from fastapi import FastAPI, HTTPException, Query from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, field_validator from typing import Optional from io import BytesIO import numpy as np import soundfile as sf from kokoro import KPipeline import threading # 全局初始化 TTS Pipeline(确保 lang_code 与 voice 匹配) # 中文女声示例:voice='zf_xiaoxiao',lang_code='z' pipeline = KPipeline(lang_code='z') # 优先使用 pipeline 提供的采样率,若无则回退到 24000 sample_rate = getattr(pipeline, "sample_rate", 24000) # 为了避免底层模型并发问题,使用锁串行化 TTS 推理 synthesis_lock = threading.Lock() app = FastAPI(title="Online TTS Service (Kokoro)") # CORS 允许来自 Notebook/浏览器的跨域调用 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境请按需配置 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class TTSRequest(BaseModel): text: str voice: Optional[str] = "zf_xiaoxiao" speed: Optional[float] = 1.0 split_pattern: Optional[str] = r"\n+" @field_validator("speed") @classmethod def check_speed(cls, v): if v is None: return 1.0 try: v = float(v) except Exception: raise ValueError("speed 必须为数值") if v <= 0: raise ValueError("speed 必须大于 0") return v def to_mono_numpy(audio) -> np.ndarray: """ 将 pipeline 返回的 audio 安全地转换为 numpy 1D float32 单声道数组。 兼容 numpy.ndarray、PyTorch Tensor 以及其他可转 numpy 的类型。 """ if audio is None: return np.array([], dtype=np.float32) # 已是 numpy if isinstance(audio, np.ndarray): arr = audio else: # 尝试兼容 PyTorch Tensor / 具有 numpy() 的对象 arr = None # PyTorch Tensor 情况 if hasattr(audio, "detach") and hasattr(audio, "cpu") and hasattr(audio, "numpy"): try: arr = audio.detach().cpu().numpy() except Exception: arr = None # 其他框架的 numpy() 情况 if arr is None and hasattr(audio, "numpy") and callable(getattr(audio, "numpy")): try: arr = audio.numpy() except Exception: arr = None # 兜底转换 if arr is None: try: arr = np.asarray(audio) except Exception: # 无法转换,返回空数组以便后续过滤 return np.array([], dtype=np.float32) # 标准化形状与 dtype arr = np.asarray(arr) if arr.size == 0: return np.array([], dtype=np.float32) # 常见返回为 [T]、[T, 1] 或 [1, T] if arr.ndim == 2: if arr.shape[1] == 1: arr = arr[:, 0] elif arr.shape[0] == 1: arr = arr[0] else: # 多声道时做下混为单声道 arr = arr.mean(axis=1) elif arr.ndim > 2: # 形状异常时,拉平为 1D arr = arr.reshape(-1) if arr.ndim == 0: arr = arr.reshape(1) # 转 float32,soundfile 会在写入时转为 PCM_16 if arr.dtype != np.float32: arr = arr.astype(np.float32, copy=False) return arr def synthesize_wav_bytes( text: str, voice: str = "zf_xiaoxiao", speed: float = 1.0, split_pattern: str = r"\n+", ) -> BytesIO: # 生成完整音频并打包为 WAV 字节 segments = [] with synthesis_lock: generator = pipeline( text, voice=voice, speed=speed, split_pattern=split_pattern, ) for _, _, audio in generator: arr = to_mono_numpy(audio) if arr.size > 0 and np.isfinite(arr).all(): segments.append(arr) if not segments: raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。") try: audio_concat = np.concatenate(segments, axis=0) except Exception as e: # 捕获异常并返回清晰错误信息 raise HTTPException( status_code=500, detail=f"音频拼接失败:{type(e).__name__}: {str(e)}" ) buf = BytesIO() # 以 PCM_16 写入 WAV sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16") buf.seek(0) return buf @app.post("/tts", summary="POST: 传入文本返回 WAV 流") def tts_post(req: TTSRequest): buf = synthesize_wav_bytes( text=req.text, voice=req.voice or "zf_xiaoxiao", speed=req.speed if req.speed is not None else 1.0, split_pattern=req.split_pattern or r"\n+", ) return StreamingResponse( buf, media_type="audio/wav", headers={ "Content-Disposition": 'inline; filename="tts.wav"' }, ) @app.get("/tts", summary="GET: 传入文本返回 WAV 流(便于直接以 URL 播放)") def tts_get( text: str = Query(..., description="待合成文本"), voice: str = Query("zf_xiaoxiao"), speed: float = Query(1.0), split_pattern: str = Query(r"\n+"), ): if speed is None or float(speed) <= 0: raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值") buf = synthesize_wav_bytes( text=text, voice=voice, speed=float(speed), split_pattern=split_pattern, ) return StreamingResponse( buf, media_type="audio/wav", headers={ "Content-Disposition": 'inline; filename="tts.wav"' }, ) # 运行: # uvicorn server:app --host 0.0.0.0 --port 8000 --workers 1 # 建议 workers=1 或者保持串行,避免占用同一设备的并发导致显存/模型冲突 if __name__ == "__main__": import uvicorn uvicorn.run("tts_zh:app", host="0.0.0.0", port=18000, reload=False, workers=1)