|
|
@@ -0,0 +1,644 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import base64
|
|
|
+import hashlib
|
|
|
+import io
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import re
|
|
|
+import asyncio
|
|
|
+import threading
|
|
|
+import uuid
|
|
|
+import urllib.request
|
|
|
+from contextlib import asynccontextmanager
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Optional, Tuple
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import soundfile as sf
|
|
|
+from fastapi import Body, FastAPI, HTTPException, Query, Request
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
+from pydantic import BaseModel, field_validator
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 配置
|
|
|
+# ============================================================
|
|
|
+
|
|
|
+DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
|
|
|
+MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
|
|
|
+HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
|
|
|
+TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
|
|
|
+CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
|
|
|
+VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
|
|
|
+VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
|
|
|
+CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
|
|
|
+LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
|
+MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
|
|
|
+MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
|
|
|
+DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
|
|
|
+DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
|
|
|
+VOICE_ALIASES = {}
|
|
|
+
|
|
|
+logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
|
|
|
+logger = logging.getLogger("speech_tts_onnx")
|
|
|
+
|
|
|
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
+SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+")
|
|
|
+_TOKENIZER_CACHE = None
|
|
|
+_VOCAB_CACHE = None
|
|
|
+_EN_G2P_PIPELINE = None
|
|
|
+_KOKORO_ONNX_ENGINE = None
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 运行时状态
|
|
|
+# ============================================================
|
|
|
+
|
|
|
+model_lock = threading.Lock()
|
|
|
+synthesis_lock = threading.Lock()
|
|
|
+request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
|
|
+current_requests: Dict[str, Dict] = {}
|
|
|
+
|
|
|
+model_session = None
|
|
|
+model_name = DEFAULT_MODEL_NAME
|
|
|
+sample_rate = DEFAULT_SAMPLE_RATE
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 工具
|
|
|
+# ============================================================
|
|
|
+
|
|
|
+def split_sentences(text: str) -> List[str]:
|
|
|
+ return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
|
|
|
+
|
|
|
+
|
|
|
+def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
|
|
|
+ raw = f"{sentence}|{voice}|{speed}|{model}"
|
|
|
+ return hashlib.md5(raw.encode("utf-8")).hexdigest()
|
|
|
+
|
|
|
+
|
|
|
+def sentence_cache_path(key: str) -> str:
|
|
|
+ return os.path.join(CACHE_DIR, f"{key}.wav")
|
|
|
+
|
|
|
+
|
|
|
+def meta_cache_path(key: str) -> str:
|
|
|
+ return os.path.join(CACHE_DIR, f"{key}.json")
|
|
|
+
|
|
|
+
|
|
|
+def to_mono_numpy(audio) -> np.ndarray:
|
|
|
+ if audio is None:
|
|
|
+ return np.array([], dtype=np.float32)
|
|
|
+
|
|
|
+ if isinstance(audio, np.ndarray):
|
|
|
+ arr = audio
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ arr = np.asarray(audio)
|
|
|
+ except Exception:
|
|
|
+ return np.array([], dtype=np.float32)
|
|
|
+
|
|
|
+ arr = np.asarray(arr)
|
|
|
+ if arr.size == 0:
|
|
|
+ return np.array([], dtype=np.float32)
|
|
|
+
|
|
|
+ if arr.ndim == 2:
|
|
|
+ if arr.shape[0] == 1:
|
|
|
+ arr = arr[0]
|
|
|
+ elif arr.shape[1] == 1:
|
|
|
+ arr = arr[:, 0]
|
|
|
+ else:
|
|
|
+ arr = arr.mean(axis=1)
|
|
|
+ elif arr.ndim > 2:
|
|
|
+ arr = arr.reshape(-1)
|
|
|
+
|
|
|
+ if arr.ndim == 0:
|
|
|
+ arr = arr.reshape(1)
|
|
|
+
|
|
|
+ if arr.dtype != np.float32:
|
|
|
+ arr = arr.astype(np.float32, copy=False)
|
|
|
+
|
|
|
+ return arr
|
|
|
+
|
|
|
+
|
|
|
+def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
|
|
|
+ audio, sr = sf.read(wav_path)
|
|
|
+ buf = io.BytesIO()
|
|
|
+ with sf.SoundFile(
|
|
|
+ buf,
|
|
|
+ "w",
|
|
|
+ samplerate=sr,
|
|
|
+ channels=audio.shape[1] if audio.ndim > 1 else 1,
|
|
|
+ format="WAV",
|
|
|
+ subtype="PCM_16",
|
|
|
+ ) as f:
|
|
|
+ f.write(audio)
|
|
|
+ return buf.getvalue(), sr
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================
|
|
|
+# 模型加载
|
|
|
+# ============================================================
|
|
|
+
|
|
|
+def resolve_model_path(name: str) -> str:
|
|
|
+ local_path = Path(MODEL_DIR) / name
|
|
|
+ if local_path.exists():
|
|
|
+ return str(local_path)
|
|
|
+
|
|
|
+ if os.path.isabs(name) and Path(name).exists():
|
|
|
+ return name
|
|
|
+
|
|
|
+ raise FileNotFoundError(
|
|
|
+ f"找不到模型文件: {local_path}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def load_vocab() -> Dict[str, int]:
|
|
|
+ """
|
|
|
+ 优先使用 Kokoro 官方 config.json 的 vocab。
|
|
|
+ 如果本地缺失,则回退到 tokenizer.json 里的 vocab。
|
|
|
+ """
|
|
|
+ config_file = Path(CONFIG_PATH)
|
|
|
+ if config_file.exists():
|
|
|
+ try:
|
|
|
+ data = json.loads(config_file.read_text(encoding="utf-8"))
|
|
|
+ vocab = data["vocab"]
|
|
|
+ if isinstance(vocab, dict) and vocab:
|
|
|
+ return {str(k): int(v) for k, v in vocab.items()}
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e)
|
|
|
+
|
|
|
+ tokenizer_file = Path(TOKENIZER_PATH)
|
|
|
+ if not tokenizer_file.exists():
|
|
|
+ raise RuntimeError(
|
|
|
+ f"缺少 vocab 文件: {config_file} / {tokenizer_file}. "
|
|
|
+ f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。"
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ data = json.loads(tokenizer_file.read_text(encoding="utf-8"))
|
|
|
+ vocab = data["model"]["vocab"]
|
|
|
+ if not isinstance(vocab, dict) or not vocab:
|
|
|
+ raise ValueError("tokenizer vocab 为空或格式异常")
|
|
|
+ return {str(k): int(v) for k, v in vocab.items()}
|
|
|
+ except Exception as e:
|
|
|
+ raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e
|
|
|
+
|
|
|
+
|
|
|
+def get_vocab() -> Dict[str, int]:
|
|
|
+ global _VOCAB_CACHE
|
|
|
+ if _VOCAB_CACHE is None:
|
|
|
+ _VOCAB_CACHE = load_vocab()
|
|
|
+ return _VOCAB_CACHE
|
|
|
+
|
|
|
+
|
|
|
+def get_en_g2p_pipeline():
|
|
|
+ global _EN_G2P_PIPELINE
|
|
|
+ if _EN_G2P_PIPELINE is None:
|
|
|
+ from kokoro import KPipeline # type: ignore
|
|
|
+
|
|
|
+ _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
|
|
|
+ return _EN_G2P_PIPELINE
|
|
|
+
|
|
|
+
|
|
|
+def resolve_voice_path(voice: str) -> Path:
|
|
|
+ voice_name = voice.strip()
|
|
|
+ if not voice_name:
|
|
|
+ voice_name = "af_heart"
|
|
|
+ voice_name = VOICE_ALIASES.get(voice_name, voice_name)
|
|
|
+ if not voice_name.endswith(".bin"):
|
|
|
+ voice_name = f"{voice_name}.bin"
|
|
|
+ voice_path = Path(VOICES_DIR) / voice_name
|
|
|
+ if voice_path.exists():
|
|
|
+ return voice_path
|
|
|
+ fallback_path = Path(VOICES_DIR) / "af_bella.bin"
|
|
|
+ if fallback_path.exists():
|
|
|
+ logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name)
|
|
|
+ return fallback_path
|
|
|
+ raise FileNotFoundError(
|
|
|
+ f"找不到 voice 文件: {voice_path}. "
|
|
|
+ "请从模型仓库下载 voices/*.bin 到本地 voices 目录。"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _download_voice_file(voice_path: Path) -> Path:
|
|
|
+ voice_name = voice_path.name
|
|
|
+ raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}"
|
|
|
+ voice_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
|
|
|
+
|
|
|
+ logger.info("从官方地址下载 voice 文件: %s", raw_url)
|
|
|
+ with urllib.request.urlopen(raw_url, timeout=60) as response:
|
|
|
+ content_type = response.headers.get_content_type()
|
|
|
+ payload = response.read()
|
|
|
+
|
|
|
+ if content_type == "text/html" or payload.startswith(b"<!doctype html") or payload.startswith(b"<html"):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"下载到的仍是 HTML 页面而不是 voice 二进制文件: {raw_url}. "
|
|
|
+ "请确认模型仓库的 voices 文件可直接通过 raw 链接访问。"
|
|
|
+ )
|
|
|
+
|
|
|
+ with open(tmp_path, "wb") as f:
|
|
|
+ f.write(payload)
|
|
|
+ tmp_path.replace(voice_path)
|
|
|
+ return voice_path
|
|
|
+
|
|
|
+
|
|
|
+def _is_html_payload(data: bytes) -> bool:
|
|
|
+ return data.startswith(b"<!doctype html") or data.startswith(b"<html")
|
|
|
+
|
|
|
+
|
|
|
+def load_onnx_session(name: str):
|
|
|
+ try:
|
|
|
+ import onnxruntime as ort # type: ignore
|
|
|
+ except Exception as e:
|
|
|
+ raise RuntimeError(
|
|
|
+ "无法导入 onnxruntime。请在部署环境中安装 onnxruntime。"
|
|
|
+ ) from e
|
|
|
+
|
|
|
+ model_path = resolve_model_path(name)
|
|
|
+ providers = ["CPUExecutionProvider"]
|
|
|
+ session = ort.InferenceSession(model_path, providers=providers)
|
|
|
+ logger.info("ONNX 模型加载完成: %s", model_path)
|
|
|
+ return session
|
|
|
+
|
|
|
+
|
|
|
+def load_kokoro_engine(name: str):
|
|
|
+ try:
|
|
|
+ from kokoro_onnx import Kokoro # type: ignore
|
|
|
+ except Exception as e:
|
|
|
+ raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
|
|
|
+
|
|
|
+ model_path = resolve_model_path(name)
|
|
|
+ engine = Kokoro(
|
|
|
+ model_path=model_path,
|
|
|
+ voices_path=VOICES_V1_PATH,
|
|
|
+ vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None,
|
|
|
+ )
|
|
|
+ logger.info("kokoro-onnx 引擎加载完成: %s", model_path)
|
|
|
+ return engine
|
|
|
+
|
|
|
+
|
|
|
+def load_model(force_reload: bool = False, name: Optional[str] = None):
|
|
|
+ global model_session, model_name, _KOKORO_ONNX_ENGINE
|
|
|
+ with model_lock:
|
|
|
+ target_name = name or model_name
|
|
|
+ if force_reload or model_session is None or target_name != model_name:
|
|
|
+ model_session = load_onnx_session(target_name)
|
|
|
+ _KOKORO_ONNX_ENGINE = load_kokoro_engine(target_name)
|
|
|
+ model_name = target_name
|
|
|
+ return model_session
|
|
|
+
|
|
|
+
|
|
|
+def get_kokoro_engine(name: Optional[str] = None):
|
|
|
+ global _KOKORO_ONNX_ENGINE
|
|
|
+ target_name = name or model_name
|
|
|
+ if _KOKORO_ONNX_ENGINE is None or target_name != model_name:
|
|
|
+ load_model(name=target_name)
|
|
|
+ return _KOKORO_ONNX_ENGINE
|
|
|
+
|
|
|
+
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
+ try:
|
|
|
+ load_model()
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ logger.info("应用关闭中...")
|
|
|
+
|
|
|
+
|
|
|
+app = FastAPI(title="Online TTS Service (Kokoro ONNX)", lifespan=lifespan)
|
|
|
+
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"],
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+@app.middleware("http")
|
|
|
+async def track_clients(request: Request, call_next):
|
|
|
+ client_id = (
|
|
|
+ request.query_params.get("client_id")
|
|
|
+ or request.headers.get("X-Client-ID")
|
|
|
+ or str(uuid.uuid4())
|
|
|
+ )
|
|
|
+ if request.url.path == "/generate":
|
|
|
+ current_requests[client_id] = {"active": True}
|
|
|
+
|
|
|
+ response = await call_next(request)
|
|
|
+
|
|
|
+ if request.url.path == "/generate":
|
|
|
+ current_requests.pop(client_id, None)
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+
|
|
|
+class TTSRequest(BaseModel):
|
|
|
+ text: str
|
|
|
+ voice: Optional[str] = "af_heart"
|
|
|
+ speed: Optional[float] = 1.0
|
|
|
+ split_pattern: Optional[str] = r"\n+"
|
|
|
+ model_name: Optional[str] = None
|
|
|
+
|
|
|
+ @field_validator("speed")
|
|
|
+ @classmethod
|
|
|
+ def validate_speed(cls, v):
|
|
|
+ if v is None:
|
|
|
+ return 1.0
|
|
|
+ v = float(v)
|
|
|
+ if v <= 0:
|
|
|
+ raise ValueError("speed 必须大于 0")
|
|
|
+ return v
|
|
|
+
|
|
|
+
|
|
|
+def _encode_token_ids(text: str) -> List[int]:
|
|
|
+ vocab = get_vocab()
|
|
|
+ unknown_id = vocab.get(" ", 16)
|
|
|
+ return [vocab.get(ch, unknown_id) for ch in text]
|
|
|
+
|
|
|
+
|
|
|
+def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]:
|
|
|
+ pipeline = get_en_g2p_pipeline()
|
|
|
+ _, tokens = pipeline.g2p(text)
|
|
|
+ chunks = list(pipeline.en_tokenize(tokens))
|
|
|
+ if not chunks:
|
|
|
+ raise RuntimeError("英文文本音素化失败,未生成 phonemes。")
|
|
|
+ return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks]
|
|
|
+
|
|
|
+
|
|
|
+def _phonemize_en_for_onnx(text: str) -> str:
|
|
|
+ return _phonemize_en_chunks_for_onnx(text)[0][1]
|
|
|
+
|
|
|
+
|
|
|
+def _tokenize_for_onnx(text: str, voice: str):
|
|
|
+ """
|
|
|
+ 将文本转成 ONNX 所需 token。
|
|
|
+ 这里优先复用 kokoro 的预处理链。
|
|
|
+ 如果你项目里有更稳定的 tokenizer,可以替换这个函数。
|
|
|
+ """
|
|
|
+ normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
|
|
|
+ if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")):
|
|
|
+ phonemes = _phonemize_en_for_onnx(text)
|
|
|
+ tokens = _encode_token_ids(phonemes)
|
|
|
+ else:
|
|
|
+ tokens = _encode_token_ids(text)
|
|
|
+
|
|
|
+ if not tokens:
|
|
|
+ raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。")
|
|
|
+
|
|
|
+ if len(tokens) > 510:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。"
|
|
|
+ )
|
|
|
+
|
|
|
+ return np.asarray([[0, *tokens, 0]], dtype=np.int64)
|
|
|
+
|
|
|
+
|
|
|
+def _is_english_voice(voice: str) -> bool:
|
|
|
+ normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
|
|
|
+ return normalized_voice.startswith(("af_", "am_", "bf_", "bm_"))
|
|
|
+
|
|
|
+
|
|
|
+def _style_for_voice(voice: str) -> np.ndarray:
|
|
|
+ """
|
|
|
+ ONNX 模型需要 style 向量。
|
|
|
+ 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。
|
|
|
+ """
|
|
|
+ voice_path = resolve_voice_path(voice)
|
|
|
+ if voice_path.suffix == ".bin":
|
|
|
+ try:
|
|
|
+ header = voice_path.read_bytes()[:64]
|
|
|
+ if _is_html_payload(header):
|
|
|
+ fallback_path = Path(VOICES_DIR) / "af_bella.bin"
|
|
|
+ if voice_path.name != fallback_path.name and fallback_path.exists():
|
|
|
+ logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name)
|
|
|
+ voice_path = fallback_path
|
|
|
+ else:
|
|
|
+ voice_path = _download_voice_file(voice_path)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e)
|
|
|
+ fallback_path = Path(VOICES_DIR) / "af_bella.bin"
|
|
|
+ if voice_path.name != fallback_path.name and fallback_path.exists():
|
|
|
+ voice_path = fallback_path
|
|
|
+ else:
|
|
|
+ voice_path = _download_voice_file(voice_path)
|
|
|
+
|
|
|
+ style = np.fromfile(str(voice_path), dtype=np.float32)
|
|
|
+ if style.size == 0:
|
|
|
+ raise RuntimeError(f"voice 文件为空: {voice_path}")
|
|
|
+ if style.size % 256 != 0:
|
|
|
+ # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。
|
|
|
+ voice_path = _download_voice_file(voice_path)
|
|
|
+ style = np.fromfile(str(voice_path), dtype=np.float32)
|
|
|
+ if style.size == 0 or style.size % 256 != 0:
|
|
|
+ raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}")
|
|
|
+ style = style.reshape(-1, 256)
|
|
|
+ return style
|
|
|
+
|
|
|
+
|
|
|
+def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray:
|
|
|
+ if style.ndim != 2 or style.shape[-1] != 256:
|
|
|
+ raise RuntimeError(f"style 维度异常: {style.shape}")
|
|
|
+ # Follow the official ONNX example: ref_s = voices[len(tokens)]
|
|
|
+ idx = min(max(token_len, 0), style.shape[0] - 1)
|
|
|
+ return style[idx : idx + 1]
|
|
|
+
|
|
|
+
|
|
|
+def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]:
|
|
|
+ for input_name in ("style", "ref_s"):
|
|
|
+ for model_input in session.get_inputs():
|
|
|
+ if model_input.name != input_name:
|
|
|
+ continue
|
|
|
+
|
|
|
+ input_shape = model_input.shape
|
|
|
+ expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim
|
|
|
+
|
|
|
+ if expected_rank == 2:
|
|
|
+ return input_name, style_slice.astype(np.float32, copy=False)
|
|
|
+ if expected_rank == 3:
|
|
|
+ return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False)
|
|
|
+
|
|
|
+ raise RuntimeError(
|
|
|
+ f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}"
|
|
|
+ )
|
|
|
+
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
|
|
|
+ if not text.strip():
|
|
|
+ raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
+
|
|
|
+ engine = get_kokoro_engine(name=model_name)
|
|
|
+ session = load_model(name=model_name)
|
|
|
+ available_voices = set(engine.get_voices())
|
|
|
+ if voice not in available_voices:
|
|
|
+ raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
|
|
|
+
|
|
|
+ phonemes = engine.tokenizer.phonemize(text, "en-us")
|
|
|
+ batched_phonemes = engine._split_phonemes(phonemes)
|
|
|
+ if not batched_phonemes:
|
|
|
+ raise HTTPException(status_code=400, detail="文本音素化失败")
|
|
|
+
|
|
|
+ voice_style = engine.get_voice_style(voice)
|
|
|
+ audio_segments: List[np.ndarray] = []
|
|
|
+ for phoneme_batch in batched_phonemes:
|
|
|
+ tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
|
|
|
+ if tokens.size == 0:
|
|
|
+ continue
|
|
|
+ style = voice_style[len(tokens)]
|
|
|
+ feeds = {
|
|
|
+ "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
|
|
|
+ "style": np.asarray(style, dtype=np.float32),
|
|
|
+ "speed": np.asarray([speed], dtype=np.float32),
|
|
|
+ }
|
|
|
+ outputs = session.run(None, feeds)
|
|
|
+ if outputs:
|
|
|
+ audio_segments.append(to_mono_numpy(outputs[0]))
|
|
|
+
|
|
|
+ if not audio_segments:
|
|
|
+ raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
|
|
|
+
|
|
|
+ audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
|
|
|
+ if audio.size == 0 or not np.isfinite(audio).all():
|
|
|
+ raise HTTPException(status_code=500, detail="生成的音频无效")
|
|
|
+ return audio
|
|
|
+
|
|
|
+
|
|
|
+def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
|
|
|
+ segments: List[np.ndarray] = []
|
|
|
+ if _is_english_voice(voice):
|
|
|
+ parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
|
|
|
+ else:
|
|
|
+ parts = split_sentences(text) if text.strip() else []
|
|
|
+ if not parts:
|
|
|
+ parts = [text.strip()]
|
|
|
+
|
|
|
+ with synthesis_lock:
|
|
|
+ for part in parts:
|
|
|
+ audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
|
|
|
+ if audio.size > 0:
|
|
|
+ segments.append(audio)
|
|
|
+
|
|
|
+ if not segments:
|
|
|
+ raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
|
|
|
+
|
|
|
+ audio_concat = np.concatenate(segments, axis=0)
|
|
|
+ buf = io.BytesIO()
|
|
|
+ sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
|
|
|
+ buf.seek(0)
|
|
|
+ return buf
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/tts", summary="POST: 传入文本返回 WAV 流")
|
|
|
+def tts_post(req: TTSRequest):
|
|
|
+ buf = synthesize_wav_bytes(
|
|
|
+ text=req.text,
|
|
|
+ voice=req.voice or "af_heart",
|
|
|
+ speed=req.speed if req.speed is not None else 1.0,
|
|
|
+ model_name=req.model_name,
|
|
|
+ )
|
|
|
+ return StreamingResponse(
|
|
|
+ buf,
|
|
|
+ media_type="audio/wav",
|
|
|
+ headers={"Content-Disposition": 'inline; filename="tts.wav"'},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/tts", summary="GET: 传入文本返回 WAV 流")
|
|
|
+def tts_get(
|
|
|
+ text: str = Query(..., description="待合成文本"),
|
|
|
+ voice: str = Query("af_heart"),
|
|
|
+ speed: float = Query(1.0),
|
|
|
+ model_name: str = Query(DEFAULT_MODEL_NAME),
|
|
|
+):
|
|
|
+ if speed is None or float(speed) <= 0:
|
|
|
+ raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
|
|
|
+ buf = synthesize_wav_bytes(
|
|
|
+ text=text,
|
|
|
+ voice=voice,
|
|
|
+ speed=float(speed),
|
|
|
+ model_name=model_name,
|
|
|
+ )
|
|
|
+ return StreamingResponse(
|
|
|
+ buf,
|
|
|
+ media_type="audio/wav",
|
|
|
+ headers={"Content-Disposition": 'inline; filename="tts.wav"'},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/generate")
|
|
|
+async def generate_audio_stream(data: Dict = Body(...)):
|
|
|
+ async with request_semaphore:
|
|
|
+ text = data.get("text", "")
|
|
|
+ voice = data.get("voice", "af_heart")
|
|
|
+ speed = float(data.get("speed", 1.0))
|
|
|
+ model = data.get("model_name", DEFAULT_MODEL_NAME)
|
|
|
+ client_id = data.get("client_id", str(uuid.uuid4()))
|
|
|
+
|
|
|
+ if not text.strip():
|
|
|
+ raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
+
|
|
|
+ if _is_english_voice(voice):
|
|
|
+ parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
|
|
|
+ else:
|
|
|
+ parts = split_sentences(text) if text.strip() else []
|
|
|
+ if not parts:
|
|
|
+ parts = [text.strip()]
|
|
|
+
|
|
|
+ if client_id in current_requests:
|
|
|
+ current_requests[client_id]["interrupt"] = True
|
|
|
+ await asyncio.sleep(0.05)
|
|
|
+
|
|
|
+ current_requests[client_id] = {"interrupt": False}
|
|
|
+
|
|
|
+ async def stream():
|
|
|
+ try:
|
|
|
+ for idx, part in enumerate(parts):
|
|
|
+ if not part:
|
|
|
+ continue
|
|
|
+ if current_requests.get(client_id, {}).get("interrupt"):
|
|
|
+ break
|
|
|
+
|
|
|
+ audio = await asyncio.to_thread(
|
|
|
+ synthesize_audio,
|
|
|
+ text=part,
|
|
|
+ voice=voice,
|
|
|
+ speed=speed,
|
|
|
+ model_name=model,
|
|
|
+ )
|
|
|
+
|
|
|
+ with io.BytesIO() as buf:
|
|
|
+ with sf.SoundFile(
|
|
|
+ buf,
|
|
|
+ "w",
|
|
|
+ sample_rate,
|
|
|
+ channels=audio.shape[1] if audio.ndim > 1 else 1,
|
|
|
+ format="WAV",
|
|
|
+ subtype="PCM_16",
|
|
|
+ ) as f:
|
|
|
+ f.write(audio)
|
|
|
+ audio_b64 = base64.b64encode(buf.getvalue()).decode()
|
|
|
+
|
|
|
+ yield json.dumps(
|
|
|
+ {
|
|
|
+ "index": idx,
|
|
|
+ "sentence": part,
|
|
|
+ "audio": audio_b64,
|
|
|
+ "sample_rate": sample_rate,
|
|
|
+ }
|
|
|
+ ).encode() + b"\n"
|
|
|
+ finally:
|
|
|
+ current_requests.pop(client_id, None)
|
|
|
+
|
|
|
+ return StreamingResponse(stream(), media_type="application/x-ndjson")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import uvicorn
|
|
|
+
|
|
|
+ uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)
|