from __future__ import annotations import base64 import hashlib import io import json import logging import os import re import asyncio import threading import uuid import urllib.request from contextlib import asynccontextmanager from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import soundfile as sf from fastapi import Body, FastAPI, HTTPException, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, field_validator # ============================================================ # 配置 # ============================================================ DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx") MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx") HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX") TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json")) CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json")) VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices")) VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin")) CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8)) MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200)) DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500)) DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000)) VOICE_ALIASES = {} logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO)) logger = logging.getLogger("speech_tts_onnx") Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+") _TOKENIZER_CACHE = None _VOCAB_CACHE = None _EN_G2P_PIPELINE = None _KOKORO_ONNX_ENGINE = None # ============================================================ # 运行时状态 # ============================================================ model_lock = threading.Lock() synthesis_lock = threading.Lock() request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) current_requests: Dict[str, Dict] = {} model_session = None model_name = DEFAULT_MODEL_NAME sample_rate = DEFAULT_SAMPLE_RATE # ============================================================ # 工具 # ============================================================ def split_sentences(text: str) -> List[str]: return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()] def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str: raw = f"{sentence}|{voice}|{speed}|{model}" return hashlib.md5(raw.encode("utf-8")).hexdigest() def sentence_cache_path(key: str) -> str: return os.path.join(CACHE_DIR, f"{key}.wav") def meta_cache_path(key: str) -> str: return os.path.join(CACHE_DIR, f"{key}.json") def to_mono_numpy(audio) -> np.ndarray: if audio is None: return np.array([], dtype=np.float32) if isinstance(audio, np.ndarray): arr = audio else: try: arr = np.asarray(audio) except Exception: return np.array([], dtype=np.float32) arr = np.asarray(arr) if arr.size == 0: return np.array([], dtype=np.float32) if arr.ndim == 2: if arr.shape[0] == 1: arr = arr[0] elif arr.shape[1] == 1: arr = arr[:, 0] else: arr = arr.mean(axis=1) elif arr.ndim > 2: arr = arr.reshape(-1) if arr.ndim == 0: arr = arr.reshape(1) if arr.dtype != np.float32: arr = arr.astype(np.float32, copy=False) return arr def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]: audio, sr = sf.read(wav_path) buf = io.BytesIO() with sf.SoundFile( buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16", ) as f: f.write(audio) return buf.getvalue(), sr # ============================================================ # 模型加载 # ============================================================ def resolve_model_path(name: str) -> str: local_path = Path(MODEL_DIR) / name if local_path.exists(): return str(local_path) if os.path.isabs(name) and Path(name).exists(): return name raise FileNotFoundError( f"找不到模型文件: {local_path}" ) def load_vocab() -> Dict[str, int]: """ 优先使用 Kokoro 官方 config.json 的 vocab。 如果本地缺失,则回退到 tokenizer.json 里的 vocab。 """ config_file = Path(CONFIG_PATH) if config_file.exists(): try: data = json.loads(config_file.read_text(encoding="utf-8")) vocab = data["vocab"] if isinstance(vocab, dict) and vocab: return {str(k): int(v) for k, v in vocab.items()} except Exception as e: logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e) tokenizer_file = Path(TOKENIZER_PATH) if not tokenizer_file.exists(): raise RuntimeError( f"缺少 vocab 文件: {config_file} / {tokenizer_file}. " f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。" ) try: data = json.loads(tokenizer_file.read_text(encoding="utf-8")) vocab = data["model"]["vocab"] if not isinstance(vocab, dict) or not vocab: raise ValueError("tokenizer vocab 为空或格式异常") return {str(k): int(v) for k, v in vocab.items()} except Exception as e: raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e def get_vocab() -> Dict[str, int]: global _VOCAB_CACHE if _VOCAB_CACHE is None: _VOCAB_CACHE = load_vocab() return _VOCAB_CACHE def get_en_g2p_pipeline(): global _EN_G2P_PIPELINE if _EN_G2P_PIPELINE is None: from kokoro import KPipeline # type: ignore _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False) return _EN_G2P_PIPELINE def resolve_voice_path(voice: str) -> Path: voice_name = voice.strip() if not voice_name: voice_name = "af_heart" voice_name = VOICE_ALIASES.get(voice_name, voice_name) if not voice_name.endswith(".bin"): voice_name = f"{voice_name}.bin" voice_path = Path(VOICES_DIR) / voice_name if voice_path.exists(): return voice_path fallback_path = Path(VOICES_DIR) / "af_bella.bin" if fallback_path.exists(): logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name) return fallback_path raise FileNotFoundError( f"找不到 voice 文件: {voice_path}. " "请从模型仓库下载 voices/*.bin 到本地 voices 目录。" ) def _download_voice_file(voice_path: Path) -> Path: voice_name = voice_path.name raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}" voice_path.parent.mkdir(parents=True, exist_ok=True) tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp") logger.info("从官方地址下载 voice 文件: %s", raw_url) with urllib.request.urlopen(raw_url, timeout=60) as response: content_type = response.headers.get_content_type() payload = response.read() if content_type == "text/html" or payload.startswith(b" bool: return data.startswith(b" List[int]: vocab = get_vocab() unknown_id = vocab.get(" ", 16) return [vocab.get(ch, unknown_id) for ch in text] def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]: pipeline = get_en_g2p_pipeline() _, tokens = pipeline.g2p(text) chunks = list(pipeline.en_tokenize(tokens)) if not chunks: raise RuntimeError("英文文本音素化失败,未生成 phonemes。") return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks] def _phonemize_en_for_onnx(text: str) -> str: return _phonemize_en_chunks_for_onnx(text)[0][1] def _tokenize_for_onnx(text: str, voice: str): """ 将文本转成 ONNX 所需 token。 这里优先复用 kokoro 的预处理链。 如果你项目里有更稳定的 tokenizer,可以替换这个函数。 """ normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip()) if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")): phonemes = _phonemize_en_for_onnx(text) tokens = _encode_token_ids(phonemes) else: tokens = _encode_token_ids(text) if not tokens: raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。") if len(tokens) > 510: raise RuntimeError( f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。" ) return np.asarray([[0, *tokens, 0]], dtype=np.int64) def _is_english_voice(voice: str) -> bool: normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip()) return normalized_voice.startswith(("af_", "am_", "bf_", "bm_")) def _style_for_voice(voice: str) -> np.ndarray: """ ONNX 模型需要 style 向量。 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。 """ voice_path = resolve_voice_path(voice) if voice_path.suffix == ".bin": try: header = voice_path.read_bytes()[:64] if _is_html_payload(header): fallback_path = Path(VOICES_DIR) / "af_bella.bin" if voice_path.name != fallback_path.name and fallback_path.exists(): logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name) voice_path = fallback_path else: voice_path = _download_voice_file(voice_path) except Exception as e: logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e) fallback_path = Path(VOICES_DIR) / "af_bella.bin" if voice_path.name != fallback_path.name and fallback_path.exists(): voice_path = fallback_path else: voice_path = _download_voice_file(voice_path) style = np.fromfile(str(voice_path), dtype=np.float32) if style.size == 0: raise RuntimeError(f"voice 文件为空: {voice_path}") if style.size % 256 != 0: # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。 voice_path = _download_voice_file(voice_path) style = np.fromfile(str(voice_path), dtype=np.float32) if style.size == 0 or style.size % 256 != 0: raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}") style = style.reshape(-1, 256) return style def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray: if style.ndim != 2 or style.shape[-1] != 256: raise RuntimeError(f"style 维度异常: {style.shape}") # Follow the official ONNX example: ref_s = voices[len(tokens)] idx = min(max(token_len, 0), style.shape[0] - 1) return style[idx : idx + 1] def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]: for input_name in ("style", "ref_s"): for model_input in session.get_inputs(): if model_input.name != input_name: continue input_shape = model_input.shape expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim if expected_rank == 2: return input_name, style_slice.astype(np.float32, copy=False) if expected_rank == 3: return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False) raise RuntimeError( f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}" ) return None, None def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray: if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") engine = get_kokoro_engine(name=model_name) session = load_model(name=model_name) available_voices = set(engine.get_voices()) if voice not in available_voices: raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}") phonemes = engine.tokenizer.phonemize(text, "en-us") batched_phonemes = engine._split_phonemes(phonemes) if not batched_phonemes: raise HTTPException(status_code=400, detail="文本音素化失败") voice_style = engine.get_voice_style(voice) audio_segments: List[np.ndarray] = [] for phoneme_batch in batched_phonemes: tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64) if tokens.size == 0: continue style = voice_style[len(tokens)] feeds = { "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64), "style": np.asarray(style, dtype=np.float32), "speed": np.asarray([speed], dtype=np.float32), } outputs = session.run(None, feeds) if outputs: audio_segments.append(to_mono_numpy(outputs[0])) if not audio_segments: raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出") audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0] if audio.size == 0 or not np.isfinite(audio).all(): raise HTTPException(status_code=500, detail="生成的音频无效") return audio def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO: segments: List[np.ndarray] = [] if _is_english_voice(voice): parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else [] else: parts = split_sentences(text) if text.strip() else [] if not parts: parts = [text.strip()] with synthesis_lock: for part in parts: audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name) if audio.size > 0: segments.append(audio) if not segments: raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。") audio_concat = np.concatenate(segments, axis=0) buf = io.BytesIO() sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16") buf.seek(0) return buf @app.post("/tts", summary="POST: 传入文本返回 WAV 流") def tts_post(req: TTSRequest): buf = synthesize_wav_bytes( text=req.text, voice=req.voice or "af_heart", speed=req.speed if req.speed is not None else 1.0, model_name=req.model_name, ) return StreamingResponse( buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'}, ) @app.get("/tts", summary="GET: 传入文本返回 WAV 流") def tts_get( text: str = Query(..., description="待合成文本"), voice: str = Query("af_heart"), speed: float = Query(1.0), model_name: str = Query(DEFAULT_MODEL_NAME), ): if speed is None or float(speed) <= 0: raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值") buf = synthesize_wav_bytes( text=text, voice=voice, speed=float(speed), model_name=model_name, ) return StreamingResponse( buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'}, ) @app.post("/generate") async def generate_audio_stream(data: Dict = Body(...)): async with request_semaphore: text = data.get("text", "") voice = data.get("voice", "af_heart") speed = float(data.get("speed", 1.0)) model = data.get("model_name", DEFAULT_MODEL_NAME) client_id = data.get("client_id", str(uuid.uuid4())) if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") if _is_english_voice(voice): parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else [] else: parts = split_sentences(text) if text.strip() else [] if not parts: parts = [text.strip()] if client_id in current_requests: current_requests[client_id]["interrupt"] = True await asyncio.sleep(0.05) current_requests[client_id] = {"interrupt": False} async def stream(): try: for idx, part in enumerate(parts): if not part: continue if current_requests.get(client_id, {}).get("interrupt"): break audio = await asyncio.to_thread( synthesize_audio, text=part, voice=voice, speed=speed, model_name=model, ) with io.BytesIO() as buf: with sf.SoundFile( buf, "w", sample_rate, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16", ) as f: f.write(audio) audio_b64 = base64.b64encode(buf.getvalue()).decode() yield json.dumps( { "index": idx, "sentence": part, "audio": audio_b64, "sample_rate": sample_rate, } ).encode() + b"\n" finally: current_requests.pop(client_id, None) return StreamingResponse(stream(), media_type="application/x-ndjson") if __name__ == "__main__": import uvicorn uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)