main_kokoro.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
  4. from fastapi.staticfiles import StaticFiles
  5. import os
  6. import shutil
  7. import uuid
  8. from pydantic import BaseModel
  9. import hashlib
  10. import asyncio
  11. from typing import AsyncGenerator
  12. import soundfile as sf
  13. import io
  14. import logging
  15. import numpy as np
  16. import re
  17. from kokoro import KPipeline # 假设 kokoro 已安装并可用
  18. # 设置日志
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging.getLogger(__name__)
  21. # 初始化 FastAPI 应用
  22. app = FastAPI()
  23. # 配置 CORS
  24. origins = ["*"]
  25. app.add_middleware(
  26. CORSMiddleware,
  27. allow_origins=origins,
  28. allow_credentials=True,
  29. allow_methods=["*"],
  30. allow_headers=["*"],
  31. )
  32. # 上传文件目录
  33. UPLOAD_DIRECTORY = "static/files"
  34. if not os.path.exists(UPLOAD_DIRECTORY):
  35. os.makedirs(UPLOAD_DIRECTORY)
  36. # 挂载静态文件
  37. app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
  38. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  39. app.mount("/static", StaticFiles(directory="static"), name="static")
  40. # 音频缓存目录
  41. CACHE_DIR = "audio_cache"
  42. os.makedirs(CACHE_DIR, exist_ok=True)
  43. # 根路径重定向到 PDF 查看器
  44. @app.get("/")
  45. def root():
  46. return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
  47. # 清理文件名
  48. def sanitize_filename(name: str) -> str:
  49. return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
  50. # PDF 上传端点
  51. @app.post("/upload-pdf")
  52. async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
  53. if file.content_type != 'application/pdf':
  54. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  55. sanitized_name = sanitize_filename(custom_name)
  56. if not sanitized_name:
  57. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  58. unique_filename = f"{sanitized_name}.pdf"
  59. file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
  60. if os.path.exists(file_path):
  61. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  62. try:
  63. with open(file_path, "wb") as buffer:
  64. shutil.copyfileobj(file.file, buffer)
  65. except Exception as e:
  66. raise HTTPException(status_code=500, detail="上传过程中出错")
  67. finally:
  68. file.file.close()
  69. file_relative_path = f"/static/files/{unique_filename}"
  70. return JSONResponse(content={"success": True, "file_path": file_relative_path})
  71. # 列出 PDF 文件端点
  72. @app.get("/list-pdfs")
  73. async def list_pdfs():
  74. try:
  75. files = os.listdir(UPLOAD_DIRECTORY)
  76. pdf_files = [
  77. {"name": file, "url": f"/static/files/{file}"}
  78. for file in files if file.lower().endswith(".pdf")
  79. ]
  80. return JSONResponse(content={"success": True, "files": pdf_files})
  81. except Exception as e:
  82. raise HTTPException(status_code=500, detail="无法获取文件列表")
  83. # TTS 服务类
  84. class TextToSpeechServer:
  85. def __init__(self):
  86. self.pipeline = None
  87. def load_model(self, lang_code='a'):
  88. try:
  89. logger.info("加载 KPipeline 模型...")
  90. self.pipeline = KPipeline(lang_code=lang_code)
  91. logger.info("模型加载成功")
  92. except Exception as e:
  93. logger.error(f"模型加载失败: {str(e)}")
  94. raise
  95. # 初始化 TTS 服务
  96. tts_server = TextToSpeechServer()
  97. # 应用启动时加载 Kokoro 模型
  98. @app.on_event("startup")
  99. async def startup_event():
  100. tts_server.load_model()
  101. # 请求模型
  102. class TextToSpeechRequest(BaseModel):
  103. user_input: str
  104. voice: str = 'af_heart' # 默认语音
  105. speed: float = 1.0 # 默认速度
  106. # 文本转语音端点(流式)
  107. @app.post("/text-to-speech/")
  108. async def text_to_speech(request: TextToSpeechRequest):
  109. user_input = request.user_input.strip()
  110. if not user_input:
  111. raise HTTPException(status_code=400, detail="输入文本为空")
  112. text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  113. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  114. if os.path.exists(audio_path):
  115. with open(audio_path, "rb") as f:
  116. return Response(content=f.read(), media_type="audio/wav")
  117. async def audio_generator() -> AsyncGenerator[bytes, None]:
  118. try:
  119. if not tts_server.pipeline:
  120. raise HTTPException(status_code=503, detail="模型未初始化")
  121. generator = tts_server.pipeline(
  122. text=user_input,
  123. voice=request.voice,
  124. speed=request.speed,
  125. split_pattern=r'\n+'
  126. )
  127. full_audio_data = []
  128. for i, (gs, ps, audio) in enumerate(generator):
  129. full_audio_data.append(audio)
  130. concatenated_audio = np.concatenate(full_audio_data)
  131. buffer = io.BytesIO()
  132. sf.write(buffer, concatenated_audio, 24000, format='WAV')
  133. buffer.seek(0)
  134. audio_data = buffer.getvalue()
  135. yield audio_data
  136. with open(audio_path, "wb") as f:
  137. f.write(audio_data)
  138. except Exception as e:
  139. logger.error(f"TTS 错误: {str(e)}")
  140. raise HTTPException(status_code=500, detail=str(e))
  141. return StreamingResponse(audio_generator(), media_type="audio/wav")
  142. # 按句子分割文本
  143. def split_text_into_sentences(text: str) -> list:
  144. # 使用正则表达式按句号、问号、感叹号分割句子
  145. sentences = re.split(r'(?<=[.!?])\s+', text.strip())
  146. return [s.strip() for s in sentences if s.strip()]
  147. # 生成单句音频
  148. async def generate_kokoro_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  149. text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
  150. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  151. if os.path.exists(audio_path):
  152. with open(audio_path, "rb") as f:
  153. yield f.read()
  154. else:
  155. try:
  156. if not tts_server.pipeline:
  157. raise HTTPException(status_code=503, detail="模型未初始化")
  158. generator = tts_server.pipeline(
  159. text=chunk,
  160. voice=voice,
  161. speed=speed,
  162. split_pattern=r'\n+'
  163. )
  164. full_audio_buffer = io.BytesIO()
  165. for i, (gs, ps, audio) in enumerate(generator):
  166. buffer = io.BytesIO()
  167. sf.write(buffer, audio, 24000, format='WAV')
  168. buffer.seek(0)
  169. audio_data = buffer.getvalue()
  170. yield audio_data
  171. full_audio_buffer.write(audio_data)
  172. break # 仅取第一个片段
  173. full_audio_buffer.seek(0)
  174. with open(audio_path, "wb") as f:
  175. f.write(full_audio_buffer.getvalue())
  176. except Exception as e:
  177. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  178. # 页面转语音端点(按句子逐句转换并播放)
  179. @app.post("/page-to-speech/")
  180. async def page_to_speech(request: TextToSpeechRequest):
  181. user_input = request.user_input.strip()
  182. if not user_input:
  183. raise HTTPException(status_code=400, detail="输入文本为空")
  184. full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  185. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
  186. if os.path.exists(full_audio_path):
  187. return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
  188. sentences = split_text_into_sentences(user_input)
  189. if not sentences:
  190. raise HTTPException(status_code=400, detail="没有有效的句子")
  191. async def audio_generator() -> AsyncGenerator[bytes, None]:
  192. full_audio_buffer = io.BytesIO() # 用于缓存完整音频
  193. for sentence in sentences:
  194. logger.info(f"处理句子: {sentence}")
  195. async for audio_data in generate_kokoro_audio(sentence, request.voice, request.speed):
  196. yield audio_data # 立即流式传输当前句子的音频
  197. full_audio_buffer.write(audio_data)
  198. await asyncio.sleep(0) # 让出控制权给事件循环
  199. # 保存完整音频到缓存
  200. full_audio_buffer.seek(0)
  201. with open(full_audio_path, "wb") as f:
  202. f.write(full_audio_buffer.getvalue())
  203. return StreamingResponse(audio_generator(), media_type="audio/wav")
  204. # 健康检查
  205. @app.get("/health")
  206. async def health_check():
  207. return {"status": "healthy" if tts_server.pipeline else "model_not_loaded"}
  208. if __name__ == "__main__":
  209. import uvicorn
  210. uvicorn.run(app, host="0.0.0.0", port=8005)