from __future__ import annotations import base64 import hashlib import io import json import logging import os import re import asyncio import threading import uuid import urllib.request from contextlib import asynccontextmanager from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import soundfile as sf import aiofiles from fastapi import Body, FastAPI, HTTPException, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, field_validator from cachetools import TTLCache import concurrent.futures # ============================================================ # 配置 model_quantized # ============================================================ DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx") MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx") HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX") TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json")) CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json")) VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices")) VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin")) CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8)) MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200)) DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500)) DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000)) MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000)) ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2")) ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1")) ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true" ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "true").lower() == "true" VOICE_ALIASES = {} logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO)) logger = logging.getLogger("speech_tts_onnx") class _PhonemizerWordCountMismatchFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: return "words count mismatch" not in record.getMessage() logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter()) Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+") _TOKENIZER_CACHE = None _VOCAB_CACHE = None _EN_G2P_PIPELINE = None _KOKORO_ONNX_ENGINE = None # ============================================================ # 运行时状态 # ============================================================ model_lock = threading.Lock() synthesis_lock = threading.Lock() cache_lock = threading.Lock() request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) current_requests: Dict[str, Dict] = {} memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=MEMORY_CACHE_TTL) executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) model_session = None model_name = DEFAULT_MODEL_NAME sample_rate = DEFAULT_SAMPLE_RATE # ============================================================ # 工具 # ============================================================ def split_sentences(text: str) -> List[str]: return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()] def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str: raw = f"{sentence}|{voice}|{speed}|{model}" return hashlib.md5(raw.encode("utf-8")).hexdigest() def sentence_cache_path(key: str) -> str: return os.path.join(CACHE_DIR, f"{key}.wav") def meta_cache_path(key: str) -> str: return os.path.join(CACHE_DIR, f"{key}.json") def put_memory_cache(key: str, value: dict): with cache_lock: memory_cache[key] = value def get_memory_cache(key: str) -> Optional[dict]: with cache_lock: return memory_cache.get(key) def to_mono_numpy(audio) -> np.ndarray: if audio is None: return np.array([], dtype=np.float32) if isinstance(audio, np.ndarray): arr = audio else: try: arr = np.asarray(audio) except Exception: return np.array([], dtype=np.float32) arr = np.asarray(arr) if arr.size == 0: return np.array([], dtype=np.float32) if arr.ndim == 2: if arr.shape[0] == 1: arr = arr[0] elif arr.shape[1] == 1: arr = arr[:, 0] else: arr = arr.mean(axis=1) elif arr.ndim > 2: arr = arr.reshape(-1) if arr.ndim == 0: arr = arr.reshape(1) if arr.dtype != np.float32: arr = arr.astype(np.float32, copy=False) return arr def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]: audio, sr = sf.read(wav_path) buf = io.BytesIO() with sf.SoundFile( buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16", ) as f: f.write(audio) return buf.getvalue(), sr async def load_sentence_from_disk(key: str) -> Optional[dict]: wav_path = sentence_cache_path(key) meta_path = meta_cache_path(key) if not Path(wav_path).exists() or not Path(meta_path).exists(): return None async with aiofiles.open(meta_path, "r") as f: meta_text = await f.read() meta = json.loads(meta_text) wav_bytes, sr = await asyncio.get_event_loop().run_in_executor( executor, lambda: read_wav_bytes(wav_path) ) meta["audio_bytes"] = wav_bytes meta["audio"] = base64.b64encode(wav_bytes).decode() meta["sample_rate"] = sr return meta async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str): wav_path = sentence_cache_path(key) meta_path = meta_cache_path(key) await asyncio.get_event_loop().run_in_executor( executor, lambda: sf.write(wav_path, audio, sr, format="WAV") ) async with aiofiles.open(meta_path, "w") as f: await f.write(json.dumps({"sentence": sentence, "sample_rate": sr})) async def clean_disk_cache(): files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime) if len(files) <= DISK_CACHE_SIZE: return remove_n = len(files) - DISK_CACHE_SIZE for f in files[:remove_n]: try: f.unlink() json_path = meta_cache_path(f.stem) if Path(json_path).exists(): Path(json_path).unlink() except Exception: pass def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes: buf = io.BytesIO() with sf.SoundFile( buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16", ) as f: f.write(audio) return buf.getvalue() def wav_to_pcm_payload(wav_bytes: bytes) -> bytes: data, _ = sf.read(io.BytesIO(wav_bytes), dtype="int16") return np.asarray(data, dtype=np.int16).tobytes() # ============================================================ # 模型加载 # ============================================================ def resolve_model_path(name: str) -> str: local_path = Path(MODEL_DIR) / name if local_path.exists(): return str(local_path) if os.path.isabs(name) and Path(name).exists(): return name raise FileNotFoundError( f"找不到模型文件: {local_path}" ) def load_vocab() -> Dict[str, int]: """ 优先使用 Kokoro 官方 config.json 的 vocab。 如果本地缺失,则回退到 tokenizer.json 里的 vocab。 """ config_file = Path(CONFIG_PATH) if config_file.exists(): try: data = json.loads(config_file.read_text(encoding="utf-8")) vocab = data["vocab"] if isinstance(vocab, dict) and vocab: return {str(k): int(v) for k, v in vocab.items()} except Exception as e: logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e) tokenizer_file = Path(TOKENIZER_PATH) if not tokenizer_file.exists(): raise RuntimeError( f"缺少 vocab 文件: {config_file} / {tokenizer_file}. " f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。" ) try: data = json.loads(tokenizer_file.read_text(encoding="utf-8")) vocab = data["model"]["vocab"] if not isinstance(vocab, dict) or not vocab: raise ValueError("tokenizer vocab 为空或格式异常") return {str(k): int(v) for k, v in vocab.items()} except Exception as e: raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e def get_vocab() -> Dict[str, int]: global _VOCAB_CACHE if _VOCAB_CACHE is None: _VOCAB_CACHE = load_vocab() return _VOCAB_CACHE def get_en_g2p_pipeline(): global _EN_G2P_PIPELINE if _EN_G2P_PIPELINE is None: from kokoro import KPipeline # type: ignore _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False) return _EN_G2P_PIPELINE def resolve_voice_path(voice: str) -> Path: voice_name = voice.strip() if not voice_name: voice_name = "af_heart" voice_name = VOICE_ALIASES.get(voice_name, voice_name) if not voice_name.endswith(".bin"): voice_name = f"{voice_name}.bin" voice_path = Path(VOICES_DIR) / voice_name if voice_path.exists(): return voice_path fallback_path = Path(VOICES_DIR) / "af_bella.bin" if fallback_path.exists(): logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name) return fallback_path raise FileNotFoundError( f"找不到 voice 文件: {voice_path}. " "请从模型仓库下载 voices/*.bin 到本地 voices 目录。" ) def _download_voice_file(voice_path: Path) -> Path: voice_name = voice_path.name raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}" voice_path.parent.mkdir(parents=True, exist_ok=True) tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp") logger.info("从官方地址下载 voice 文件: %s", raw_url) with urllib.request.urlopen(raw_url, timeout=60) as response: content_type = response.headers.get_content_type() payload = response.read() if content_type == "text/html" or payload.startswith(b" bool: return data.startswith(b" List[int]: vocab = get_vocab() unknown_id = vocab.get(" ", 16) return [vocab.get(ch, unknown_id) for ch in text] def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]: pipeline = get_en_g2p_pipeline() _, tokens = pipeline.g2p(text) chunks = list(pipeline.en_tokenize(tokens)) if not chunks: raise RuntimeError("英文文本音素化失败,未生成 phonemes。") return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks] def _phonemize_en_for_onnx(text: str) -> str: return _phonemize_en_chunks_for_onnx(text)[0][1] def _tokenize_for_onnx(text: str, voice: str): """ 将文本转成 ONNX 所需 token。 这里优先复用 kokoro 的预处理链。 如果你项目里有更稳定的 tokenizer,可以替换这个函数。 """ normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip()) if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")): phonemes = _phonemize_en_for_onnx(text) tokens = _encode_token_ids(phonemes) else: tokens = _encode_token_ids(text) if not tokens: raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。") if len(tokens) > 510: raise RuntimeError( f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。" ) return np.asarray([[0, *tokens, 0]], dtype=np.int64) def _is_english_voice(voice: str) -> bool: normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip()) return normalized_voice.startswith(("af_", "am_", "bf_", "bm_")) def _style_for_voice(voice: str) -> np.ndarray: """ ONNX 模型需要 style 向量。 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。 """ voice_path = resolve_voice_path(voice) if voice_path.suffix == ".bin": try: header = voice_path.read_bytes()[:64] if _is_html_payload(header): fallback_path = Path(VOICES_DIR) / "af_bella.bin" if voice_path.name != fallback_path.name and fallback_path.exists(): logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name) voice_path = fallback_path else: voice_path = _download_voice_file(voice_path) except Exception as e: logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e) fallback_path = Path(VOICES_DIR) / "af_bella.bin" if voice_path.name != fallback_path.name and fallback_path.exists(): voice_path = fallback_path else: voice_path = _download_voice_file(voice_path) style = np.fromfile(str(voice_path), dtype=np.float32) if style.size == 0: raise RuntimeError(f"voice 文件为空: {voice_path}") if style.size % 256 != 0: # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。 voice_path = _download_voice_file(voice_path) style = np.fromfile(str(voice_path), dtype=np.float32) if style.size == 0 or style.size % 256 != 0: raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}") style = style.reshape(-1, 256) return style def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray: if style.ndim != 2 or style.shape[-1] != 256: raise RuntimeError(f"style 维度异常: {style.shape}") # Follow the official ONNX example: ref_s = voices[len(tokens)] idx = min(max(token_len, 0), style.shape[0] - 1) return style[idx : idx + 1] def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]: for input_name in ("style", "ref_s"): for model_input in session.get_inputs(): if model_input.name != input_name: continue input_shape = model_input.shape expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim if expected_rank == 2: return input_name, style_slice.astype(np.float32, copy=False) if expected_rank == 3: return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False) raise RuntimeError( f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}" ) return None, None def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray: if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") engine = get_kokoro_engine(name=model_name) session = load_model(name=model_name) available_voices = set(engine.get_voices()) if voice not in available_voices: raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}") phonemes = engine.tokenizer.phonemize(text, "en-us") batched_phonemes = engine._split_phonemes(phonemes) if not batched_phonemes: raise HTTPException(status_code=400, detail="文本音素化失败") voice_style = engine.get_voice_style(voice) audio_segments: List[np.ndarray] = [] for phoneme_batch in batched_phonemes: tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64) if tokens.size == 0: continue style = voice_style[len(tokens)] feeds = { "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64), "style": np.asarray(style, dtype=np.float32), "speed": np.asarray([speed], dtype=np.float32), } outputs = session.run(None, feeds) if outputs: audio_segments.append(to_mono_numpy(outputs[0])) if not audio_segments: raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出") audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0] if audio.size == 0 or not np.isfinite(audio).all(): raise HTTPException(status_code=500, detail="生成的音频无效") return audio def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO: audio = synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name) buf = io.BytesIO() sf.write(buf, audio, samplerate=sample_rate, format="WAV", subtype="PCM_16") buf.seek(0) return buf def iter_text_parts(text: str, split_pattern: Optional[str] = None) -> List[str]: text = (text or "").strip() if not text: return [] blocks = [text] if split_pattern: try: blocks = [part.strip() for part in re.split(split_pattern, text) if part.strip()] except re.error: logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern) parts: List[str] = [] for block in blocks: sentences = split_sentences(block) if sentences: parts.extend(sentences) elif block: parts.append(block) return parts async def get_or_create_sentence_cache_item( sentence: str, voice: str, speed: float, model_name: str, ) -> dict: key = sentence_cache_key(sentence, voice, speed, model_name) cached_item = get_memory_cache(key) if cached_item: return cached_item disk_item = await load_sentence_from_disk(key) if disk_item: put_memory_cache(key, disk_item) return disk_item def _synthesize_item(): audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model_name) wav_bytes = encode_wav_bytes(audio, sample_rate) return { "sentence": sentence, "sample_rate": sample_rate, "audio_bytes": wav_bytes, "audio": base64.b64encode(wav_bytes).decode(), }, audio item, audio = await asyncio.to_thread(_synthesize_item) put_memory_cache(key, item) await save_sentence_to_disk(key, audio, sample_rate, sentence) await clean_disk_cache() return item @app.post("/tts", summary="POST: 传入文本返回 WAV 流") def tts_post(req: TTSRequest): buf = synthesize_wav_bytes( text=req.text, voice=req.voice or "af_heart", speed=req.speed if req.speed is not None else 1.0, model_name=req.model_name, ) return StreamingResponse( buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'}, ) @app.get("/tts", summary="GET: 传入文本返回 WAV 流") def tts_get( text: str = Query(..., description="待合成文本"), voice: str = Query("af_heart"), speed: float = Query(1.0), model_name: str = Query(DEFAULT_MODEL_NAME), ): if speed is None or float(speed) <= 0: raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值") buf = synthesize_wav_bytes( text=text, voice=voice, speed=float(speed), model_name=model_name, ) return StreamingResponse( buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'}, ) @app.post("/generate") async def generate_audio_stream(data: Dict = Body(...)): async with request_semaphore: text = data.get("text", "") voice = data.get("voice", "af_heart") speed = float(data.get("speed", 1.0)) model = data.get("model_name", DEFAULT_MODEL_NAME) client_id = data.get("client_id", str(uuid.uuid4())) if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") parts = iter_text_parts(text, data.get("split_pattern", r"\n+")) if client_id in current_requests: current_requests[client_id]["interrupt"] = True await asyncio.sleep(0.05) current_requests[client_id] = {"interrupt": False} async def stream(): try: for idx, part in enumerate(parts): if current_requests.get(client_id, {}).get("interrupt"): break item = await get_or_create_sentence_cache_item(part, voice, speed, model) yield json.dumps( { "index": idx, "sentence": item["sentence"], "audio": item["audio"], "sample_rate": item["sample_rate"], } ).encode() + b"\n" finally: current_requests.pop(client_id, None) return StreamingResponse(stream(), media_type="application/x-ndjson") @app.post("/generate_pcm", summary="POST: 返回更轻量的 PCM 分片流") async def generate_audio_pcm_stream(data: Dict = Body(...)): async with request_semaphore: text = data.get("text", "") voice = data.get("voice", "af_heart") speed = float(data.get("speed", 1.0)) model = data.get("model_name", DEFAULT_MODEL_NAME) if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") parts = iter_text_parts(text, data.get("split_pattern", r"\n+")) async def pcm_stream(): for idx, part in enumerate(parts): item = await get_or_create_sentence_cache_item(part, voice, speed, model) payload = wav_to_pcm_payload(item["audio_bytes"]) header = json.dumps( { "index": idx, "sentence": item["sentence"], "sample_rate": item["sample_rate"], "format": "s16le", "bytes": len(payload), } ).encode() + b"\n" yield header yield payload return StreamingResponse( pcm_stream(), media_type="application/octet-stream", headers={"X-Audio-Format": "s16le", "X-Sample-Rate": str(sample_rate)}, ) @app.get("/clear-cache") async def clear_cache(): with cache_lock: memory_cache.clear() for f in Path(CACHE_DIR).glob("*"): f.unlink() return {"status": "success"} @app.get("/cache-info") async def get_cache_info(): with cache_lock: mem_count = len(memory_cache) disk_files = list(Path(CACHE_DIR).glob("*.wav")) return {"memory_cache": mem_count, "disk_cache": len(disk_files)} if __name__ == "__main__": import uvicorn uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)