main_server.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
  4. from fastapi.staticfiles import StaticFiles
  5. from pydantic import BaseModel
  6. import os
  7. import shutil
  8. import hashlib
  9. import asyncio
  10. from typing import AsyncGenerator, Optional
  11. import aiohttp
  12. import io
  13. import logging
  14. import base64
  15. import json
  16. from datetime import datetime, timezone
  17. import secrets
  18. from config import (
  19. OPENAI_TTS_BASE_URL,
  20. OPENAI_TTS_API_KEY,
  21. OPENAI_TTS_MODEL,
  22. OPENAI_TTS_DEFAULT_VOICE,
  23. OPENAI_TTS_FORMAT,
  24. )
  25. CLIENT_COOKIE = "reader_pro_client"
  26. PROGRESS_FILE = "reading_progress.json"
  27. # Set up logging
  28. logging.basicConfig(level=logging.INFO)
  29. logger = logging.getLogger(__name__)
  30. # Initialize FastAPI app
  31. app = FastAPI()
  32. tts_http_session: aiohttp.ClientSession | None = None
  33. # Configure CORS
  34. origins = ["*"]
  35. app.add_middleware(
  36. CORSMiddleware,
  37. allow_origins=origins,
  38. allow_credentials=True,
  39. allow_methods=["*"],
  40. allow_headers=["*"],
  41. )
  42. # Base directories
  43. BASE_STATIC_FILES_DIR = "static/files"
  44. os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
  45. # Mount static files
  46. app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
  47. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  48. app.mount("/static", StaticFiles(directory="static"), name="static")
  49. # Audio cache directory
  50. CACHE_DIR = "audio_cache"
  51. os.makedirs(CACHE_DIR, exist_ok=True)
  52. def sanitize_filename(name: str) -> str:
  53. return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
  54. def build_file_url(filename: str) -> str:
  55. return f"/static/files/{filename}"
  56. def get_or_create_client_id(client_id: Optional[str]) -> str:
  57. normalized = (client_id or "").strip()
  58. return normalized or secrets.token_urlsafe(24)
  59. def load_progress_store() -> dict:
  60. if not os.path.exists(PROGRESS_FILE):
  61. return {}
  62. try:
  63. with open(PROGRESS_FILE, "r", encoding="utf-8") as f:
  64. data = json.load(f)
  65. return data if isinstance(data, dict) else {}
  66. except Exception:
  67. return {}
  68. def save_progress_store(data: dict) -> None:
  69. with open(PROGRESS_FILE, "w", encoding="utf-8") as f:
  70. json.dump(data, f, ensure_ascii=False, indent=2)
  71. @app.on_event("startup")
  72. async def startup_event():
  73. global tts_http_session
  74. timeout = aiohttp.ClientTimeout(total=120)
  75. connector = aiohttp.TCPConnector(limit=20, ttl_dns_cache=300)
  76. tts_http_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
  77. @app.on_event("shutdown")
  78. async def shutdown_event():
  79. global tts_http_session
  80. if tts_http_session and not tts_http_session.closed:
  81. await tts_http_session.close()
  82. tts_http_session = None
  83. @app.get("/")
  84. def root(client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
  85. current_client_id = get_or_create_client_id(client_id)
  86. progress = load_progress_store().get(current_client_id, {})
  87. last_file = (progress.get("last_file") or "").strip()
  88. if last_file:
  89. response = RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
  90. else:
  91. files = sorted([f for f in os.listdir(BASE_STATIC_FILES_DIR) if f.lower().endswith(".pdf")])
  92. if files:
  93. response = RedirectResponse(
  94. url=f"/static/web/viewer.html?file={build_file_url(files[0])}",
  95. status_code=302,
  96. )
  97. else:
  98. response = RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
  99. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  100. return response
  101. # PDF upload endpoint
  102. @app.post("/upload-pdf")
  103. async def upload_pdf(
  104. file: UploadFile = File(...),
  105. custom_name: str = Form(...),
  106. client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE),
  107. ):
  108. if file.content_type != "application/pdf":
  109. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  110. sanitized_name = sanitize_filename(custom_name)
  111. if not sanitized_name:
  112. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  113. unique_filename = f"{sanitized_name}.pdf"
  114. file_path = os.path.join(BASE_STATIC_FILES_DIR, unique_filename)
  115. if os.path.exists(file_path):
  116. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  117. try:
  118. with open(file_path, "wb") as buffer:
  119. shutil.copyfileobj(file.file, buffer)
  120. except Exception:
  121. raise HTTPException(status_code=500, detail="上传过程中出错")
  122. finally:
  123. file.file.close()
  124. current_client_id = get_or_create_client_id(client_id)
  125. file_relative_path = build_file_url(unique_filename)
  126. store = load_progress_store()
  127. store[current_client_id] = {
  128. "last_file": file_relative_path,
  129. "last_page": 1,
  130. "updated_at": datetime.now(timezone.utc).isoformat(),
  131. }
  132. save_progress_store(store)
  133. response = JSONResponse(content={"success": True, "file_path": file_relative_path})
  134. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  135. return response
  136. # List PDFs endpoint
  137. @app.get("/list-pdfs")
  138. async def list_pdfs():
  139. try:
  140. files = os.listdir(BASE_STATIC_FILES_DIR)
  141. pdf_files = [
  142. {"name": file, "url": build_file_url(file)}
  143. for file in files
  144. if file.lower().endswith(".pdf")
  145. ]
  146. pdf_files.sort(key=lambda x: x["name"].lower())
  147. return JSONResponse(content={"success": True, "files": pdf_files})
  148. except Exception:
  149. raise HTTPException(status_code=500, detail="无法获取文件列表")
  150. class ReadingProgressRequest(BaseModel):
  151. file: str
  152. page: int
  153. @app.get("/reading-progress")
  154. async def get_reading_progress(file: str, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
  155. normalized_file = (file or "").strip()
  156. if not normalized_file:
  157. return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
  158. current_client_id = get_or_create_client_id(client_id)
  159. progress = load_progress_store().get(current_client_id, {})
  160. page = progress.get("last_page") if progress.get("last_file") == normalized_file else None
  161. response = JSONResponse(content={"success": True, "file": normalized_file, "page": page})
  162. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  163. return response
  164. @app.post("/reading-progress")
  165. async def save_reading_progress(request: ReadingProgressRequest, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
  166. normalized_file = (request.file or "").strip()
  167. page = int(request.page)
  168. if not normalized_file:
  169. return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
  170. if page < 1:
  171. return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
  172. current_client_id = get_or_create_client_id(client_id)
  173. store = load_progress_store()
  174. store[current_client_id] = {
  175. "last_file": normalized_file,
  176. "last_page": page,
  177. "updated_at": datetime.now(timezone.utc).isoformat(),
  178. }
  179. save_progress_store(store)
  180. response = JSONResponse(content={"success": True})
  181. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  182. return response
  183. class TextToSpeechRequest(BaseModel):
  184. user_input: str
  185. voice: str = OPENAI_TTS_DEFAULT_VOICE
  186. speed: float = 1.0
  187. @app.post("/generate")
  188. async def generate_proxy(request: TextToSpeechRequest):
  189. user_input = request.user_input.strip()
  190. if not user_input:
  191. raise HTTPException(status_code=400, detail="输入文本为空")
  192. async def stream_generator() -> AsyncGenerator[bytes, None]:
  193. try:
  194. chunks = split_text_into_chunks(user_input)
  195. if not chunks:
  196. return
  197. first_audio = await request_openai_tts_audio(chunks[0], request.voice)
  198. first_payload = {
  199. "index": 0,
  200. "text": chunks[0],
  201. "audio": base64.b64encode(first_audio).decode("utf-8"),
  202. "format": OPENAI_TTS_FORMAT,
  203. }
  204. yield (json.dumps(first_payload, ensure_ascii=False) + "\n").encode("utf-8")
  205. await asyncio.sleep(0)
  206. pending_tasks = {
  207. index: asyncio.create_task(request_openai_tts_audio(chunk, request.voice))
  208. for index, chunk in enumerate(chunks[1:], start=1)
  209. }
  210. for index in range(1, len(chunks)):
  211. audio_bytes = await pending_tasks[index]
  212. payload = {
  213. "index": index,
  214. "text": chunks[index],
  215. "audio": base64.b64encode(audio_bytes).decode("utf-8"),
  216. "format": OPENAI_TTS_FORMAT,
  217. }
  218. yield (json.dumps(payload, ensure_ascii=False) + "\n").encode("utf-8")
  219. await asyncio.sleep(0)
  220. except HTTPException as e:
  221. logger.error("generate proxy http error: %s", e.detail)
  222. yield (json.dumps({"error": e.detail}, ensure_ascii=False) + "\n").encode("utf-8")
  223. except Exception as e:
  224. logger.error(f"generate proxy error: {str(e)}")
  225. yield (json.dumps({"error": "TTS生成失败"}, ensure_ascii=False) + "\n").encode("utf-8")
  226. return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
  227. def normalize_openai_voice(voice: str) -> str:
  228. allowed_voices = {"alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"}
  229. normalized = (voice or "").strip().lower()
  230. return normalized if normalized in allowed_voices else OPENAI_TTS_DEFAULT_VOICE
  231. def get_audio_media_type(audio_format: str) -> str:
  232. mapping = {
  233. "wav": "audio/wav",
  234. "mp3": "audio/mpeg",
  235. "flac": "audio/flac",
  236. "opus": "audio/opus",
  237. "pcm16": "audio/L16",
  238. }
  239. return mapping.get(audio_format.lower(), "application/octet-stream")
  240. async def request_openai_tts_audio(text: str, voice: str) -> bytes:
  241. global tts_http_session
  242. payload = {
  243. "model": OPENAI_TTS_MODEL,
  244. "voice": normalize_openai_voice(voice),
  245. "input": text,
  246. "response_format": OPENAI_TTS_FORMAT,
  247. "speed": 1.0,
  248. }
  249. headers = {
  250. "Authorization": f"Bearer {OPENAI_TTS_API_KEY}",
  251. "Content-Type": "application/json",
  252. }
  253. if tts_http_session is None or tts_http_session.closed:
  254. timeout = aiohttp.ClientTimeout(total=120)
  255. connector = aiohttp.TCPConnector(limit=20, ttl_dns_cache=300)
  256. tts_http_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
  257. async with tts_http_session.post(
  258. f"{OPENAI_TTS_BASE_URL.rstrip('/')}/v1/audio/speech",
  259. headers=headers,
  260. json=payload,
  261. ) as response:
  262. if response.status != 200:
  263. response_text = await response.text()
  264. logger.error("OpenAI TTS request failed: %s", response_text)
  265. error_detail = "OpenAI TTS API 请求失败"
  266. try:
  267. error_data = json.loads(response_text)
  268. error_obj = error_data.get("error", {})
  269. error_message = error_obj.get("message")
  270. error_code = error_obj.get("code")
  271. if error_message:
  272. error_detail = f"{error_detail}: {error_message}"
  273. if response.status == 429 or error_code in {"rate_limit_exceeded", "model_not_found", "upstream_error"}:
  274. raise HTTPException(status_code=503, detail=error_detail)
  275. except json.JSONDecodeError:
  276. pass
  277. raise HTTPException(status_code=500 if response.status < 500 else 502, detail=error_detail)
  278. return await response.read()
  279. @app.post("/text-to-speech/")
  280. async def text_to_speech(request: TextToSpeechRequest):
  281. user_input = request.user_input.strip()
  282. if not user_input:
  283. raise HTTPException(status_code=400, detail="输入文本为空")
  284. text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
  285. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
  286. media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
  287. if os.path.exists(audio_path):
  288. with open(audio_path, "rb") as f:
  289. return Response(content=f.read(), media_type=media_type)
  290. try:
  291. audio_bytes = await request_openai_tts_audio(user_input, request.voice)
  292. with open(audio_path, "wb") as f:
  293. f.write(audio_bytes)
  294. return Response(content=audio_bytes, media_type=media_type)
  295. except HTTPException:
  296. raise
  297. except Exception as e:
  298. logger.error(f"TTS error: {str(e)}")
  299. raise HTTPException(status_code=500, detail="TTS生成失败")
  300. FIRST_CHUNK_SIZE = 80
  301. FOLLOWING_CHUNK_SIZE = 180
  302. def split_text_into_chunks(
  303. text: str,
  304. first_chunk_size: int = FIRST_CHUNK_SIZE,
  305. following_chunk_size: int = FOLLOWING_CHUNK_SIZE,
  306. ) -> list:
  307. import re
  308. def split_long_sentence(sentence: str, chunk_size: int) -> list[str]:
  309. parts = []
  310. remaining = sentence.strip()
  311. while len(remaining) > chunk_size:
  312. split_at = -1
  313. candidate = remaining[: chunk_size + 1]
  314. for pattern in [r"[,:;)\]]\s+", r"\s+"]:
  315. matches = list(re.finditer(pattern, candidate))
  316. if matches:
  317. split_at = matches[-1].start()
  318. break
  319. if split_at <= 0:
  320. split_at = chunk_size
  321. part = remaining[:split_at].strip()
  322. if not part:
  323. part = remaining[:chunk_size].strip()
  324. split_at = len(part)
  325. parts.append(part)
  326. remaining = remaining[split_at:].strip()
  327. if remaining:
  328. parts.append(remaining)
  329. return parts
  330. sentences = re.split(r"(?<=[.!?])\s+", text.strip())
  331. chunks = []
  332. current_chunk = ""
  333. current_limit = first_chunk_size
  334. for sentence in sentences:
  335. sentence = sentence.strip()
  336. if not sentence:
  337. continue
  338. if len(sentence) > current_limit:
  339. if current_chunk:
  340. chunks.append(current_chunk)
  341. current_chunk = ""
  342. current_limit = following_chunk_size
  343. long_parts = split_long_sentence(sentence, current_limit)
  344. chunks.extend(long_parts)
  345. if long_parts:
  346. current_limit = following_chunk_size
  347. continue
  348. next_chunk = f"{current_chunk} {sentence}".strip() if current_chunk else sentence
  349. if len(next_chunk) <= current_limit:
  350. current_chunk = next_chunk
  351. else:
  352. if current_chunk:
  353. chunks.append(current_chunk)
  354. current_limit = following_chunk_size
  355. current_chunk = sentence
  356. if current_chunk:
  357. chunks.append(current_chunk)
  358. return chunks
  359. async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  360. text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
  361. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
  362. if os.path.exists(audio_path):
  363. with open(audio_path, "rb") as f:
  364. yield f.read()
  365. else:
  366. try:
  367. audio_bytes = await request_openai_tts_audio(chunk, voice)
  368. with open(audio_path, "wb") as f:
  369. f.write(audio_bytes)
  370. yield audio_bytes
  371. except HTTPException as e:
  372. raise HTTPException(status_code=e.status_code, detail=e.detail)
  373. except Exception as e:
  374. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  375. @app.post("/page-to-speech/")
  376. async def page_to_speech(request: TextToSpeechRequest):
  377. user_input = request.user_input.strip()
  378. if not user_input:
  379. raise HTTPException(status_code=400, detail="输入文本为空")
  380. full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
  381. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.{OPENAI_TTS_FORMAT}")
  382. media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
  383. if os.path.exists(full_audio_path):
  384. return StreamingResponse(open(full_audio_path, "rb"), media_type=media_type)
  385. chunks = split_text_into_chunks(user_input)
  386. async def audio_generator() -> AsyncGenerator[bytes, None]:
  387. full_audio_buffer = io.BytesIO()
  388. for chunk in chunks:
  389. async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
  390. yield audio_data
  391. full_audio_buffer.write(audio_data)
  392. await asyncio.sleep(0)
  393. full_audio_buffer.seek(0)
  394. with open(full_audio_path, "wb") as f:
  395. f.write(full_audio_buffer.getvalue())
  396. return StreamingResponse(audio_generator(), media_type=media_type)
  397. @app.get("/health")
  398. async def health_check():
  399. return {"status": "healthy"}
  400. if __name__ == "__main__":
  401. import uvicorn
  402. uvicorn.run(app, host="0.0.0.0", port=8005)