main_server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
  4. from fastapi.staticfiles import StaticFiles
  5. import os
  6. import shutil
  7. import hashlib
  8. import asyncio
  9. from typing import AsyncGenerator
  10. import aiohttp
  11. import io
  12. import logging
  13. import base64
  14. import json
  15. # Set up logging
  16. logging.basicConfig(level=logging.INFO)
  17. logger = logging.getLogger(__name__)
  18. # Initialize FastAPI app
  19. app = FastAPI()
  20. # Configure CORS
  21. origins = ["*"]
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=origins,
  25. allow_credentials=True,
  26. allow_methods=["*"],
  27. allow_headers=["*"],
  28. )
  29. # Directory for uploaded files
  30. UPLOAD_DIRECTORY = "static/files"
  31. if not os.path.exists(UPLOAD_DIRECTORY):
  32. os.makedirs(UPLOAD_DIRECTORY)
  33. # Mount static files
  34. app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
  35. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  36. app.mount("/static", StaticFiles(directory="static"), name="static")
  37. # Audio cache directory
  38. CACHE_DIR = "audio_cache"
  39. os.makedirs(CACHE_DIR, exist_ok=True)
  40. # Root redirect to PDF viewer
  41. @app.get("/")
  42. def root():
  43. return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
  44. # Sanitize filename
  45. def sanitize_filename(name: str) -> str:
  46. return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
  47. # PDF upload endpoint
  48. @app.post("/upload-pdf")
  49. async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
  50. if file.content_type != 'application/pdf':
  51. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  52. sanitized_name = sanitize_filename(custom_name)
  53. if not sanitized_name:
  54. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  55. unique_filename = f"{sanitized_name}.pdf"
  56. file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
  57. if os.path.exists(file_path):
  58. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  59. try:
  60. with open(file_path, "wb") as buffer:
  61. shutil.copyfileobj(file.file, buffer)
  62. except Exception as e:
  63. raise HTTPException(status_code=500, detail="上传过程中出错")
  64. finally:
  65. file.file.close()
  66. file_relative_path = f"/static/files/{unique_filename}"
  67. return JSONResponse(content={"success": True, "file_path": file_relative_path})
  68. # List PDFs endpoint
  69. @app.get("/list-pdfs")
  70. async def list_pdfs():
  71. try:
  72. files = os.listdir(UPLOAD_DIRECTORY)
  73. pdf_files = [
  74. {"name": file, "url": f"/static/files/{file}"}
  75. for file in files if file.lower().endswith(".pdf")
  76. ]
  77. return JSONResponse(content={"success": True, "files": pdf_files})
  78. except Exception as e:
  79. raise HTTPException(status_code=500, detail="无法获取文件列表")
  80. # Request models
  81. from pydantic import BaseModel
  82. class TextToSpeechRequest(BaseModel):
  83. user_input: str
  84. voice: str = 'af_heart' # Default voice
  85. speed: float = 1.0 # Default speed
  86. # Text-to-speech endpoint (streaming)
  87. @app.post("/text-to-speech/")
  88. async def text_to_speech(request: TextToSpeechRequest):
  89. user_input = request.user_input.strip()
  90. if not user_input:
  91. raise HTTPException(status_code=400, detail="输入文本为空")
  92. text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  93. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  94. if os.path.exists(audio_path):
  95. with open(audio_path, "rb") as f:
  96. return Response(content=f.read(), media_type="audio/wav")
  97. async def audio_generator() -> AsyncGenerator[bytes, None]:
  98. try:
  99. async with aiohttp.ClientSession() as session:
  100. async with session.post(
  101. 'http://141.140.15.30:8028/generate',
  102. headers={'Content-Type': 'application/json'},
  103. json={
  104. 'text': user_input,
  105. 'voice': request.voice,
  106. 'speed': request.speed
  107. }
  108. ) as response:
  109. if response.status != 200:
  110. raise HTTPException(status_code=500, detail="TTS API 请求失败")
  111. # Read NDJSON response
  112. buffer = ""
  113. full_audio = io.BytesIO()
  114. async for chunk in response.content.iter_any():
  115. buffer += chunk.decode('utf-8')
  116. lines = buffer.split('\n')
  117. buffer = lines[-1] # Keep incomplete line
  118. for line in lines[:-1]:
  119. if not line.strip():
  120. continue
  121. try:
  122. data = json.loads(line)
  123. if data.get('error'):
  124. raise HTTPException(status_code=500, detail=data['error'])
  125. audio_b64 = data.get('audio')
  126. if audio_b64:
  127. audio_bytes = base64.b64decode(audio_b64)
  128. full_audio.write(audio_bytes)
  129. yield audio_bytes
  130. except json.JSONDecodeError as e:
  131. logger.error(f"JSON decode error: {str(e)}")
  132. continue
  133. # Handle final buffer
  134. if buffer.strip():
  135. try:
  136. data = json.loads(buffer)
  137. if data.get('audio'):
  138. audio_bytes = base64.b64decode(data['audio'])
  139. full_audio.write(audio_bytes)
  140. yield audio_bytes
  141. except json.JSONDecodeError:
  142. pass
  143. # Save to cache
  144. full_audio.seek(0)
  145. with open(audio_path, "wb") as f:
  146. f.write(full_audio.getvalue())
  147. except Exception as e:
  148. logger.error(f"TTS error: {str(e)}")
  149. raise HTTPException(status_code=500, detail=str(e))
  150. return StreamingResponse(audio_generator(), media_type="audio/wav")
  151. # Page-to-speech endpoint (chunked streaming)
  152. MAX_CHUNK_SIZE = 200
  153. def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
  154. import re
  155. sentences = re.split('(?<=[.!?]) +', text)
  156. chunks = []
  157. current_chunk = ""
  158. for sentence in sentences:
  159. if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
  160. current_chunk += " " + sentence if current_chunk else sentence
  161. else:
  162. if current_chunk:
  163. chunks.append(current_chunk)
  164. if len(sentence) > max_chunk_size:
  165. for i in range(0, len(sentence), max_chunk_size):
  166. chunks.append(sentence[i:i + max_chunk_size])
  167. current_chunk = ""
  168. else:
  169. current_chunk = sentence
  170. if current_chunk:
  171. chunks.append(current_chunk)
  172. return chunks
  173. async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  174. text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
  175. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  176. if os.path.exists(audio_path):
  177. with open(audio_path, "rb") as f:
  178. yield f.read()
  179. else:
  180. try:
  181. async with aiohttp.ClientSession() as session:
  182. async with session.post(
  183. 'http://141.140.15.30:8028/generate',
  184. headers={'Content-Type': 'application/json'},
  185. json={
  186. 'text': chunk,
  187. 'voice': voice,
  188. 'speed': speed
  189. }
  190. ) as response:
  191. if response.status != 200:
  192. raise HTTPException(status_code=500, detail="TTS API 请求失败")
  193. # Read NDJSON response
  194. buffer = ""
  195. async for chunk in response.content.iter_any():
  196. buffer += chunk.decode('utf-8')
  197. lines = buffer.split('\n')
  198. buffer = lines[-1]
  199. for line in lines[:-1]:
  200. if not line.strip():
  201. continue
  202. try:
  203. data = json.loads(line)
  204. if data.get('error'):
  205. raise HTTPException(status_code=500, detail=data['error'])
  206. audio_b64 = data.get('audio')
  207. if audio_b64:
  208. audio_bytes = base64.b64decode(audio_b64)
  209. yield audio_bytes
  210. # Cache the chunk
  211. with open(audio_path, "wb") as f:
  212. f.write(audio_bytes)
  213. except json.JSONDecodeError as e:
  214. logger.error(f"JSON decode error: {str(e)}")
  215. continue
  216. # Handle final buffer
  217. if buffer.strip():
  218. try:
  219. data = json.loads(buffer)
  220. if data.get('audio'):
  221. audio_bytes = base64.b64decode(data['audio'])
  222. yield audio_bytes
  223. with open(audio_path, "wb") as f:
  224. f.write(audio_bytes)
  225. except json.JSONDecodeError:
  226. pass
  227. except Exception as e:
  228. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  229. @app.post("/page-to-speech/")
  230. async def page_to_speech(request: TextToSpeechRequest):
  231. user_input = request.user_input.strip()
  232. if not user_input:
  233. raise HTTPException(status_code=400, detail="输入文本为空")
  234. full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  235. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
  236. if os.path.exists(full_audio_path):
  237. return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
  238. chunks = split_text_into_chunks(user_input)
  239. async def audio_generator() -> AsyncGenerator[bytes, None]:
  240. full_audio_buffer = io.BytesIO() # For caching full audio
  241. for chunk in chunks:
  242. async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
  243. yield audio_data # Stream each chunk's audio
  244. full_audio_buffer.write(audio_data)
  245. await asyncio.sleep(0) # Yield control to event loop
  246. # Save the full audio to cache
  247. full_audio_buffer.seek(0)
  248. with open(full_audio_path, "wb") as f:
  249. f.write(full_audio_buffer.getvalue())
  250. return StreamingResponse(audio_generator(), media_type="audio/wav")
  251. # Health check
  252. @app.get("/health")
  253. async def health_check():
  254. return {"status": "healthy"}
  255. if __name__ == "__main__":
  256. import uvicorn
  257. uvicorn.run(app, host="0.0.0.0", port=8005)