| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652 |
- from __future__ import annotations
- import base64
- import hashlib
- import io
- import json
- import logging
- import os
- import re
- import asyncio
- import threading
- import uuid
- import urllib.request
- from contextlib import asynccontextmanager
- from pathlib import Path
- from typing import Dict, List, Optional, Tuple
- import numpy as np
- import soundfile as sf
- from fastapi import Body, FastAPI, HTTPException, Query, Request
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import StreamingResponse
- from pydantic import BaseModel, field_validator
- # ============================================================
- # 配置
- # ============================================================
- DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model_quantized.onnx")
- MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
- HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
- TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
- CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
- VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
- VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
- CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
- MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
- MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
- DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
- DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
- VOICE_ALIASES = {}
- logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
- logger = logging.getLogger("speech_tts_onnx")
- class _PhonemizerWordCountMismatchFilter(logging.Filter):
- def filter(self, record: logging.LogRecord) -> bool:
- return "words count mismatch" not in record.getMessage()
- logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
- Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
- SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+")
- _TOKENIZER_CACHE = None
- _VOCAB_CACHE = None
- _EN_G2P_PIPELINE = None
- _KOKORO_ONNX_ENGINE = None
- # ============================================================
- # 运行时状态
- # ============================================================
- model_lock = threading.Lock()
- synthesis_lock = threading.Lock()
- request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
- current_requests: Dict[str, Dict] = {}
- model_session = None
- model_name = DEFAULT_MODEL_NAME
- sample_rate = DEFAULT_SAMPLE_RATE
- # ============================================================
- # 工具
- # ============================================================
- def split_sentences(text: str) -> List[str]:
- return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
- def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
- raw = f"{sentence}|{voice}|{speed}|{model}"
- return hashlib.md5(raw.encode("utf-8")).hexdigest()
- def sentence_cache_path(key: str) -> str:
- return os.path.join(CACHE_DIR, f"{key}.wav")
- def meta_cache_path(key: str) -> str:
- return os.path.join(CACHE_DIR, f"{key}.json")
- def to_mono_numpy(audio) -> np.ndarray:
- if audio is None:
- return np.array([], dtype=np.float32)
- if isinstance(audio, np.ndarray):
- arr = audio
- else:
- try:
- arr = np.asarray(audio)
- except Exception:
- return np.array([], dtype=np.float32)
- arr = np.asarray(arr)
- if arr.size == 0:
- return np.array([], dtype=np.float32)
- if arr.ndim == 2:
- if arr.shape[0] == 1:
- arr = arr[0]
- elif arr.shape[1] == 1:
- arr = arr[:, 0]
- else:
- arr = arr.mean(axis=1)
- elif arr.ndim > 2:
- arr = arr.reshape(-1)
- if arr.ndim == 0:
- arr = arr.reshape(1)
- if arr.dtype != np.float32:
- arr = arr.astype(np.float32, copy=False)
- return arr
- def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
- audio, sr = sf.read(wav_path)
- buf = io.BytesIO()
- with sf.SoundFile(
- buf,
- "w",
- samplerate=sr,
- channels=audio.shape[1] if audio.ndim > 1 else 1,
- format="WAV",
- subtype="PCM_16",
- ) as f:
- f.write(audio)
- return buf.getvalue(), sr
- # ============================================================
- # 模型加载
- # ============================================================
- def resolve_model_path(name: str) -> str:
- local_path = Path(MODEL_DIR) / name
- if local_path.exists():
- return str(local_path)
- if os.path.isabs(name) and Path(name).exists():
- return name
- raise FileNotFoundError(
- f"找不到模型文件: {local_path}"
- )
- def load_vocab() -> Dict[str, int]:
- """
- 优先使用 Kokoro 官方 config.json 的 vocab。
- 如果本地缺失,则回退到 tokenizer.json 里的 vocab。
- """
- config_file = Path(CONFIG_PATH)
- if config_file.exists():
- try:
- data = json.loads(config_file.read_text(encoding="utf-8"))
- vocab = data["vocab"]
- if isinstance(vocab, dict) and vocab:
- return {str(k): int(v) for k, v in vocab.items()}
- except Exception as e:
- logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e)
- tokenizer_file = Path(TOKENIZER_PATH)
- if not tokenizer_file.exists():
- raise RuntimeError(
- f"缺少 vocab 文件: {config_file} / {tokenizer_file}. "
- f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。"
- )
- try:
- data = json.loads(tokenizer_file.read_text(encoding="utf-8"))
- vocab = data["model"]["vocab"]
- if not isinstance(vocab, dict) or not vocab:
- raise ValueError("tokenizer vocab 为空或格式异常")
- return {str(k): int(v) for k, v in vocab.items()}
- except Exception as e:
- raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e
- def get_vocab() -> Dict[str, int]:
- global _VOCAB_CACHE
- if _VOCAB_CACHE is None:
- _VOCAB_CACHE = load_vocab()
- return _VOCAB_CACHE
- def get_en_g2p_pipeline():
- global _EN_G2P_PIPELINE
- if _EN_G2P_PIPELINE is None:
- from kokoro import KPipeline # type: ignore
- _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
- return _EN_G2P_PIPELINE
- def resolve_voice_path(voice: str) -> Path:
- voice_name = voice.strip()
- if not voice_name:
- voice_name = "af_heart"
- voice_name = VOICE_ALIASES.get(voice_name, voice_name)
- if not voice_name.endswith(".bin"):
- voice_name = f"{voice_name}.bin"
- voice_path = Path(VOICES_DIR) / voice_name
- if voice_path.exists():
- return voice_path
- fallback_path = Path(VOICES_DIR) / "af_bella.bin"
- if fallback_path.exists():
- logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name)
- return fallback_path
- raise FileNotFoundError(
- f"找不到 voice 文件: {voice_path}. "
- "请从模型仓库下载 voices/*.bin 到本地 voices 目录。"
- )
- def _download_voice_file(voice_path: Path) -> Path:
- voice_name = voice_path.name
- raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}"
- voice_path.parent.mkdir(parents=True, exist_ok=True)
- tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
- logger.info("从官方地址下载 voice 文件: %s", raw_url)
- with urllib.request.urlopen(raw_url, timeout=60) as response:
- content_type = response.headers.get_content_type()
- payload = response.read()
- if content_type == "text/html" or payload.startswith(b"<!doctype html") or payload.startswith(b"<html"):
- raise RuntimeError(
- f"下载到的仍是 HTML 页面而不是 voice 二进制文件: {raw_url}. "
- "请确认模型仓库的 voices 文件可直接通过 raw 链接访问。"
- )
- with open(tmp_path, "wb") as f:
- f.write(payload)
- tmp_path.replace(voice_path)
- return voice_path
- def _is_html_payload(data: bytes) -> bool:
- return data.startswith(b"<!doctype html") or data.startswith(b"<html")
- def load_onnx_session(name: str):
- try:
- import onnxruntime as ort # type: ignore
- except Exception as e:
- raise RuntimeError(
- "无法导入 onnxruntime。请在部署环境中安装 onnxruntime。"
- ) from e
- model_path = resolve_model_path(name)
- providers = ["CPUExecutionProvider"]
- session = ort.InferenceSession(model_path, providers=providers)
- logger.info("ONNX 模型加载完成: %s", model_path)
- return session
- def load_kokoro_engine(name: str):
- try:
- from kokoro_onnx import Kokoro # type: ignore
- except Exception as e:
- raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
- model_path = resolve_model_path(name)
- engine = Kokoro(
- model_path=model_path,
- voices_path=VOICES_V1_PATH,
- vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None,
- )
- logger.info("kokoro-onnx 引擎加载完成: %s", model_path)
- return engine
- def load_model(force_reload: bool = False, name: Optional[str] = None):
- global model_session, model_name, _KOKORO_ONNX_ENGINE
- with model_lock:
- target_name = name or model_name
- if force_reload or model_session is None or target_name != model_name:
- model_session = load_onnx_session(target_name)
- _KOKORO_ONNX_ENGINE = load_kokoro_engine(target_name)
- model_name = target_name
- return model_session
- def get_kokoro_engine(name: Optional[str] = None):
- global _KOKORO_ONNX_ENGINE
- target_name = name or model_name
- if _KOKORO_ONNX_ENGINE is None or target_name != model_name:
- load_model(name=target_name)
- return _KOKORO_ONNX_ENGINE
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- try:
- load_model()
- yield
- finally:
- logger.info("应用关闭中...")
- app = FastAPI(title="Online TTS Service (Kokoro ONNX)", lifespan=lifespan)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.middleware("http")
- async def track_clients(request: Request, call_next):
- client_id = (
- request.query_params.get("client_id")
- or request.headers.get("X-Client-ID")
- or str(uuid.uuid4())
- )
- if request.url.path == "/generate":
- current_requests[client_id] = {"active": True}
- response = await call_next(request)
- if request.url.path == "/generate":
- current_requests.pop(client_id, None)
- return response
- class TTSRequest(BaseModel):
- text: str
- voice: Optional[str] = "af_heart"
- speed: Optional[float] = 1.0
- split_pattern: Optional[str] = r"\n+"
- model_name: Optional[str] = None
- @field_validator("speed")
- @classmethod
- def validate_speed(cls, v):
- if v is None:
- return 1.0
- v = float(v)
- if v <= 0:
- raise ValueError("speed 必须大于 0")
- return v
- def _encode_token_ids(text: str) -> List[int]:
- vocab = get_vocab()
- unknown_id = vocab.get(" ", 16)
- return [vocab.get(ch, unknown_id) for ch in text]
- def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]:
- pipeline = get_en_g2p_pipeline()
- _, tokens = pipeline.g2p(text)
- chunks = list(pipeline.en_tokenize(tokens))
- if not chunks:
- raise RuntimeError("英文文本音素化失败,未生成 phonemes。")
- return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks]
- def _phonemize_en_for_onnx(text: str) -> str:
- return _phonemize_en_chunks_for_onnx(text)[0][1]
- def _tokenize_for_onnx(text: str, voice: str):
- """
- 将文本转成 ONNX 所需 token。
- 这里优先复用 kokoro 的预处理链。
- 如果你项目里有更稳定的 tokenizer,可以替换这个函数。
- """
- normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
- if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")):
- phonemes = _phonemize_en_for_onnx(text)
- tokens = _encode_token_ids(phonemes)
- else:
- tokens = _encode_token_ids(text)
- if not tokens:
- raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。")
- if len(tokens) > 510:
- raise RuntimeError(
- f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。"
- )
- return np.asarray([[0, *tokens, 0]], dtype=np.int64)
- def _is_english_voice(voice: str) -> bool:
- normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
- return normalized_voice.startswith(("af_", "am_", "bf_", "bm_"))
- def _style_for_voice(voice: str) -> np.ndarray:
- """
- ONNX 模型需要 style 向量。
- 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。
- """
- voice_path = resolve_voice_path(voice)
- if voice_path.suffix == ".bin":
- try:
- header = voice_path.read_bytes()[:64]
- if _is_html_payload(header):
- fallback_path = Path(VOICES_DIR) / "af_bella.bin"
- if voice_path.name != fallback_path.name and fallback_path.exists():
- logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name)
- voice_path = fallback_path
- else:
- voice_path = _download_voice_file(voice_path)
- except Exception as e:
- logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e)
- fallback_path = Path(VOICES_DIR) / "af_bella.bin"
- if voice_path.name != fallback_path.name and fallback_path.exists():
- voice_path = fallback_path
- else:
- voice_path = _download_voice_file(voice_path)
- style = np.fromfile(str(voice_path), dtype=np.float32)
- if style.size == 0:
- raise RuntimeError(f"voice 文件为空: {voice_path}")
- if style.size % 256 != 0:
- # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。
- voice_path = _download_voice_file(voice_path)
- style = np.fromfile(str(voice_path), dtype=np.float32)
- if style.size == 0 or style.size % 256 != 0:
- raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}")
- style = style.reshape(-1, 256)
- return style
- def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray:
- if style.ndim != 2 or style.shape[-1] != 256:
- raise RuntimeError(f"style 维度异常: {style.shape}")
- # Follow the official ONNX example: ref_s = voices[len(tokens)]
- idx = min(max(token_len, 0), style.shape[0] - 1)
- return style[idx : idx + 1]
- def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]:
- for input_name in ("style", "ref_s"):
- for model_input in session.get_inputs():
- if model_input.name != input_name:
- continue
- input_shape = model_input.shape
- expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim
- if expected_rank == 2:
- return input_name, style_slice.astype(np.float32, copy=False)
- if expected_rank == 3:
- return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False)
- raise RuntimeError(
- f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}"
- )
- return None, None
- def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
- if not text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空")
- engine = get_kokoro_engine(name=model_name)
- session = load_model(name=model_name)
- available_voices = set(engine.get_voices())
- if voice not in available_voices:
- raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
- phonemes = engine.tokenizer.phonemize(text, "en-us")
- batched_phonemes = engine._split_phonemes(phonemes)
- if not batched_phonemes:
- raise HTTPException(status_code=400, detail="文本音素化失败")
- voice_style = engine.get_voice_style(voice)
- audio_segments: List[np.ndarray] = []
- for phoneme_batch in batched_phonemes:
- tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
- if tokens.size == 0:
- continue
- style = voice_style[len(tokens)]
- feeds = {
- "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
- "style": np.asarray(style, dtype=np.float32),
- "speed": np.asarray([speed], dtype=np.float32),
- }
- outputs = session.run(None, feeds)
- if outputs:
- audio_segments.append(to_mono_numpy(outputs[0]))
- if not audio_segments:
- raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
- audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
- if audio.size == 0 or not np.isfinite(audio).all():
- raise HTTPException(status_code=500, detail="生成的音频无效")
- return audio
- def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
- segments: List[np.ndarray] = []
- if _is_english_voice(voice):
- parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
- else:
- parts = split_sentences(text) if text.strip() else []
- if not parts:
- parts = [text.strip()]
- with synthesis_lock:
- for part in parts:
- audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
- if audio.size > 0:
- segments.append(audio)
- if not segments:
- raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
- audio_concat = np.concatenate(segments, axis=0)
- buf = io.BytesIO()
- sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
- buf.seek(0)
- return buf
- @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
- def tts_post(req: TTSRequest):
- buf = synthesize_wav_bytes(
- text=req.text,
- voice=req.voice or "af_heart",
- speed=req.speed if req.speed is not None else 1.0,
- model_name=req.model_name,
- )
- return StreamingResponse(
- buf,
- media_type="audio/wav",
- headers={"Content-Disposition": 'inline; filename="tts.wav"'},
- )
- @app.get("/tts", summary="GET: 传入文本返回 WAV 流")
- def tts_get(
- text: str = Query(..., description="待合成文本"),
- voice: str = Query("af_heart"),
- speed: float = Query(1.0),
- model_name: str = Query(DEFAULT_MODEL_NAME),
- ):
- if speed is None or float(speed) <= 0:
- raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
- buf = synthesize_wav_bytes(
- text=text,
- voice=voice,
- speed=float(speed),
- model_name=model_name,
- )
- return StreamingResponse(
- buf,
- media_type="audio/wav",
- headers={"Content-Disposition": 'inline; filename="tts.wav"'},
- )
- @app.post("/generate")
- async def generate_audio_stream(data: Dict = Body(...)):
- async with request_semaphore:
- text = data.get("text", "")
- voice = data.get("voice", "af_heart")
- speed = float(data.get("speed", 1.0))
- model = data.get("model_name", DEFAULT_MODEL_NAME)
- client_id = data.get("client_id", str(uuid.uuid4()))
- if not text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空")
- if _is_english_voice(voice):
- parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
- else:
- parts = split_sentences(text) if text.strip() else []
- if not parts:
- parts = [text.strip()]
- if client_id in current_requests:
- current_requests[client_id]["interrupt"] = True
- await asyncio.sleep(0.05)
- current_requests[client_id] = {"interrupt": False}
- async def stream():
- try:
- for idx, part in enumerate(parts):
- if not part:
- continue
- if current_requests.get(client_id, {}).get("interrupt"):
- break
- audio = await asyncio.to_thread(
- synthesize_audio,
- text=part,
- voice=voice,
- speed=speed,
- model_name=model,
- )
- with io.BytesIO() as buf:
- with sf.SoundFile(
- buf,
- "w",
- sample_rate,
- channels=audio.shape[1] if audio.ndim > 1 else 1,
- format="WAV",
- subtype="PCM_16",
- ) as f:
- f.write(audio)
- audio_b64 = base64.b64encode(buf.getvalue()).decode()
- yield json.dumps(
- {
- "index": idx,
- "sentence": part,
- "audio": audio_b64,
- "sample_rate": sample_rate,
- }
- ).encode() + b"\n"
- finally:
- current_requests.pop(client_id, None)
- return StreamingResponse(stream(), media_type="application/x-ndjson")
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)
|