main.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response
  2. # from fastapi.middleware.cors import CORSMiddleware
  3. # from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
  4. # from fastapi.staticfiles import StaticFiles
  5. # import os
  6. # import shutil
  7. # import uuid
  8. # from pydantic import BaseModel
  9. # import hashlib
  10. # import asyncio
  11. # from typing import AsyncGenerator
  12. # import soundfile as sf
  13. # import io
  14. # import logging
  15. # from kokoro import KPipeline # Assuming kokoro is installed and available
  16. # # Set up logging
  17. # logging.basicConfig(level=logging.INFO)
  18. # logger = logging.getLogger(__name__)
  19. # # Initialize FastAPI app
  20. # app = FastAPI()
  21. # # Configure CORS
  22. # origins = ["*"]
  23. # app.add_middleware(
  24. # CORSMiddleware,
  25. # allow_origins=origins,
  26. # allow_credentials=True,
  27. # allow_methods=["*"],
  28. # allow_headers=["*"],
  29. # )
  30. # # Directory for uploaded files
  31. # UPLOAD_DIRECTORY = "static/files"
  32. # if not os.path.exists(UPLOAD_DIRECTORY):
  33. # os.makedirs(UPLOAD_DIRECTORY)
  34. # # Mount static files
  35. # app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
  36. # app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  37. # app.mount("/static", StaticFiles(directory="static"), name="static")
  38. # # Audio cache directory
  39. # CACHE_DIR = "audio_cache"
  40. # os.makedirs(CACHE_DIR, exist_ok=True)
  41. # # Root redirect to PDF viewer
  42. # @app.get("/")
  43. # def root():
  44. # return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
  45. # # Sanitize filename
  46. # def sanitize_filename(name: str) -> str:
  47. # return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
  48. # # PDF upload endpoint
  49. # @app.post("/upload-pdf")
  50. # async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
  51. # if file.content_type != 'application/pdf':
  52. # raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  53. # sanitized_name = sanitize_filename(custom_name)
  54. # if not sanitized_name:
  55. # return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  56. # unique_filename = f"{sanitized_name}.pdf"
  57. # file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
  58. # if os.path.exists(file_path):
  59. # return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  60. # try:
  61. # with open(file_path, "wb") as buffer:
  62. # shutil.copyfileobj(file.file, buffer)
  63. # except Exception as e:
  64. # raise HTTPException(status_code=500, detail="上传过程中出错")
  65. # finally:
  66. # file.file.close()
  67. # file_relative_path = f"/static/files/{unique_filename}"
  68. # return JSONResponse(content={"success": True, "file_path": file_relative_path})
  69. # # List PDFs endpoint
  70. # @app.get("/list-pdfs")
  71. # async def list_pdfs():
  72. # try:
  73. # files = os.listdir(UPLOAD_DIRECTORY)
  74. # pdf_files = [
  75. # {"name": file, "url": f"/static/files/{file}"}
  76. # for file in files if file.lower().endswith(".pdf")
  77. # ]
  78. # return JSONResponse(content={"success": True, "files": pdf_files})
  79. # except Exception as e:
  80. # raise HTTPException(status_code=500, detail="无法获取文件列表")
  81. # # TTS Server with Kokoro
  82. # class TextToSpeechServer:
  83. # def __init__(self):
  84. # self.pipeline = None
  85. # def load_model(self, lang_code='a'):
  86. # try:
  87. # logger.info("Loading KPipeline model...")
  88. # self.pipeline = KPipeline(lang_code=lang_code)
  89. # logger.info("Model loaded successfully")
  90. # except Exception as e:
  91. # logger.error(f"Failed to load model: {str(e)}")
  92. # raise
  93. # # Initialize TTS server
  94. # tts_server = TextToSpeechServer()
  95. # # Startup event to load Kokoro model
  96. # @app.on_event("startup")
  97. # async def startup_event():
  98. # tts_server.load_model()
  99. # # Request models
  100. # class TextToSpeechRequest(BaseModel):
  101. # user_input: str
  102. # voice: str = 'af_heart' # Default voice for Kokoro
  103. # speed: float = 1.0 # Default speed
  104. # # Text-to-speech endpoint (streaming)
  105. # @app.post("/text-to-speech/")
  106. # async def text_to_speech(request: TextToSpeechRequest):
  107. # user_input = request.user_input.strip()
  108. # if not user_input:
  109. # raise HTTPException(status_code=400, detail="输入文本为空")
  110. # text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  111. # audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  112. # if os.path.exists(audio_path):
  113. # with open(audio_path, "rb") as f:
  114. # return Response(content=f.read(), media_type="audio/wav")
  115. # async def audio_generator() -> AsyncGenerator[bytes, None]:
  116. # try:
  117. # if not tts_server.pipeline:
  118. # raise HTTPException(status_code=503, detail="Model not initialized")
  119. # print(user_input)
  120. # generator = tts_server.pipeline(
  121. # text=user_input,
  122. # voice=request.voice,
  123. # speed=request.speed,
  124. # split_pattern=r'\n+'
  125. # )
  126. # # 用于拼接所有音频数据的 NumPy 数组
  127. # full_audio_data = []
  128. # for i, (gs, ps, audio) in enumerate(generator):
  129. # print(f"Generating segment {i}")
  130. # full_audio_data.append(audio) # 假设 audio 是 NumPy 数组
  131. # # 将所有音频片段拼接成一个完整的音频
  132. # import numpy as np
  133. # concatenated_audio = np.concatenate(full_audio_data)
  134. # # 将拼接后的音频写入 WAV 文件
  135. # buffer = io.BytesIO()
  136. # sf.write(buffer, concatenated_audio, 24000, format='WAV')
  137. # buffer.seek(0)
  138. # audio_data = buffer.getvalue()
  139. # # 流式传输整个音频
  140. # yield audio_data
  141. # # 保存到缓存
  142. # with open(audio_path, "wb") as f:
  143. # f.write(audio_data)
  144. # except Exception as e:
  145. # logger.error(f"TTS error: {str(e)}")
  146. # raise HTTPException(status_code=500, detail=str(e))
  147. # return StreamingResponse(audio_generator(), media_type="audio/wav")
  148. # # Page-to-speech endpoint (chunked streaming)
  149. # MAX_CHUNK_SIZE = 200
  150. # def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
  151. # import re
  152. # sentences = re.split('(?<=[.!?]) +', text)
  153. # chunks = []
  154. # current_chunk = ""
  155. # for sentence in sentences:
  156. # if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
  157. # current_chunk += " " + sentence if current_chunk else sentence
  158. # else:
  159. # if current_chunk:
  160. # chunks.append(current_chunk)
  161. # if len(sentence) > max_chunk_size:
  162. # for i in range(0, len(sentence), max_chunk_size):
  163. # chunks.append(sentence[i:i + max_chunk_size])
  164. # current_chunk = ""
  165. # else:
  166. # current_chunk = sentence
  167. # if current_chunk:
  168. # chunks.append(current_chunk)
  169. # return chunks
  170. # async def generate_kokoro_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  171. # text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
  172. # audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  173. # if os.path.exists(audio_path):
  174. # with open(audio_path, "rb") as f:
  175. # yield f.read()
  176. # else:
  177. # try:
  178. # if not tts_server.pipeline:
  179. # raise HTTPException(status_code=503, detail="Model not initialized")
  180. # generator = tts_server.pipeline(
  181. # text=chunk,
  182. # voice=voice,
  183. # speed=speed,
  184. # split_pattern=r'\n+'
  185. # )
  186. # full_audio_buffer = io.BytesIO() # For caching
  187. # for i, (gs, ps, audio) in enumerate(generator):
  188. # buffer = io.BytesIO()
  189. # sf.write(buffer, audio, 24000, format='WAV')
  190. # buffer.seek(0)
  191. # audio_data = buffer.getvalue()
  192. # yield audio_data # Stream immediately
  193. # full_audio_buffer.write(audio_data)
  194. # break # Take first segment (adjust if multiple segments needed)
  195. # # Cache the chunk
  196. # full_audio_buffer.seek(0)
  197. # with open(audio_path, "wb") as f:
  198. # f.write(full_audio_buffer.getvalue())
  199. # except Exception as e:
  200. # raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  201. # @app.post("/page-to-speech/")
  202. # async def page_to_speech(request: TextToSpeechRequest):
  203. # user_input = request.user_input.strip()
  204. # if not user_input:
  205. # raise HTTPException(status_code=400, detail="输入文本为空")
  206. # full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  207. # full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
  208. # if os.path.exists(full_audio_path):
  209. # return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
  210. # chunks = split_text_into_chunks(user_input)
  211. # async def audio_generator() -> AsyncGenerator[bytes, None]:
  212. # full_audio_buffer = io.BytesIO() # For caching full audio
  213. # for chunk in chunks:
  214. # async for audio_data in generate_kokoro_audio(chunk, request.voice, request.speed):
  215. # yield audio_data # Stream each chunk's audio
  216. # full_audio_buffer.write(audio_data)
  217. # await asyncio.sleep(0) # Yield control to event loop
  218. # # Save the full audio to cache
  219. # full_audio_buffer.seek(0)
  220. # with open(full_audio_path, "wb") as f:
  221. # f.write(full_audio_buffer.getvalue())
  222. # return StreamingResponse(audio_generator(), media_type="audio/wav")
  223. # # Health check
  224. # @app.get("/health")
  225. # async def health_check():
  226. # return {"status": "healthy" if tts_server.pipeline else "model_not_loaded"}
  227. # if __name__ == "__main__":
  228. # import uvicorn
  229. # uvicorn.run(app, host="0.0.0.0", port=8005)
  230. from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response
  231. from fastapi.middleware.cors import CORSMiddleware
  232. from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
  233. from fastapi.staticfiles import StaticFiles
  234. import os
  235. import shutil
  236. import uuid
  237. from pydantic import BaseModel
  238. import hashlib
  239. import asyncio
  240. from typing import AsyncGenerator
  241. import soundfile as sf
  242. import io
  243. import logging
  244. import numpy as np
  245. import re
  246. from kokoro import KPipeline # 假设 kokoro 已安装并可用
  247. # 设置日志
  248. logging.basicConfig(level=logging.INFO)
  249. logger = logging.getLogger(__name__)
  250. # 初始化 FastAPI 应用
  251. app = FastAPI()
  252. # 配置 CORS
  253. origins = ["*"]
  254. app.add_middleware(
  255. CORSMiddleware,
  256. allow_origins=origins,
  257. allow_credentials=True,
  258. allow_methods=["*"],
  259. allow_headers=["*"],
  260. )
  261. # 上传文件目录
  262. UPLOAD_DIRECTORY = "static/files"
  263. if not os.path.exists(UPLOAD_DIRECTORY):
  264. os.makedirs(UPLOAD_DIRECTORY)
  265. # 挂载静态文件
  266. app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
  267. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  268. app.mount("/static", StaticFiles(directory="static"), name="static")
  269. # 音频缓存目录
  270. CACHE_DIR = "audio_cache"
  271. os.makedirs(CACHE_DIR, exist_ok=True)
  272. # 根路径重定向到 PDF 查看器
  273. @app.get("/")
  274. def root():
  275. return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
  276. # 清理文件名
  277. def sanitize_filename(name: str) -> str:
  278. return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
  279. # PDF 上传端点
  280. @app.post("/upload-pdf")
  281. async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
  282. if file.content_type != 'application/pdf':
  283. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  284. sanitized_name = sanitize_filename(custom_name)
  285. if not sanitized_name:
  286. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  287. unique_filename = f"{sanitized_name}.pdf"
  288. file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
  289. if os.path.exists(file_path):
  290. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  291. try:
  292. with open(file_path, "wb") as buffer:
  293. shutil.copyfileobj(file.file, buffer)
  294. except Exception as e:
  295. raise HTTPException(status_code=500, detail="上传过程中出错")
  296. finally:
  297. file.file.close()
  298. file_relative_path = f"/static/files/{unique_filename}"
  299. return JSONResponse(content={"success": True, "file_path": file_relative_path})
  300. # 列出 PDF 文件端点
  301. @app.get("/list-pdfs")
  302. async def list_pdfs():
  303. try:
  304. files = os.listdir(UPLOAD_DIRECTORY)
  305. pdf_files = [
  306. {"name": file, "url": f"/static/files/{file}"}
  307. for file in files if file.lower().endswith(".pdf")
  308. ]
  309. return JSONResponse(content={"success": True, "files": pdf_files})
  310. except Exception as e:
  311. raise HTTPException(status_code=500, detail="无法获取文件列表")
  312. # TTS 服务类
  313. class TextToSpeechServer:
  314. def __init__(self):
  315. self.pipeline = None
  316. def load_model(self, lang_code='a'):
  317. try:
  318. logger.info("加载 KPipeline 模型...")
  319. self.pipeline = KPipeline(lang_code=lang_code)
  320. logger.info("模型加载成功")
  321. except Exception as e:
  322. logger.error(f"模型加载失败: {str(e)}")
  323. raise
  324. # 初始化 TTS 服务
  325. tts_server = TextToSpeechServer()
  326. # 应用启动时加载 Kokoro 模型
  327. @app.on_event("startup")
  328. async def startup_event():
  329. tts_server.load_model()
  330. # 请求模型
  331. class TextToSpeechRequest(BaseModel):
  332. user_input: str
  333. voice: str = 'af_heart' # 默认语音
  334. speed: float = 1.0 # 默认速度
  335. # 文本转语音端点(流式)
  336. @app.post("/text-to-speech/")
  337. async def text_to_speech(request: TextToSpeechRequest):
  338. user_input = request.user_input.strip()
  339. if not user_input:
  340. raise HTTPException(status_code=400, detail="输入文本为空")
  341. text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  342. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  343. if os.path.exists(audio_path):
  344. with open(audio_path, "rb") as f:
  345. return Response(content=f.read(), media_type="audio/wav")
  346. async def audio_generator() -> AsyncGenerator[bytes, None]:
  347. try:
  348. if not tts_server.pipeline:
  349. raise HTTPException(status_code=503, detail="模型未初始化")
  350. generator = tts_server.pipeline(
  351. text=user_input,
  352. voice=request.voice,
  353. speed=request.speed,
  354. split_pattern=r'\n+'
  355. )
  356. full_audio_data = []
  357. for i, (gs, ps, audio) in enumerate(generator):
  358. full_audio_data.append(audio)
  359. concatenated_audio = np.concatenate(full_audio_data)
  360. buffer = io.BytesIO()
  361. sf.write(buffer, concatenated_audio, 24000, format='WAV')
  362. buffer.seek(0)
  363. audio_data = buffer.getvalue()
  364. yield audio_data
  365. with open(audio_path, "wb") as f:
  366. f.write(audio_data)
  367. except Exception as e:
  368. logger.error(f"TTS 错误: {str(e)}")
  369. raise HTTPException(status_code=500, detail=str(e))
  370. return StreamingResponse(audio_generator(), media_type="audio/wav")
  371. # 按句子分割文本
  372. def split_text_into_sentences(text: str) -> list:
  373. # 使用正则表达式按句号、问号、感叹号分割句子
  374. sentences = re.split(r'(?<=[.!?])\s+', text.strip())
  375. return [s.strip() for s in sentences if s.strip()]
  376. # 生成单句音频
  377. async def generate_kokoro_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  378. text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
  379. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  380. if os.path.exists(audio_path):
  381. with open(audio_path, "rb") as f:
  382. yield f.read()
  383. else:
  384. try:
  385. if not tts_server.pipeline:
  386. raise HTTPException(status_code=503, detail="模型未初始化")
  387. generator = tts_server.pipeline(
  388. text=chunk,
  389. voice=voice,
  390. speed=speed,
  391. split_pattern=r'\n+'
  392. )
  393. full_audio_buffer = io.BytesIO()
  394. for i, (gs, ps, audio) in enumerate(generator):
  395. buffer = io.BytesIO()
  396. sf.write(buffer, audio, 24000, format='WAV')
  397. buffer.seek(0)
  398. audio_data = buffer.getvalue()
  399. yield audio_data
  400. full_audio_buffer.write(audio_data)
  401. break # 仅取第一个片段
  402. full_audio_buffer.seek(0)
  403. with open(audio_path, "wb") as f:
  404. f.write(full_audio_buffer.getvalue())
  405. except Exception as e:
  406. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  407. # 页面转语音端点(按句子逐句转换并播放)
  408. @app.post("/page-to-speech/")
  409. async def page_to_speech(request: TextToSpeechRequest):
  410. user_input = request.user_input.strip()
  411. if not user_input:
  412. raise HTTPException(status_code=400, detail="输入文本为空")
  413. full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
  414. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
  415. if os.path.exists(full_audio_path):
  416. return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
  417. sentences = split_text_into_sentences(user_input)
  418. if not sentences:
  419. raise HTTPException(status_code=400, detail="没有有效的句子")
  420. async def audio_generator() -> AsyncGenerator[bytes, None]:
  421. full_audio_buffer = io.BytesIO() # 用于缓存完整音频
  422. for sentence in sentences:
  423. logger.info(f"处理句子: {sentence}")
  424. async for audio_data in generate_kokoro_audio(sentence, request.voice, request.speed):
  425. yield audio_data # 立即流式传输当前句子的音频
  426. full_audio_buffer.write(audio_data)
  427. await asyncio.sleep(0) # 让出控制权给事件循环
  428. # 保存完整音频到缓存
  429. full_audio_buffer.seek(0)
  430. with open(full_audio_path, "wb") as f:
  431. f.write(full_audio_buffer.getvalue())
  432. return StreamingResponse(audio_generator(), media_type="audio/wav")
  433. # 健康检查
  434. @app.get("/health")
  435. async def health_check():
  436. return {"status": "healthy" if tts_server.pipeline else "model_not_loaded"}
  437. if __name__ == "__main__":
  438. import uvicorn
  439. uvicorn.run(app, host="0.0.0.0", port=8005)