tts_zh.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from fastapi import FastAPI, HTTPException, Query
  2. from fastapi.responses import StreamingResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from pydantic import BaseModel, field_validator
  5. from typing import Optional
  6. from io import BytesIO
  7. import numpy as np
  8. import soundfile as sf
  9. from kokoro import KPipeline
  10. import threading
  11. # 全局初始化 TTS Pipeline(确保 lang_code 与 voice 匹配)
  12. # 中文女声示例:voice='zf_xiaoxiao',lang_code='z'
  13. pipeline = KPipeline(lang_code='z')
  14. # 优先使用 pipeline 提供的采样率,若无则回退到 24000
  15. sample_rate = getattr(pipeline, "sample_rate", 24000)
  16. # 为了避免底层模型并发问题,使用锁串行化 TTS 推理
  17. synthesis_lock = threading.Lock()
  18. app = FastAPI(title="Online TTS Service (Kokoro)")
  19. # CORS 允许来自 Notebook/浏览器的跨域调用
  20. app.add_middleware(
  21. CORSMiddleware,
  22. allow_origins=["*"], # 生产环境请按需配置
  23. allow_credentials=True,
  24. allow_methods=["*"],
  25. allow_headers=["*"],
  26. )
  27. class TTSRequest(BaseModel):
  28. text: str
  29. voice: Optional[str] = "zf_xiaoxiao"
  30. speed: Optional[float] = 1.0
  31. split_pattern: Optional[str] = r"\n+"
  32. @field_validator("speed")
  33. @classmethod
  34. def check_speed(cls, v):
  35. if v is None:
  36. return 1.0
  37. try:
  38. v = float(v)
  39. except Exception:
  40. raise ValueError("speed 必须为数值")
  41. if v <= 0:
  42. raise ValueError("speed 必须大于 0")
  43. return v
  44. def to_mono_numpy(audio) -> np.ndarray:
  45. """
  46. 将 pipeline 返回的 audio 安全地转换为 numpy 1D float32 单声道数组。
  47. 兼容 numpy.ndarray、PyTorch Tensor 以及其他可转 numpy 的类型。
  48. """
  49. if audio is None:
  50. return np.array([], dtype=np.float32)
  51. # 已是 numpy
  52. if isinstance(audio, np.ndarray):
  53. arr = audio
  54. else:
  55. # 尝试兼容 PyTorch Tensor / 具有 numpy() 的对象
  56. arr = None
  57. # PyTorch Tensor 情况
  58. if hasattr(audio, "detach") and hasattr(audio, "cpu") and hasattr(audio, "numpy"):
  59. try:
  60. arr = audio.detach().cpu().numpy()
  61. except Exception:
  62. arr = None
  63. # 其他框架的 numpy() 情况
  64. if arr is None and hasattr(audio, "numpy") and callable(getattr(audio, "numpy")):
  65. try:
  66. arr = audio.numpy()
  67. except Exception:
  68. arr = None
  69. # 兜底转换
  70. if arr is None:
  71. try:
  72. arr = np.asarray(audio)
  73. except Exception:
  74. # 无法转换,返回空数组以便后续过滤
  75. return np.array([], dtype=np.float32)
  76. # 标准化形状与 dtype
  77. arr = np.asarray(arr)
  78. if arr.size == 0:
  79. return np.array([], dtype=np.float32)
  80. # 常见返回为 [T]、[T, 1] 或 [1, T]
  81. if arr.ndim == 2:
  82. if arr.shape[1] == 1:
  83. arr = arr[:, 0]
  84. elif arr.shape[0] == 1:
  85. arr = arr[0]
  86. else:
  87. # 多声道时做下混为单声道
  88. arr = arr.mean(axis=1)
  89. elif arr.ndim > 2:
  90. # 形状异常时,拉平为 1D
  91. arr = arr.reshape(-1)
  92. if arr.ndim == 0:
  93. arr = arr.reshape(1)
  94. # 转 float32,soundfile 会在写入时转为 PCM_16
  95. if arr.dtype != np.float32:
  96. arr = arr.astype(np.float32, copy=False)
  97. return arr
  98. def synthesize_wav_bytes(
  99. text: str,
  100. voice: str = "zf_xiaoxiao",
  101. speed: float = 1.0,
  102. split_pattern: str = r"\n+",
  103. ) -> BytesIO:
  104. # 生成完整音频并打包为 WAV 字节
  105. segments = []
  106. with synthesis_lock:
  107. generator = pipeline(
  108. text,
  109. voice=voice,
  110. speed=speed,
  111. split_pattern=split_pattern,
  112. )
  113. for _, _, audio in generator:
  114. arr = to_mono_numpy(audio)
  115. if arr.size > 0 and np.isfinite(arr).all():
  116. segments.append(arr)
  117. if not segments:
  118. raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
  119. try:
  120. audio_concat = np.concatenate(segments, axis=0)
  121. except Exception as e:
  122. # 捕获异常并返回清晰错误信息
  123. raise HTTPException(
  124. status_code=500,
  125. detail=f"音频拼接失败:{type(e).__name__}: {str(e)}"
  126. )
  127. buf = BytesIO()
  128. # 以 PCM_16 写入 WAV
  129. sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
  130. buf.seek(0)
  131. return buf
  132. @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
  133. def tts_post(req: TTSRequest):
  134. buf = synthesize_wav_bytes(
  135. text=req.text,
  136. voice=req.voice or "zf_xiaoxiao",
  137. speed=req.speed if req.speed is not None else 1.0,
  138. split_pattern=req.split_pattern or r"\n+",
  139. )
  140. return StreamingResponse(
  141. buf,
  142. media_type="audio/wav",
  143. headers={
  144. "Content-Disposition": 'inline; filename="tts.wav"'
  145. },
  146. )
  147. @app.get("/tts", summary="GET: 传入文本返回 WAV 流(便于直接以 URL 播放)")
  148. def tts_get(
  149. text: str = Query(..., description="待合成文本"),
  150. voice: str = Query("zf_xiaoxiao"),
  151. speed: float = Query(1.0),
  152. split_pattern: str = Query(r"\n+"),
  153. ):
  154. if speed is None or float(speed) <= 0:
  155. raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
  156. buf = synthesize_wav_bytes(
  157. text=text,
  158. voice=voice,
  159. speed=float(speed),
  160. split_pattern=split_pattern,
  161. )
  162. return StreamingResponse(
  163. buf,
  164. media_type="audio/wav",
  165. headers={
  166. "Content-Disposition": 'inline; filename="tts.wav"'
  167. },
  168. )
  169. # 运行:
  170. # uvicorn server:app --host 0.0.0.0 --port 8000 --workers 1
  171. # 建议 workers=1 或者保持串行,避免占用同一设备的并发导致显存/模型冲突
  172. if __name__ == "__main__":
  173. import uvicorn
  174. uvicorn.run("tts_zh:app", host="0.0.0.0", port=18000, reload=False, workers=1)