| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- from fastapi import FastAPI, HTTPException, Query
- from fastapi.responses import StreamingResponse
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel, field_validator
- from typing import Optional
- from io import BytesIO
- import numpy as np
- import soundfile as sf
- from kokoro import KPipeline
- import threading
- # 全局初始化 TTS Pipeline(确保 lang_code 与 voice 匹配)
- # 中文女声示例:voice='zf_xiaoxiao',lang_code='z'
- pipeline = KPipeline(lang_code='z')
- # 优先使用 pipeline 提供的采样率,若无则回退到 24000
- sample_rate = getattr(pipeline, "sample_rate", 24000)
- # 为了避免底层模型并发问题,使用锁串行化 TTS 推理
- synthesis_lock = threading.Lock()
- app = FastAPI(title="Online TTS Service (Kokoro)")
- # CORS 允许来自 Notebook/浏览器的跨域调用
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"], # 生产环境请按需配置
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class TTSRequest(BaseModel):
- text: str
- voice: Optional[str] = "zf_xiaoxiao"
- speed: Optional[float] = 1.0
- split_pattern: Optional[str] = r"\n+"
- @field_validator("speed")
- @classmethod
- def check_speed(cls, v):
- if v is None:
- return 1.0
- try:
- v = float(v)
- except Exception:
- raise ValueError("speed 必须为数值")
- if v <= 0:
- raise ValueError("speed 必须大于 0")
- return v
- def to_mono_numpy(audio) -> np.ndarray:
- """
- 将 pipeline 返回的 audio 安全地转换为 numpy 1D float32 单声道数组。
- 兼容 numpy.ndarray、PyTorch Tensor 以及其他可转 numpy 的类型。
- """
- if audio is None:
- return np.array([], dtype=np.float32)
- # 已是 numpy
- if isinstance(audio, np.ndarray):
- arr = audio
- else:
- # 尝试兼容 PyTorch Tensor / 具有 numpy() 的对象
- arr = None
- # PyTorch Tensor 情况
- if hasattr(audio, "detach") and hasattr(audio, "cpu") and hasattr(audio, "numpy"):
- try:
- arr = audio.detach().cpu().numpy()
- except Exception:
- arr = None
- # 其他框架的 numpy() 情况
- if arr is None and hasattr(audio, "numpy") and callable(getattr(audio, "numpy")):
- try:
- arr = audio.numpy()
- except Exception:
- arr = None
- # 兜底转换
- if arr is None:
- try:
- arr = np.asarray(audio)
- except Exception:
- # 无法转换,返回空数组以便后续过滤
- return np.array([], dtype=np.float32)
- # 标准化形状与 dtype
- arr = np.asarray(arr)
- if arr.size == 0:
- return np.array([], dtype=np.float32)
- # 常见返回为 [T]、[T, 1] 或 [1, T]
- if arr.ndim == 2:
- if arr.shape[1] == 1:
- arr = arr[:, 0]
- elif arr.shape[0] == 1:
- arr = arr[0]
- else:
- # 多声道时做下混为单声道
- arr = arr.mean(axis=1)
- elif arr.ndim > 2:
- # 形状异常时,拉平为 1D
- arr = arr.reshape(-1)
- if arr.ndim == 0:
- arr = arr.reshape(1)
- # 转 float32,soundfile 会在写入时转为 PCM_16
- if arr.dtype != np.float32:
- arr = arr.astype(np.float32, copy=False)
- return arr
- def synthesize_wav_bytes(
- text: str,
- voice: str = "zf_xiaoxiao",
- speed: float = 1.0,
- split_pattern: str = r"\n+",
- ) -> BytesIO:
- # 生成完整音频并打包为 WAV 字节
- segments = []
- with synthesis_lock:
- generator = pipeline(
- text,
- voice=voice,
- speed=speed,
- split_pattern=split_pattern,
- )
- for _, _, audio in generator:
- arr = to_mono_numpy(audio)
- if arr.size > 0 and np.isfinite(arr).all():
- segments.append(arr)
- if not segments:
- raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
- try:
- audio_concat = np.concatenate(segments, axis=0)
- except Exception as e:
- # 捕获异常并返回清晰错误信息
- raise HTTPException(
- status_code=500,
- detail=f"音频拼接失败:{type(e).__name__}: {str(e)}"
- )
- buf = BytesIO()
- # 以 PCM_16 写入 WAV
- sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
- buf.seek(0)
- return buf
- @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
- def tts_post(req: TTSRequest):
- buf = synthesize_wav_bytes(
- text=req.text,
- voice=req.voice or "zf_xiaoxiao",
- speed=req.speed if req.speed is not None else 1.0,
- split_pattern=req.split_pattern or r"\n+",
- )
- return StreamingResponse(
- buf,
- media_type="audio/wav",
- headers={
- "Content-Disposition": 'inline; filename="tts.wav"'
- },
- )
- @app.get("/tts", summary="GET: 传入文本返回 WAV 流(便于直接以 URL 播放)")
- def tts_get(
- text: str = Query(..., description="待合成文本"),
- voice: str = Query("zf_xiaoxiao"),
- speed: float = Query(1.0),
- split_pattern: str = Query(r"\n+"),
- ):
- if speed is None or float(speed) <= 0:
- raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
- buf = synthesize_wav_bytes(
- text=text,
- voice=voice,
- speed=float(speed),
- split_pattern=split_pattern,
- )
- return StreamingResponse(
- buf,
- media_type="audio/wav",
- headers={
- "Content-Disposition": 'inline; filename="tts.wav"'
- },
- )
- # 运行:
- # uvicorn server:app --host 0.0.0.0 --port 8000 --workers 1
- # 建议 workers=1 或者保持串行,避免占用同一设备的并发导致显存/模型冲突
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("tts_zh:app", host="0.0.0.0", port=18000, reload=False, workers=1)
|