from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import os import shutil import hashlib import asyncio from typing import AsyncGenerator, Optional import aiohttp import io import logging import base64 import json from datetime import datetime, timezone import secrets from config import ( OPENAI_TTS_BASE_URL, OPENAI_TTS_API_KEY, OPENAI_TTS_MODEL, OPENAI_TTS_DEFAULT_VOICE, OPENAI_TTS_FORMAT, ) CLIENT_COOKIE = "reader_pro_client" PROGRESS_FILE = "reading_progress.json" # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() tts_http_session: aiohttp.ClientSession | None = None # Configure CORS origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Base directories BASE_STATIC_FILES_DIR = "static/files" os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True) # Mount static files app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files") app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web") app.mount("/static", StaticFiles(directory="static"), name="static") # Audio cache directory CACHE_DIR = "audio_cache" os.makedirs(CACHE_DIR, exist_ok=True) def sanitize_filename(name: str) -> str: return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip() def build_file_url(filename: str) -> str: return f"/static/files/{filename}" def get_or_create_client_id(client_id: Optional[str]) -> str: normalized = (client_id or "").strip() return normalized or secrets.token_urlsafe(24) def load_progress_store() -> dict: if not os.path.exists(PROGRESS_FILE): return {} try: with open(PROGRESS_FILE, "r", encoding="utf-8") as f: data = json.load(f) return data if isinstance(data, dict) else {} except Exception: return {} def save_progress_store(data: dict) -> None: with open(PROGRESS_FILE, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) @app.on_event("startup") async def startup_event(): global tts_http_session timeout = aiohttp.ClientTimeout(total=120) connector = aiohttp.TCPConnector(limit=20, ttl_dns_cache=300) tts_http_session = aiohttp.ClientSession(timeout=timeout, connector=connector) @app.on_event("shutdown") async def shutdown_event(): global tts_http_session if tts_http_session and not tts_http_session.closed: await tts_http_session.close() tts_http_session = None @app.get("/") def root(client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)): current_client_id = get_or_create_client_id(client_id) progress = load_progress_store().get(current_client_id, {}) last_file = (progress.get("last_file") or "").strip() if last_file: response = RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302) else: files = sorted([f for f in os.listdir(BASE_STATIC_FILES_DIR) if f.lower().endswith(".pdf")]) if files: response = RedirectResponse( url=f"/static/web/viewer.html?file={build_file_url(files[0])}", status_code=302, ) else: response = RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302) response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/") return response # PDF upload endpoint @app.post("/upload-pdf") async def upload_pdf( file: UploadFile = File(...), custom_name: str = Form(...), client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE), ): if file.content_type != "application/pdf": raise HTTPException(status_code=400, detail="文件类型必须是 PDF") sanitized_name = sanitize_filename(custom_name) if not sanitized_name: return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"}) unique_filename = f"{sanitized_name}.pdf" file_path = os.path.join(BASE_STATIC_FILES_DIR, unique_filename) if os.path.exists(file_path): return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"}) try: with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) except Exception: raise HTTPException(status_code=500, detail="上传过程中出错") finally: file.file.close() current_client_id = get_or_create_client_id(client_id) file_relative_path = build_file_url(unique_filename) store = load_progress_store() store[current_client_id] = { "last_file": file_relative_path, "last_page": 1, "updated_at": datetime.now(timezone.utc).isoformat(), } save_progress_store(store) response = JSONResponse(content={"success": True, "file_path": file_relative_path}) response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/") return response # List PDFs endpoint @app.get("/list-pdfs") async def list_pdfs(): try: files = os.listdir(BASE_STATIC_FILES_DIR) pdf_files = [ {"name": file, "url": build_file_url(file)} for file in files if file.lower().endswith(".pdf") ] pdf_files.sort(key=lambda x: x["name"].lower()) return JSONResponse(content={"success": True, "files": pdf_files}) except Exception: raise HTTPException(status_code=500, detail="无法获取文件列表") class ReadingProgressRequest(BaseModel): file: str page: int @app.get("/reading-progress") async def get_reading_progress(file: str, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)): normalized_file = (file or "").strip() if not normalized_file: return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"}) current_client_id = get_or_create_client_id(client_id) progress = load_progress_store().get(current_client_id, {}) page = progress.get("last_page") if progress.get("last_file") == normalized_file else None response = JSONResponse(content={"success": True, "file": normalized_file, "page": page}) response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/") return response @app.post("/reading-progress") async def save_reading_progress(request: ReadingProgressRequest, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)): normalized_file = (request.file or "").strip() page = int(request.page) if not normalized_file: return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"}) if page < 1: return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"}) current_client_id = get_or_create_client_id(client_id) store = load_progress_store() store[current_client_id] = { "last_file": normalized_file, "last_page": page, "updated_at": datetime.now(timezone.utc).isoformat(), } save_progress_store(store) response = JSONResponse(content={"success": True}) response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/") return response class TextToSpeechRequest(BaseModel): user_input: str voice: str = OPENAI_TTS_DEFAULT_VOICE speed: float = 1.0 @app.post("/generate") async def generate_proxy(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") async def stream_generator() -> AsyncGenerator[bytes, None]: try: chunks = split_text_into_chunks(user_input) if not chunks: return first_audio = await request_openai_tts_audio(chunks[0], request.voice) first_payload = { "index": 0, "text": chunks[0], "audio": base64.b64encode(first_audio).decode("utf-8"), "format": OPENAI_TTS_FORMAT, } yield (json.dumps(first_payload, ensure_ascii=False) + "\n").encode("utf-8") await asyncio.sleep(0) pending_tasks = { index: asyncio.create_task(request_openai_tts_audio(chunk, request.voice)) for index, chunk in enumerate(chunks[1:], start=1) } for index in range(1, len(chunks)): audio_bytes = await pending_tasks[index] payload = { "index": index, "text": chunks[index], "audio": base64.b64encode(audio_bytes).decode("utf-8"), "format": OPENAI_TTS_FORMAT, } yield (json.dumps(payload, ensure_ascii=False) + "\n").encode("utf-8") await asyncio.sleep(0) except HTTPException as e: logger.error("generate proxy http error: %s", e.detail) yield (json.dumps({"error": e.detail}, ensure_ascii=False) + "\n").encode("utf-8") except Exception as e: logger.error(f"generate proxy error: {str(e)}") yield (json.dumps({"error": "TTS生成失败"}, ensure_ascii=False) + "\n").encode("utf-8") return StreamingResponse(stream_generator(), media_type="application/x-ndjson") def normalize_openai_voice(voice: str) -> str: allowed_voices = {"alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"} normalized = (voice or "").strip().lower() return normalized if normalized in allowed_voices else OPENAI_TTS_DEFAULT_VOICE def get_audio_media_type(audio_format: str) -> str: mapping = { "wav": "audio/wav", "mp3": "audio/mpeg", "flac": "audio/flac", "opus": "audio/opus", "pcm16": "audio/L16", } return mapping.get(audio_format.lower(), "application/octet-stream") async def request_openai_tts_audio(text: str, voice: str) -> bytes: global tts_http_session payload = { "model": OPENAI_TTS_MODEL, "voice": normalize_openai_voice(voice), "input": text, "response_format": OPENAI_TTS_FORMAT, "speed": 1.0, } headers = { "Authorization": f"Bearer {OPENAI_TTS_API_KEY}", "Content-Type": "application/json", } if tts_http_session is None or tts_http_session.closed: timeout = aiohttp.ClientTimeout(total=120) connector = aiohttp.TCPConnector(limit=20, ttl_dns_cache=300) tts_http_session = aiohttp.ClientSession(timeout=timeout, connector=connector) async with tts_http_session.post( f"{OPENAI_TTS_BASE_URL.rstrip('/')}/v1/audio/speech", headers=headers, json=payload, ) as response: if response.status != 200: response_text = await response.text() logger.error("OpenAI TTS request failed: %s", response_text) error_detail = "OpenAI TTS API 请求失败" try: error_data = json.loads(response_text) error_obj = error_data.get("error", {}) error_message = error_obj.get("message") error_code = error_obj.get("code") if error_message: error_detail = f"{error_detail}: {error_message}" if response.status == 429 or error_code in {"rate_limit_exceeded", "model_not_found", "upstream_error"}: raise HTTPException(status_code=503, detail=error_detail) except json.JSONDecodeError: pass raise HTTPException(status_code=500 if response.status < 500 else 502, detail=error_detail) return await response.read() @app.post("/text-to-speech/") async def text_to_speech(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest() audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}") media_type = get_audio_media_type(OPENAI_TTS_FORMAT) if os.path.exists(audio_path): with open(audio_path, "rb") as f: return Response(content=f.read(), media_type=media_type) try: audio_bytes = await request_openai_tts_audio(user_input, request.voice) with open(audio_path, "wb") as f: f.write(audio_bytes) return Response(content=audio_bytes, media_type=media_type) except HTTPException: raise except Exception as e: logger.error(f"TTS error: {str(e)}") raise HTTPException(status_code=500, detail="TTS生成失败") FIRST_CHUNK_SIZE = 80 FOLLOWING_CHUNK_SIZE = 180 def split_text_into_chunks( text: str, first_chunk_size: int = FIRST_CHUNK_SIZE, following_chunk_size: int = FOLLOWING_CHUNK_SIZE, ) -> list: import re def split_long_sentence(sentence: str, chunk_size: int) -> list[str]: parts = [] remaining = sentence.strip() while len(remaining) > chunk_size: split_at = -1 candidate = remaining[: chunk_size + 1] for pattern in [r"[,:;)\]]\s+", r"\s+"]: matches = list(re.finditer(pattern, candidate)) if matches: split_at = matches[-1].start() break if split_at <= 0: split_at = chunk_size part = remaining[:split_at].strip() if not part: part = remaining[:chunk_size].strip() split_at = len(part) parts.append(part) remaining = remaining[split_at:].strip() if remaining: parts.append(remaining) return parts sentences = re.split(r"(?<=[.!?])\s+", text.strip()) chunks = [] current_chunk = "" current_limit = first_chunk_size for sentence in sentences: sentence = sentence.strip() if not sentence: continue if len(sentence) > current_limit: if current_chunk: chunks.append(current_chunk) current_chunk = "" current_limit = following_chunk_size long_parts = split_long_sentence(sentence, current_limit) chunks.extend(long_parts) if long_parts: current_limit = following_chunk_size continue next_chunk = f"{current_chunk} {sentence}".strip() if current_chunk else sentence if len(next_chunk) <= current_limit: current_chunk = next_chunk else: if current_chunk: chunks.append(current_chunk) current_limit = following_chunk_size current_chunk = sentence if current_chunk: chunks.append(current_chunk) return chunks async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]: text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest() audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}") if os.path.exists(audio_path): with open(audio_path, "rb") as f: yield f.read() else: try: audio_bytes = await request_openai_tts_audio(chunk, voice) with open(audio_path, "wb") as f: f.write(audio_bytes) yield audio_bytes except HTTPException as e: raise HTTPException(status_code=e.status_code, detail=e.detail) except Exception as e: raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}") @app.post("/page-to-speech/") async def page_to_speech(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest() full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.{OPENAI_TTS_FORMAT}") media_type = get_audio_media_type(OPENAI_TTS_FORMAT) if os.path.exists(full_audio_path): return StreamingResponse(open(full_audio_path, "rb"), media_type=media_type) chunks = split_text_into_chunks(user_input) async def audio_generator() -> AsyncGenerator[bytes, None]: full_audio_buffer = io.BytesIO() for chunk in chunks: async for audio_data in generate_api_audio(chunk, request.voice, request.speed): yield audio_data full_audio_buffer.write(audio_data) await asyncio.sleep(0) full_audio_buffer.seek(0) with open(full_audio_path, "wb") as f: f.write(full_audio_buffer.getvalue()) return StreamingResponse(audio_generator(), media_type=media_type) @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8005)