from __future__ import annotations import asyncio import base64 import concurrent.futures import hashlib import io import json import logging import os import re import threading import time import uuid import urllib.request from collections import OrderedDict from contextlib import asynccontextmanager from pathlib import Path from typing import Dict, List, Optional, Tuple import aiofiles import numpy as np import soundfile as sf from fastapi import Body, FastAPI, HTTPException, Query, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, field_validator DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model_uint8.onnx") MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx") HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX") CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json")) VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices")) VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin")) CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8)) DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000)) MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000)) MEMORY_CACHE_MAX_ITEMS = int(os.getenv("MEMORY_CACHE_SIZE", 120)) MEMORY_CACHE_MAX_BYTES = int(os.getenv("MEMORY_CACHE_MAX_BYTES", str(128 * 1024 * 1024))) DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500)) ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2")) ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1")) ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true" ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "false").lower() == "true" VOICE_ALIASES: Dict[str, str] = {} logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO)) logger = logging.getLogger("speech_tts_onnx_opt") Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) SENT_SPLIT_RE = re.compile(r"(?<=[。!?;.!?;::,,])\s*|\n+") class _PhonemizerWordCountMismatchFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: return "words count mismatch" not in record.getMessage() logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter()) class MemoryAudioCache: def __init__(self, max_items: int, max_bytes: int, ttl: int): self.max_items = max_items self.max_bytes = max_bytes self.ttl = ttl self.lock = threading.Lock() self.items: "OrderedDict[str, dict]" = OrderedDict() self.total_bytes = 0 def _entry_size(self, value: dict) -> int: return len(value.get("audio_bytes", b"")) + len(value.get("sentence", "").encode("utf-8")) + 128 def _purge_expired(self, now: float): expired = [k for k, v in self.items.items() if now - v["ts"] > self.ttl] for key in expired: entry = self.items.pop(key, None) if entry: self.total_bytes -= entry["size"] def get(self, key: str) -> Optional[dict]: now = time.time() with self.lock: self._purge_expired(now) entry = self.items.get(key) if not entry: return None self.items.move_to_end(key) entry["ts"] = now return { "sentence": entry["sentence"], "sample_rate": entry["sample_rate"], "audio_bytes": entry["audio_bytes"], } def set(self, key: str, value: dict): now = time.time() with self.lock: self._purge_expired(now) old = self.items.pop(key, None) if old: self.total_bytes -= old["size"] entry = { "sentence": value["sentence"], "sample_rate": value["sample_rate"], "audio_bytes": value["audio_bytes"], "ts": now, } entry["size"] = self._entry_size(entry) self.items[key] = entry self.total_bytes += entry["size"] while self.items and ( len(self.items) > self.max_items or self.total_bytes > self.max_bytes ): _, removed = self.items.popitem(last=False) self.total_bytes -= removed["size"] def clear(self): with self.lock: self.items.clear() self.total_bytes = 0 def info(self) -> dict: now = time.time() with self.lock: self._purge_expired(now) return {"items": len(self.items), "bytes": self.total_bytes} model_lock = threading.Lock() request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) current_requests: Dict[str, Dict] = {} memory_cache = MemoryAudioCache( max_items=MEMORY_CACHE_MAX_ITEMS, max_bytes=MEMORY_CACHE_MAX_BYTES, ttl=MEMORY_CACHE_TTL, ) executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) inflight_lock = threading.Lock() inflight_tasks: Dict[str, asyncio.Future] = {} model_session = None model_name = DEFAULT_MODEL_NAME sample_rate = DEFAULT_SAMPLE_RATE _KOKORO_ONNX_ENGINE = None _EN_G2P_PIPELINE = None def split_sentences(text: str) -> List[str]: parts = [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()] merged: List[str] = [] buf = "" for part in parts: if len(part) < 3 and merged: merged[-1] = f"{merged[-1]} {part}".strip() else: merged.append(part) return merged def iter_text_parts(text: str, split_pattern: Optional[str]) -> List[str]: text = (text or "").strip() if not text: return [] blocks = [text] if split_pattern: try: blocks = [p.strip() for p in re.split(split_pattern, text) if p.strip()] except re.error: logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern) parts: List[str] = [] for block in blocks: parts.extend(split_sentences(block) or [block]) return parts def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str: raw = f"{sentence}|{voice}|{speed:.4f}|{model}" return hashlib.md5(raw.encode("utf-8")).hexdigest() def sentence_cache_path(key: str) -> str: return os.path.join(CACHE_DIR, f"{key}.wav") def meta_cache_path(key: str) -> str: return os.path.join(CACHE_DIR, f"{key}.json") def to_mono_numpy(audio) -> np.ndarray: arr = np.asarray(audio, dtype=np.float32) if arr.size == 0: return np.array([], dtype=np.float32) if arr.ndim == 2: if arr.shape[1] == 1: arr = arr[:, 0] elif arr.shape[0] == 1: arr = arr[0] else: arr = arr.mean(axis=1) elif arr.ndim > 2: arr = arr.reshape(-1) return arr.astype(np.float32, copy=False) def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]: audio, sr = sf.read(wav_path) buf = io.BytesIO() with sf.SoundFile(buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16") as f: f.write(audio) return buf.getvalue(), sr async def load_sentence_from_disk(key: str) -> Optional[dict]: wav_path = sentence_cache_path(key) meta_path = meta_cache_path(key) if not Path(wav_path).exists() or not Path(meta_path).exists(): return None async with aiofiles.open(meta_path, "r") as f: meta = json.loads(await f.read()) wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(executor, lambda: read_wav_bytes(wav_path)) return {"sentence": meta["sentence"], "sample_rate": sr, "audio_bytes": wav_bytes} async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str): wav_path = sentence_cache_path(key) meta_path = meta_cache_path(key) await asyncio.get_event_loop().run_in_executor(executor, lambda: sf.write(wav_path, audio, sr, format="WAV")) async with aiofiles.open(meta_path, "w") as f: await f.write(json.dumps({"sentence": sentence, "sample_rate": sr})) async def clean_disk_cache(): files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime) if len(files) <= DISK_CACHE_SIZE: return for f in files[: len(files) - DISK_CACHE_SIZE]: try: f.unlink() meta = Path(meta_cache_path(f.stem)) if meta.exists(): meta.unlink() except Exception: pass def resolve_model_path(name: str) -> str: local_path = Path(MODEL_DIR) / name if local_path.exists(): return str(local_path) if os.path.isabs(name) and Path(name).exists(): return name raise FileNotFoundError(f"找不到模型文件: {local_path}") def get_en_g2p_pipeline(): global _EN_G2P_PIPELINE if _EN_G2P_PIPELINE is None: from kokoro import KPipeline # type: ignore _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False) return _EN_G2P_PIPELINE def _download_voice_file(voice_path: Path) -> Path: raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_path.name}" voice_path.parent.mkdir(parents=True, exist_ok=True) tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp") with urllib.request.urlopen(raw_url, timeout=60) as response: payload = response.read() with open(tmp_path, "wb") as f: f.write(payload) tmp_path.replace(voice_path) return voice_path def resolve_voices_path() -> str: voices_dir = Path(VOICES_DIR) if voices_dir.is_dir(): voice_files = sorted(voices_dir.glob("*.bin")) if not voice_files: raise FileNotFoundError(f"voices 目录中没有可用的 .bin 音色文件: {voices_dir}") packed_path = voices_dir / "_voices.generated.npz" needs_rebuild = not packed_path.exists() if not needs_rebuild: packed_mtime = packed_path.stat().st_mtime needs_rebuild = any(f.stat().st_mtime > packed_mtime for f in voice_files) if needs_rebuild: voice_map = {} for voice_file in voice_files: raw = np.fromfile(voice_file, dtype=np.float32) if raw.size != 510 * 1 * 256: raise ValueError(f"音色文件格式不正确: {voice_file}") voice_map[voice_file.stem] = raw.reshape(510, 1, 256) np.savez(packed_path, **voice_map) return str(packed_path) if Path(VOICES_V1_PATH).exists(): return VOICES_V1_PATH raise FileNotFoundError(f"找不到 voices 目录或 voices-v1.0.bin: {voices_dir}, {VOICES_V1_PATH}") def load_onnx_session(name: str): try: import onnxruntime as ort # type: ignore except Exception as e: raise RuntimeError("无法导入 onnxruntime。请在部署环境中安装 onnxruntime。") from e model_path = resolve_model_path(name) sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN return ort.InferenceSession(model_path, sess_options=sess_options, providers=["CPUExecutionProvider"]) def load_kokoro_engine(name: str): try: from kokoro_onnx import Kokoro # type: ignore except Exception as e: raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e model_path = resolve_model_path(name) return Kokoro( model_path=model_path, voices_path=resolve_voices_path(), vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None, ) def load_model(force_reload: bool = False, name: Optional[str] = None): global model_session, model_name, _KOKORO_ONNX_ENGINE with model_lock: target = name or model_name if force_reload or model_session is None or target != model_name: model_session = load_onnx_session(target) _KOKORO_ONNX_ENGINE = load_kokoro_engine(target) model_name = target logger.info("ONNX 模型加载完成: %s", resolve_model_path(target)) return model_session def get_kokoro_engine(name: Optional[str] = None): global _KOKORO_ONNX_ENGINE target = name or model_name if _KOKORO_ONNX_ENGINE is None or target != model_name: load_model(name=target) return _KOKORO_ONNX_ENGINE @asynccontextmanager async def lifespan(app: FastAPI): load_model() yield app = FastAPI(title="Online TTS Service (Kokoro ONNX Optimized)", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def track_clients(request: Request, call_next): client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4()) if request.url.path == "/generate": current_requests[client_id] = {"active": True} response = await call_next(request) if request.url.path == "/generate": current_requests.pop(client_id, None) return response class TTSRequest(BaseModel): text: str voice: Optional[str] = "af_heart" speed: Optional[float] = 1.0 split_pattern: Optional[str] = r"\n+" model_name: Optional[str] = None @field_validator("speed") @classmethod def validate_speed(cls, v): if v is None: return 1.0 v = float(v) if v <= 0: raise ValueError("speed 必须大于 0") return v def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray: if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") engine = get_kokoro_engine(name=model_name) session = load_model(name=model_name) if voice not in set(engine.get_voices()): raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}") phonemes = engine.tokenizer.phonemize(text, "en-us") batched_phonemes = engine._split_phonemes(phonemes) if not batched_phonemes: raise HTTPException(status_code=400, detail="文本音素化失败") voice_style = engine.get_voice_style(voice) audio_segments: List[np.ndarray] = [] for phoneme_batch in batched_phonemes: tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64) if tokens.size == 0: continue feeds = { "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64), "style": np.asarray(voice_style[len(tokens)], dtype=np.float32), "speed": np.asarray([speed], dtype=np.float32), } outputs = session.run(None, feeds) if outputs: audio_segments.append(to_mono_numpy(outputs[0])) if not audio_segments: raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出") audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0] if audio.size == 0 or not np.isfinite(audio).all(): raise HTTPException(status_code=500, detail="生成的音频无效") return audio def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes: buf = io.BytesIO() with sf.SoundFile(buf, "w", samplerate=sr, channels=1, format="WAV", subtype="PCM_16") as f: f.write(audio) return buf.getvalue() def cache_item_to_response(item: dict) -> dict: return { "sentence": item["sentence"], "sample_rate": item["sample_rate"], "audio": base64.b64encode(item["audio_bytes"]).decode(), } async def _compute_sentence_item(sentence: str, voice: str, speed: float, model: str) -> dict: def _run(): audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model) wav_bytes = encode_wav_bytes(audio, sample_rate) return {"sentence": sentence, "sample_rate": sample_rate, "audio_bytes": wav_bytes}, audio item, audio = await asyncio.to_thread(_run) memory_cache.set(sentence_cache_key(sentence, voice, speed, model), item) await save_sentence_to_disk(sentence_cache_key(sentence, voice, speed, model), audio, sample_rate, sentence) await clean_disk_cache() return item async def get_or_create_sentence_cache_item(sentence: str, voice: str, speed: float, model: str) -> dict: key = sentence_cache_key(sentence, voice, speed, model) cached = memory_cache.get(key) if cached: return cached disk_item = await load_sentence_from_disk(key) if disk_item: memory_cache.set(key, disk_item) return disk_item loop = asyncio.get_running_loop() with inflight_lock: fut = inflight_tasks.get(key) if fut is None: fut = loop.create_task(_compute_sentence_item(sentence, voice, speed, model)) inflight_tasks[key] = fut try: return await fut finally: with inflight_lock: if inflight_tasks.get(key) is fut: inflight_tasks.pop(key, None) def synthesize_wav_bytes(text: str, voice: str, speed: float, split_pattern: Optional[str], model_name: Optional[str] = None) -> io.BytesIO: parts = iter_text_parts(text, split_pattern) if not parts: raise HTTPException(status_code=400, detail="文本不能为空") buffers: List[np.ndarray] = [] for part in parts: key = sentence_cache_key(part, voice, speed, model_name or DEFAULT_MODEL_NAME) cached = memory_cache.get(key) if cached: audio, _ = sf.read(io.BytesIO(cached["audio_bytes"]), dtype="float32") buffers.append(to_mono_numpy(audio)) continue audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name) buffers.append(audio) merged = np.concatenate(buffers, axis=0) if len(buffers) > 1 else buffers[0] buf = io.BytesIO() sf.write(buf, merged, samplerate=sample_rate, format="WAV", subtype="PCM_16") buf.seek(0) return buf @app.post("/tts", summary="POST: 传入文本返回 WAV 流") def tts_post(req: TTSRequest): buf = synthesize_wav_bytes( text=req.text, voice=req.voice or "af_heart", speed=req.speed if req.speed is not None else 1.0, split_pattern=req.split_pattern, model_name=req.model_name, ) return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'}) @app.get("/tts", summary="GET: 传入文本返回 WAV 流") def tts_get( text: str = Query(..., description="待合成文本"), voice: str = Query("af_heart"), speed: float = Query(1.0), model_name: str = Query(DEFAULT_MODEL_NAME), split_pattern: str = Query(r"\n+"), ): if speed is None or float(speed) <= 0: raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值") buf = synthesize_wav_bytes(text=text, voice=voice, speed=float(speed), split_pattern=split_pattern, model_name=model_name) return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'}) @app.post("/generate") async def generate_audio_stream(data: Dict = Body(...)): async with request_semaphore: text = data.get("text", "") voice = data.get("voice", "af_heart") speed = float(data.get("speed", 1.0)) model = data.get("model_name", DEFAULT_MODEL_NAME) split_pattern = data.get("split_pattern", r"\n+") client_id = data.get("client_id", str(uuid.uuid4())) if not text.strip(): raise HTTPException(status_code=400, detail="文本不能为空") parts = iter_text_parts(text, split_pattern) if client_id in current_requests: current_requests[client_id]["interrupt"] = True await asyncio.sleep(0.05) current_requests[client_id] = {"interrupt": False} async def stream(): try: for idx, part in enumerate(parts): if current_requests.get(client_id, {}).get("interrupt"): break item = await get_or_create_sentence_cache_item(part, voice, speed, model) resp = cache_item_to_response(item) yield json.dumps({"index": idx, **resp}).encode() + b"\n" finally: current_requests.pop(client_id, None) return StreamingResponse(stream(), media_type="application/x-ndjson") @app.get("/clear-cache") async def clear_cache(): memory_cache.clear() for f in Path(CACHE_DIR).glob("*"): f.unlink() return {"status": "success"} @app.get("/cache-info") async def get_cache_info(): mem = memory_cache.info() disk_files = list(Path(CACHE_DIR).glob("*.wav")) return {"memory_cache": mem["items"], "memory_cache_bytes": mem["bytes"], "disk_cache": len(disk_files)} if __name__ == "__main__": import uvicorn uvicorn.run("speech_tts_onnx_opt:app", host="0.0.0.0", port=18000, reload=False, workers=1)