| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505 |
- from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from pydantic import BaseModel
- import os
- import shutil
- import hashlib
- import asyncio
- from typing import AsyncGenerator, Optional
- import aiohttp
- import io
- import logging
- import base64
- import json
- from datetime import datetime, timezone
- import secrets
- from config import (
- OPENAI_TTS_BASE_URL,
- OPENAI_TTS_API_KEY,
- OPENAI_TTS_MODEL,
- OPENAI_TTS_DEFAULT_VOICE,
- OPENAI_TTS_FORMAT,
- )
- CLIENT_COOKIE = "reader_pro_client"
- PROGRESS_FILE = "reading_progress.json"
- # Set up logging
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- # Initialize FastAPI app
- app = FastAPI()
- tts_http_session: aiohttp.ClientSession | None = None
- # Configure CORS
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # Base directories
- BASE_STATIC_FILES_DIR = "static/files"
- os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
- # Mount static files
- app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
- app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
- app.mount("/static", StaticFiles(directory="static"), name="static")
- # Audio cache directory
- CACHE_DIR = "audio_cache"
- os.makedirs(CACHE_DIR, exist_ok=True)
- def sanitize_filename(name: str) -> str:
- return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
- def build_file_url(filename: str) -> str:
- return f"/static/files/{filename}"
- def get_or_create_client_id(client_id: Optional[str]) -> str:
- normalized = (client_id or "").strip()
- return normalized or secrets.token_urlsafe(24)
- def load_progress_store() -> dict:
- if not os.path.exists(PROGRESS_FILE):
- return {}
- try:
- with open(PROGRESS_FILE, "r", encoding="utf-8") as f:
- data = json.load(f)
- return data if isinstance(data, dict) else {}
- except Exception:
- return {}
- def save_progress_store(data: dict) -> None:
- with open(PROGRESS_FILE, "w", encoding="utf-8") as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
- @app.on_event("startup")
- async def startup_event():
- global tts_http_session
- timeout = aiohttp.ClientTimeout(total=120)
- connector = aiohttp.TCPConnector(limit=20, ttl_dns_cache=300)
- tts_http_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
- @app.on_event("shutdown")
- async def shutdown_event():
- global tts_http_session
- if tts_http_session and not tts_http_session.closed:
- await tts_http_session.close()
- tts_http_session = None
- @app.get("/")
- def root(client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
- current_client_id = get_or_create_client_id(client_id)
- progress = load_progress_store().get(current_client_id, {})
- last_file = (progress.get("last_file") or "").strip()
- if last_file:
- response = RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
- else:
- files = sorted([f for f in os.listdir(BASE_STATIC_FILES_DIR) if f.lower().endswith(".pdf")])
- if files:
- response = RedirectResponse(
- url=f"/static/web/viewer.html?file={build_file_url(files[0])}",
- status_code=302,
- )
- else:
- response = RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
- response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
- return response
- # PDF upload endpoint
- @app.post("/upload-pdf")
- async def upload_pdf(
- file: UploadFile = File(...),
- custom_name: str = Form(...),
- client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE),
- ):
- if file.content_type != "application/pdf":
- raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
- sanitized_name = sanitize_filename(custom_name)
- if not sanitized_name:
- return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
- unique_filename = f"{sanitized_name}.pdf"
- file_path = os.path.join(BASE_STATIC_FILES_DIR, unique_filename)
- if os.path.exists(file_path):
- return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
- try:
- with open(file_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
- except Exception:
- raise HTTPException(status_code=500, detail="上传过程中出错")
- finally:
- file.file.close()
- current_client_id = get_or_create_client_id(client_id)
- file_relative_path = build_file_url(unique_filename)
- store = load_progress_store()
- store[current_client_id] = {
- "last_file": file_relative_path,
- "last_page": 1,
- "updated_at": datetime.now(timezone.utc).isoformat(),
- }
- save_progress_store(store)
- response = JSONResponse(content={"success": True, "file_path": file_relative_path})
- response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
- return response
- # List PDFs endpoint
- @app.get("/list-pdfs")
- async def list_pdfs():
- try:
- files = os.listdir(BASE_STATIC_FILES_DIR)
- pdf_files = [
- {"name": file, "url": build_file_url(file)}
- for file in files
- if file.lower().endswith(".pdf")
- ]
- pdf_files.sort(key=lambda x: x["name"].lower())
- return JSONResponse(content={"success": True, "files": pdf_files})
- except Exception:
- raise HTTPException(status_code=500, detail="无法获取文件列表")
- class ReadingProgressRequest(BaseModel):
- file: str
- page: int
- @app.get("/reading-progress")
- async def get_reading_progress(file: str, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
- normalized_file = (file or "").strip()
- if not normalized_file:
- return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
- current_client_id = get_or_create_client_id(client_id)
- progress = load_progress_store().get(current_client_id, {})
- page = progress.get("last_page") if progress.get("last_file") == normalized_file else None
- response = JSONResponse(content={"success": True, "file": normalized_file, "page": page})
- response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
- return response
- @app.post("/reading-progress")
- async def save_reading_progress(request: ReadingProgressRequest, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
- normalized_file = (request.file or "").strip()
- page = int(request.page)
- if not normalized_file:
- return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
- if page < 1:
- return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
- current_client_id = get_or_create_client_id(client_id)
- store = load_progress_store()
- store[current_client_id] = {
- "last_file": normalized_file,
- "last_page": page,
- "updated_at": datetime.now(timezone.utc).isoformat(),
- }
- save_progress_store(store)
- response = JSONResponse(content={"success": True})
- response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
- return response
- class TextToSpeechRequest(BaseModel):
- user_input: str
- voice: str = OPENAI_TTS_DEFAULT_VOICE
- speed: float = 1.0
- @app.post("/generate")
- async def generate_proxy(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
- async def stream_generator() -> AsyncGenerator[bytes, None]:
- try:
- chunks = split_text_into_chunks(user_input)
- if not chunks:
- return
- first_audio = await request_openai_tts_audio(chunks[0], request.voice)
- first_payload = {
- "index": 0,
- "text": chunks[0],
- "audio": base64.b64encode(first_audio).decode("utf-8"),
- "format": OPENAI_TTS_FORMAT,
- }
- yield (json.dumps(first_payload, ensure_ascii=False) + "\n").encode("utf-8")
- await asyncio.sleep(0)
- pending_tasks = {
- index: asyncio.create_task(request_openai_tts_audio(chunk, request.voice))
- for index, chunk in enumerate(chunks[1:], start=1)
- }
- for index in range(1, len(chunks)):
- audio_bytes = await pending_tasks[index]
- payload = {
- "index": index,
- "text": chunks[index],
- "audio": base64.b64encode(audio_bytes).decode("utf-8"),
- "format": OPENAI_TTS_FORMAT,
- }
- yield (json.dumps(payload, ensure_ascii=False) + "\n").encode("utf-8")
- await asyncio.sleep(0)
- except HTTPException as e:
- logger.error("generate proxy http error: %s", e.detail)
- yield (json.dumps({"error": e.detail}, ensure_ascii=False) + "\n").encode("utf-8")
- except Exception as e:
- logger.error(f"generate proxy error: {str(e)}")
- yield (json.dumps({"error": "TTS生成失败"}, ensure_ascii=False) + "\n").encode("utf-8")
- return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
- def normalize_openai_voice(voice: str) -> str:
- allowed_voices = {"alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"}
- normalized = (voice or "").strip().lower()
- return normalized if normalized in allowed_voices else OPENAI_TTS_DEFAULT_VOICE
- def get_audio_media_type(audio_format: str) -> str:
- mapping = {
- "wav": "audio/wav",
- "mp3": "audio/mpeg",
- "flac": "audio/flac",
- "opus": "audio/opus",
- "pcm16": "audio/L16",
- }
- return mapping.get(audio_format.lower(), "application/octet-stream")
- async def request_openai_tts_audio(text: str, voice: str) -> bytes:
- global tts_http_session
- payload = {
- "model": OPENAI_TTS_MODEL,
- "voice": normalize_openai_voice(voice),
- "input": text,
- "response_format": OPENAI_TTS_FORMAT,
- "speed": 1.0,
- }
- headers = {
- "Authorization": f"Bearer {OPENAI_TTS_API_KEY}",
- "Content-Type": "application/json",
- }
- if tts_http_session is None or tts_http_session.closed:
- timeout = aiohttp.ClientTimeout(total=120)
- connector = aiohttp.TCPConnector(limit=20, ttl_dns_cache=300)
- tts_http_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
- async with tts_http_session.post(
- f"{OPENAI_TTS_BASE_URL.rstrip('/')}/v1/audio/speech",
- headers=headers,
- json=payload,
- ) as response:
- if response.status != 200:
- response_text = await response.text()
- logger.error("OpenAI TTS request failed: %s", response_text)
- error_detail = "OpenAI TTS API 请求失败"
- try:
- error_data = json.loads(response_text)
- error_obj = error_data.get("error", {})
- error_message = error_obj.get("message")
- error_code = error_obj.get("code")
- if error_message:
- error_detail = f"{error_detail}: {error_message}"
- if response.status == 429 or error_code in {"rate_limit_exceeded", "model_not_found", "upstream_error"}:
- raise HTTPException(status_code=503, detail=error_detail)
- except json.JSONDecodeError:
- pass
- raise HTTPException(status_code=500 if response.status < 500 else 502, detail=error_detail)
- return await response.read()
- @app.post("/text-to-speech/")
- async def text_to_speech(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
- text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
- media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
- if os.path.exists(audio_path):
- with open(audio_path, "rb") as f:
- return Response(content=f.read(), media_type=media_type)
- try:
- audio_bytes = await request_openai_tts_audio(user_input, request.voice)
- with open(audio_path, "wb") as f:
- f.write(audio_bytes)
- return Response(content=audio_bytes, media_type=media_type)
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"TTS error: {str(e)}")
- raise HTTPException(status_code=500, detail="TTS生成失败")
- FIRST_CHUNK_SIZE = 80
- FOLLOWING_CHUNK_SIZE = 180
- def split_text_into_chunks(
- text: str,
- first_chunk_size: int = FIRST_CHUNK_SIZE,
- following_chunk_size: int = FOLLOWING_CHUNK_SIZE,
- ) -> list:
- import re
- def split_long_sentence(sentence: str, chunk_size: int) -> list[str]:
- parts = []
- remaining = sentence.strip()
- while len(remaining) > chunk_size:
- split_at = -1
- candidate = remaining[: chunk_size + 1]
- for pattern in [r"[,:;)\]]\s+", r"\s+"]:
- matches = list(re.finditer(pattern, candidate))
- if matches:
- split_at = matches[-1].start()
- break
- if split_at <= 0:
- split_at = chunk_size
- part = remaining[:split_at].strip()
- if not part:
- part = remaining[:chunk_size].strip()
- split_at = len(part)
- parts.append(part)
- remaining = remaining[split_at:].strip()
- if remaining:
- parts.append(remaining)
- return parts
- sentences = re.split(r"(?<=[.!?])\s+", text.strip())
- chunks = []
- current_chunk = ""
- current_limit = first_chunk_size
- for sentence in sentences:
- sentence = sentence.strip()
- if not sentence:
- continue
- if len(sentence) > current_limit:
- if current_chunk:
- chunks.append(current_chunk)
- current_chunk = ""
- current_limit = following_chunk_size
- long_parts = split_long_sentence(sentence, current_limit)
- chunks.extend(long_parts)
- if long_parts:
- current_limit = following_chunk_size
- continue
- next_chunk = f"{current_chunk} {sentence}".strip() if current_chunk else sentence
- if len(next_chunk) <= current_limit:
- current_chunk = next_chunk
- else:
- if current_chunk:
- chunks.append(current_chunk)
- current_limit = following_chunk_size
- current_chunk = sentence
- if current_chunk:
- chunks.append(current_chunk)
- return chunks
- async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
- text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
- if os.path.exists(audio_path):
- with open(audio_path, "rb") as f:
- yield f.read()
- else:
- try:
- audio_bytes = await request_openai_tts_audio(chunk, voice)
- with open(audio_path, "wb") as f:
- f.write(audio_bytes)
- yield audio_bytes
- except HTTPException as e:
- raise HTTPException(status_code=e.status_code, detail=e.detail)
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
- @app.post("/page-to-speech/")
- async def page_to_speech(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
- full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
- full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.{OPENAI_TTS_FORMAT}")
- media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
- if os.path.exists(full_audio_path):
- return StreamingResponse(open(full_audio_path, "rb"), media_type=media_type)
- chunks = split_text_into_chunks(user_input)
- async def audio_generator() -> AsyncGenerator[bytes, None]:
- full_audio_buffer = io.BytesIO()
- for chunk in chunks:
- async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
- yield audio_data
- full_audio_buffer.write(audio_data)
- await asyncio.sleep(0)
- full_audio_buffer.seek(0)
- with open(full_audio_path, "wb") as f:
- f.write(full_audio_buffer.getvalue())
- return StreamingResponse(audio_generator(), media_type=media_type)
- @app.get("/health")
- async def health_check():
- return {"status": "healthy"}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8005)
|