Ver Fonte

后端对接api.aimanyi.top 调用openai gpt-4o-mini-tts

sequoia00 há 1 mês atrás
pai
commit
9c33e50004
3 ficheiros alterados com 201 adições e 220 exclusões
  1. 184 188
      main_server.py
  2. 5 0
      reading_progress.json
  3. 12 32
      static/web/viewer.html

+ 184 - 188
main_server.py

@@ -21,6 +21,17 @@ import pymysql
 
 from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
 
+OPENAI_TTS_BASE_URL = os.getenv("OPENAI_TTS_BASE_URL", "https://api.aimanyi.top")
+OPENAI_TTS_API_KEY = os.getenv(
+    "OPENAI_TTS_API_KEY",
+    "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa",
+)
+OPENAI_TTS_MODEL = os.getenv("OPENAI_TTS_MODEL", "gpt-4o-mini-tts")
+OPENAI_TTS_DEFAULT_VOICE = os.getenv("OPENAI_TTS_DEFAULT_VOICE", "sage")
+OPENAI_TTS_FORMAT = os.getenv("OPENAI_TTS_FORMAT", "wav")
+CLIENT_COOKIE = "reader_pro_client"
+PROGRESS_FILE = "reading_progress.json"
+
 # Set up logging
 logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
@@ -145,6 +156,31 @@ def sanitize_filename(name: str) -> str:
     return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
 
 
+def build_file_url(filename: str) -> str:
+    return f"/static/files/{filename}"
+
+
+def get_or_create_client_id(client_id: Optional[str]) -> str:
+    normalized = (client_id or "").strip()
+    return normalized or secrets.token_urlsafe(24)
+
+
+def load_progress_store() -> dict:
+    if not os.path.exists(PROGRESS_FILE):
+        return {}
+    try:
+        with open(PROGRESS_FILE, "r", encoding="utf-8") as f:
+            data = json.load(f)
+        return data if isinstance(data, dict) else {}
+    except Exception:
+        return {}
+
+
+def save_progress_store(data: dict) -> None:
+    with open(PROGRESS_FILE, "w", encoding="utf-8") as f:
+        json.dump(data, f, ensure_ascii=False, indent=2)
+
+
 def get_user_dir(username: str) -> str:
     safe = sanitize_filename(username) or "user"
     path = os.path.join(BASE_STATIC_FILES_DIR, safe)
@@ -217,24 +253,25 @@ def on_startup():
 
 
 @app.get("/")
-def root(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
-    user = get_user_by_session(session_token)
-    if not user:
-        return RedirectResponse(url="/login", status_code=302)
+def root(client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
+    current_client_id = get_or_create_client_id(client_id)
+    progress = load_progress_store().get(current_client_id, {})
+    last_file = (progress.get("last_file") or "").strip()
 
-    last_file = user.get("last_file")
     if last_file:
-        return RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
-
-    # fallback: first file in user's own directory
-    user_dir = get_user_dir(user["username"])
-    files = sorted([f for f in os.listdir(user_dir) if f.lower().endswith(".pdf")])
-    if files:
-        file_url = build_user_file_url(user["username"], files[0])
-        return RedirectResponse(url=f"/static/web/viewer.html?file={file_url}", status_code=302)
+        response = RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
+    else:
+        files = sorted([f for f in os.listdir(BASE_STATIC_FILES_DIR) if f.lower().endswith(".pdf")])
+        if files:
+            response = RedirectResponse(
+                url=f"/static/web/viewer.html?file={build_file_url(files[0])}",
+                status_code=302,
+            )
+        else:
+            response = RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
 
-    # fallback to sample document
-    return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
+    response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
+    return response
 
 
 @app.get("/login")
@@ -655,9 +692,8 @@ async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[
 async def upload_pdf(
     file: UploadFile = File(...),
     custom_name: str = Form(...),
-    session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
+    client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE),
 ):
-    user = require_user(session_token)
     if file.content_type != "application/pdf":
         raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
 
@@ -666,8 +702,7 @@ async def upload_pdf(
         return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
 
     unique_filename = f"{sanitized_name}.pdf"
-    user_dir = get_user_dir(user["username"])
-    file_path = os.path.join(user_dir, unique_filename)
+    file_path = os.path.join(BASE_STATIC_FILES_DIR, unique_filename)
 
     if os.path.exists(file_path):
         return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
@@ -680,30 +715,28 @@ async def upload_pdf(
     finally:
         file.file.close()
 
-    file_relative_path = build_user_file_url(user["username"], unique_filename)
-    now = datetime.now(timezone.utc).replace(tzinfo=None)
-    conn = db_conn()
-    try:
-        with conn.cursor() as cur:
-            cur.execute(
-                "UPDATE user SET last_file=%s, last_page=1, updated_at=%s WHERE id=%s",
-                (file_relative_path, now, user["id"]),
-            )
-    finally:
-        conn.close()
+    current_client_id = get_or_create_client_id(client_id)
+    file_relative_path = build_file_url(unique_filename)
+    store = load_progress_store()
+    store[current_client_id] = {
+        "last_file": file_relative_path,
+        "last_page": 1,
+        "updated_at": datetime.now(timezone.utc).isoformat(),
+    }
+    save_progress_store(store)
 
-    return JSONResponse(content={"success": True, "file_path": file_relative_path})
+    response = JSONResponse(content={"success": True, "file_path": file_relative_path})
+    response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
+    return response
 
 
 # List PDFs endpoint
 @app.get("/list-pdfs")
-async def list_pdfs(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
-    user = require_user(session_token)
+async def list_pdfs():
     try:
-        user_dir = get_user_dir(user["username"])
-        files = os.listdir(user_dir)
+        files = os.listdir(BASE_STATIC_FILES_DIR)
         pdf_files = [
-            {"name": file, "url": build_user_file_url(user["username"], file)}
+            {"name": file, "url": build_file_url(file)}
             for file in files
             if file.lower().endswith(".pdf")
         ]
@@ -719,30 +752,21 @@ class ReadingProgressRequest(BaseModel):
 
 
 @app.get("/reading-progress")
-async def get_reading_progress(file: str, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
-    user = require_user(session_token)
+async def get_reading_progress(file: str, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
     normalized_file = (file or "").strip()
     if not normalized_file:
         return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
 
-    conn = db_conn()
-    try:
-        with conn.cursor() as cur:
-            cur.execute(
-                "SELECT page FROM user_progress WHERE user_id=%s AND file_path=%s",
-                (user["id"], normalized_file),
-            )
-            row = cur.fetchone()
-            page = row["page"] if row else None
-    finally:
-        conn.close()
-
-    return JSONResponse(content={"success": True, "file": normalized_file, "page": page})
+    current_client_id = get_or_create_client_id(client_id)
+    progress = load_progress_store().get(current_client_id, {})
+    page = progress.get("last_page") if progress.get("last_file") == normalized_file else None
+    response = JSONResponse(content={"success": True, "file": normalized_file, "page": page})
+    response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
+    return response
 
 
 @app.post("/reading-progress")
-async def save_reading_progress(request: ReadingProgressRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
-    user = require_user(session_token)
+async def save_reading_progress(request: ReadingProgressRequest, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
     normalized_file = (request.file or "").strip()
     page = int(request.page)
     if not normalized_file:
@@ -750,31 +774,22 @@ async def save_reading_progress(request: ReadingProgressRequest, session_token:
     if page < 1:
         return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
 
-    now = datetime.now(timezone.utc).replace(tzinfo=None)
-    conn = db_conn()
-    try:
-        with conn.cursor() as cur:
-            cur.execute(
-                """
-                INSERT INTO user_progress (user_id, file_path, page, updated_at)
-                VALUES (%s, %s, %s, %s)
-                ON DUPLICATE KEY UPDATE page=VALUES(page), updated_at=VALUES(updated_at)
-                """,
-                (user["id"], normalized_file, page, now),
-            )
-            cur.execute(
-                "UPDATE user SET last_file=%s, last_page=%s, updated_at=%s WHERE id=%s",
-                (normalized_file, page, now, user["id"]),
-            )
-    finally:
-        conn.close()
-
-    return JSONResponse(content={"success": True})
+    current_client_id = get_or_create_client_id(client_id)
+    store = load_progress_store()
+    store[current_client_id] = {
+        "last_file": normalized_file,
+        "last_page": page,
+        "updated_at": datetime.now(timezone.utc).isoformat(),
+    }
+    save_progress_store(store)
+    response = JSONResponse(content={"success": True})
+    response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
+    return response
 
 
 class TextToSpeechRequest(BaseModel):
     user_input: str
-    voice: str = "af_heart"
+    voice: str = OPENAI_TTS_DEFAULT_VOICE
     speed: float = 1.0
 
 
@@ -786,24 +801,83 @@ async def generate_proxy(request: TextToSpeechRequest):
 
     async def stream_generator() -> AsyncGenerator[bytes, None]:
         try:
-            async with aiohttp.ClientSession() as session:
-                async with session.post(
-                    "http://141.140.15.30:8028/generate",
-                    headers={"Content-Type": "application/json"},
-                    json={"text": user_input, "voice": request.voice, "speed": request.speed},
-                ) as response:
-                    if response.status != 200:
-                        raise HTTPException(status_code=500, detail="TTS API 请求失败")
-
-                    async for chunk in response.content.iter_any():
-                        yield chunk
+            chunks = split_text_into_chunks(user_input)
+            for index, chunk in enumerate(chunks):
+                audio_bytes = await request_openai_tts_audio(chunk, request.voice)
+                payload = {
+                    "index": index,
+                    "text": chunk,
+                    "audio": base64.b64encode(audio_bytes).decode("utf-8"),
+                }
+                yield (json.dumps(payload, ensure_ascii=False) + "\n").encode("utf-8")
+                await asyncio.sleep(0)
+        except HTTPException as e:
+            logger.error("generate proxy http error: %s", e.detail)
+            yield (json.dumps({"error": e.detail}, ensure_ascii=False) + "\n").encode("utf-8")
         except Exception as e:
             logger.error(f"generate proxy error: {str(e)}")
-            raise HTTPException(status_code=500, detail=str(e))
+            yield (json.dumps({"error": "TTS生成失败"}, ensure_ascii=False) + "\n").encode("utf-8")
 
     return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
 
 
+def normalize_openai_voice(voice: str) -> str:
+    allowed_voices = {"alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"}
+    normalized = (voice or "").strip().lower()
+    return normalized if normalized in allowed_voices else OPENAI_TTS_DEFAULT_VOICE
+
+
+def get_audio_media_type(audio_format: str) -> str:
+    mapping = {
+        "wav": "audio/wav",
+        "mp3": "audio/mpeg",
+        "flac": "audio/flac",
+        "opus": "audio/opus",
+        "pcm16": "audio/L16",
+    }
+    return mapping.get(audio_format.lower(), "application/octet-stream")
+
+
+async def request_openai_tts_audio(text: str, voice: str) -> bytes:
+    payload = {
+        "model": OPENAI_TTS_MODEL,
+        "voice": normalize_openai_voice(voice),
+        "input": text,
+        "response_format": OPENAI_TTS_FORMAT,
+        "speed": 1.0,
+    }
+
+    headers = {
+        "Authorization": f"Bearer {OPENAI_TTS_API_KEY}",
+        "Content-Type": "application/json",
+    }
+
+    async with aiohttp.ClientSession() as session:
+        async with session.post(
+            f"{OPENAI_TTS_BASE_URL.rstrip('/')}/v1/audio/speech",
+            headers=headers,
+            json=payload,
+        ) as response:
+            if response.status != 200:
+                response_text = await response.text()
+                logger.error("OpenAI TTS request failed: %s", response_text)
+                error_detail = "OpenAI TTS API 请求失败"
+                try:
+                    error_data = json.loads(response_text)
+                    error_obj = error_data.get("error", {})
+                    error_message = error_obj.get("message")
+                    error_code = error_obj.get("code")
+                    if error_message:
+                        error_detail = f"{error_detail}: {error_message}"
+                    if response.status == 429 or error_code in {"rate_limit_exceeded", "model_not_found", "upstream_error"}:
+                        raise HTTPException(status_code=503, detail=error_detail)
+                except json.JSONDecodeError:
+                    pass
+                raise HTTPException(status_code=500 if response.status < 500 else 502, detail=error_detail)
+
+            return await response.read()
+
+
 @app.post("/text-to-speech/")
 async def text_to_speech(request: TextToSpeechRequest):
     user_input = request.user_input.strip()
@@ -811,65 +885,23 @@ async def text_to_speech(request: TextToSpeechRequest):
         raise HTTPException(status_code=400, detail="输入文本为空")
 
     text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
-    audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
+    audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
+    media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
 
     if os.path.exists(audio_path):
         with open(audio_path, "rb") as f:
-            return Response(content=f.read(), media_type="audio/wav")
+            return Response(content=f.read(), media_type=media_type)
 
-    async def audio_generator() -> AsyncGenerator[bytes, None]:
-        try:
-            async with aiohttp.ClientSession() as session:
-                async with session.post(
-                    "http://141.140.15.30:8028/generate",
-                    headers={"Content-Type": "application/json"},
-                    json={"text": user_input, "voice": request.voice, "speed": request.speed},
-                ) as response:
-                    if response.status != 200:
-                        raise HTTPException(status_code=500, detail="TTS API 请求失败")
-
-                    buffer = ""
-                    full_audio = io.BytesIO()
-                    async for chunk in response.content.iter_any():
-                        buffer += chunk.decode("utf-8")
-                        lines = buffer.split("\n")
-                        buffer = lines[-1]
-
-                        for line in lines[:-1]:
-                            if not line.strip():
-                                continue
-                            try:
-                                data = json.loads(line)
-                                if data.get("error"):
-                                    raise HTTPException(status_code=500, detail=data["error"])
-                                audio_b64 = data.get("audio")
-                                if audio_b64:
-                                    audio_bytes = base64.b64decode(audio_b64)
-                                    full_audio.write(audio_bytes)
-                                    yield audio_bytes
-                            except json.JSONDecodeError as e:
-                                logger.error(f"JSON decode error: {str(e)}")
-                                continue
-
-                    if buffer.strip():
-                        try:
-                            data = json.loads(buffer)
-                            if data.get("audio"):
-                                audio_bytes = base64.b64decode(data["audio"])
-                                full_audio.write(audio_bytes)
-                                yield audio_bytes
-                        except json.JSONDecodeError:
-                            pass
-
-                    full_audio.seek(0)
-                    with open(audio_path, "wb") as f:
-                        f.write(full_audio.getvalue())
-
-        except Exception as e:
-            logger.error(f"TTS error: {str(e)}")
-            raise HTTPException(status_code=500, detail=str(e))
-
-    return StreamingResponse(audio_generator(), media_type="audio/wav")
+    try:
+        audio_bytes = await request_openai_tts_audio(user_input, request.voice)
+        with open(audio_path, "wb") as f:
+            f.write(audio_bytes)
+        return Response(content=audio_bytes, media_type=media_type)
+    except HTTPException:
+        raise
+    except Exception as e:
+        logger.error(f"TTS error: {str(e)}")
+        raise HTTPException(status_code=500, detail="TTS生成失败")
 
 
 MAX_CHUNK_SIZE = 200
@@ -900,56 +932,19 @@ def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> l
 
 async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
     text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
-    audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
+    audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
 
     if os.path.exists(audio_path):
         with open(audio_path, "rb") as f:
             yield f.read()
     else:
         try:
-            async with aiohttp.ClientSession() as session:
-                async with session.post(
-                    "http://141.140.15.30:8028/generate",
-                    headers={"Content-Type": "application/json"},
-                    json={"text": chunk, "voice": voice, "speed": speed},
-                ) as response:
-                    if response.status != 200:
-                        raise HTTPException(status_code=500, detail="TTS API 请求失败")
-
-                    buffer = ""
-                    async for part in response.content.iter_any():
-                        buffer += part.decode("utf-8")
-                        lines = buffer.split("\n")
-                        buffer = lines[-1]
-
-                        for line in lines[:-1]:
-                            if not line.strip():
-                                continue
-                            try:
-                                data = json.loads(line)
-                                if data.get("error"):
-                                    raise HTTPException(status_code=500, detail=data["error"])
-                                audio_b64 = data.get("audio")
-                                if audio_b64:
-                                    audio_bytes = base64.b64decode(audio_b64)
-                                    yield audio_bytes
-                                    with open(audio_path, "wb") as f:
-                                        f.write(audio_bytes)
-                            except json.JSONDecodeError as e:
-                                logger.error(f"JSON decode error: {str(e)}")
-                                continue
-
-                    if buffer.strip():
-                        try:
-                            data = json.loads(buffer)
-                            if data.get("audio"):
-                                audio_bytes = base64.b64decode(data["audio"])
-                                yield audio_bytes
-                                with open(audio_path, "wb") as f:
-                                    f.write(audio_bytes)
-                        except json.JSONDecodeError:
-                            pass
-
+            audio_bytes = await request_openai_tts_audio(chunk, voice)
+            with open(audio_path, "wb") as f:
+                f.write(audio_bytes)
+            yield audio_bytes
+        except HTTPException as e:
+            raise HTTPException(status_code=e.status_code, detail=e.detail)
         except Exception as e:
             raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
 
@@ -961,10 +956,11 @@ async def page_to_speech(request: TextToSpeechRequest):
         raise HTTPException(status_code=400, detail="输入文本为空")
 
     full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
-    full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
+    full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.{OPENAI_TTS_FORMAT}")
+    media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
 
     if os.path.exists(full_audio_path):
-        return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
+        return StreamingResponse(open(full_audio_path, "rb"), media_type=media_type)
 
     chunks = split_text_into_chunks(user_input)
 
@@ -980,7 +976,7 @@ async def page_to_speech(request: TextToSpeechRequest):
         with open(full_audio_path, "wb") as f:
             f.write(full_audio_buffer.getvalue())
 
-    return StreamingResponse(audio_generator(), media_type="audio/wav")
+    return StreamingResponse(audio_generator(), media_type=media_type)
 
 
 @app.get("/health")

+ 5 - 0
reading_progress.json

@@ -1,5 +1,10 @@
 {
   "/static/files/find.pdf": {
     "page": 13
+  },
+  "9JpcCku5tt1Mch8-Lr8CsV2ADhKz9SIp": {
+    "last_file": "/static/files/study/compress.pdf",
+    "last_page": 11,
+    "updated_at": "2026-04-30T02:26:34.047449+00:00"
   }
 }

+ 12 - 32
static/web/viewer.html

@@ -1618,26 +1618,10 @@
                     setNightMode(enabled);
                 }
 
-                async function ensureLoggedIn() {
-                    try {
-                        const resp = await fetch('/auth/me');
-                        if (!resp.ok) {
-                            window.location.href = '/login';
-                            return false;
-                        }
-                        const data = await resp.json();
-                        if (!data || !data.success) {
-                            window.location.href = '/login';
-                            return false;
-                        }
-                        userMenuTrigger.textContent = data.username || '用户';
-                        userAdminBtn.style.display = data.is_admin ? 'block' : 'none';
-                        return true;
-                    } catch (e) {
-                        window.location.href = '/login';
-                        return false;
-                    }
-                }
+                function initAnonymousReader() {
+                    userMenuTrigger.textContent = '本机阅读';
+                    userAdminBtn.style.display = 'none';
+                }
 
                 userMenuTrigger.addEventListener('click', function (event) {
                     event.stopPropagation();
@@ -1648,16 +1632,12 @@
                         userMenu.classList.remove('open');
                     }
                 });
-                userLogoutBtn.addEventListener('click', async function () {
-                    try {
-                        await fetch('/auth/logout', { method: 'POST' });
-                    } finally {
-                        window.location.href = '/login';
-                    }
-                });
-                userAdminBtn.addEventListener('click', function () {
-                    window.location.href = '/admin';
-                });
+                userLogoutBtn.addEventListener('click', function () {
+                    userMenu.classList.remove('open');
+                });
+                userAdminBtn.addEventListener('click', function () {
+                    userMenu.classList.remove('open');
+                });
                 initNightMode();
                 initReaderControls();
                 nightModeButton.addEventListener('click', function () {
@@ -1746,8 +1726,8 @@
                     };
                     bindEvents();
                 }
-                setupReadingProgressSync();
-                ensureLoggedIn();
+                setupReadingProgressSync();
+                initAnonymousReader();
 
                 // 上传按钮点击事件
                 uploadButton.addEventListener('click', function () {