speech_tts_onnx.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. from __future__ import annotations
  2. import base64
  3. import hashlib
  4. import io
  5. import json
  6. import logging
  7. import os
  8. import re
  9. import asyncio
  10. import threading
  11. import uuid
  12. import urllib.request
  13. from contextlib import asynccontextmanager
  14. from pathlib import Path
  15. from typing import Dict, List, Optional, Tuple
  16. import numpy as np
  17. import soundfile as sf
  18. from fastapi import Body, FastAPI, HTTPException, Query, Request
  19. from fastapi.middleware.cors import CORSMiddleware
  20. from fastapi.responses import StreamingResponse
  21. from pydantic import BaseModel, field_validator
  22. # ============================================================
  23. # 配置
  24. # ============================================================
  25. DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model_quantized.onnx")
  26. MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
  27. HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
  28. TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
  29. CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
  30. VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
  31. VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
  32. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
  33. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
  34. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
  35. MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
  36. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
  37. DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
  38. VOICE_ALIASES = {}
  39. logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
  40. logger = logging.getLogger("speech_tts_onnx")
  41. class _PhonemizerWordCountMismatchFilter(logging.Filter):
  42. def filter(self, record: logging.LogRecord) -> bool:
  43. return "words count mismatch" not in record.getMessage()
  44. logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
  45. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  46. SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+")
  47. _TOKENIZER_CACHE = None
  48. _VOCAB_CACHE = None
  49. _EN_G2P_PIPELINE = None
  50. _KOKORO_ONNX_ENGINE = None
  51. # ============================================================
  52. # 运行时状态
  53. # ============================================================
  54. model_lock = threading.Lock()
  55. synthesis_lock = threading.Lock()
  56. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  57. current_requests: Dict[str, Dict] = {}
  58. model_session = None
  59. model_name = DEFAULT_MODEL_NAME
  60. sample_rate = DEFAULT_SAMPLE_RATE
  61. # ============================================================
  62. # 工具
  63. # ============================================================
  64. def split_sentences(text: str) -> List[str]:
  65. return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  66. def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
  67. raw = f"{sentence}|{voice}|{speed}|{model}"
  68. return hashlib.md5(raw.encode("utf-8")).hexdigest()
  69. def sentence_cache_path(key: str) -> str:
  70. return os.path.join(CACHE_DIR, f"{key}.wav")
  71. def meta_cache_path(key: str) -> str:
  72. return os.path.join(CACHE_DIR, f"{key}.json")
  73. def to_mono_numpy(audio) -> np.ndarray:
  74. if audio is None:
  75. return np.array([], dtype=np.float32)
  76. if isinstance(audio, np.ndarray):
  77. arr = audio
  78. else:
  79. try:
  80. arr = np.asarray(audio)
  81. except Exception:
  82. return np.array([], dtype=np.float32)
  83. arr = np.asarray(arr)
  84. if arr.size == 0:
  85. return np.array([], dtype=np.float32)
  86. if arr.ndim == 2:
  87. if arr.shape[0] == 1:
  88. arr = arr[0]
  89. elif arr.shape[1] == 1:
  90. arr = arr[:, 0]
  91. else:
  92. arr = arr.mean(axis=1)
  93. elif arr.ndim > 2:
  94. arr = arr.reshape(-1)
  95. if arr.ndim == 0:
  96. arr = arr.reshape(1)
  97. if arr.dtype != np.float32:
  98. arr = arr.astype(np.float32, copy=False)
  99. return arr
  100. def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
  101. audio, sr = sf.read(wav_path)
  102. buf = io.BytesIO()
  103. with sf.SoundFile(
  104. buf,
  105. "w",
  106. samplerate=sr,
  107. channels=audio.shape[1] if audio.ndim > 1 else 1,
  108. format="WAV",
  109. subtype="PCM_16",
  110. ) as f:
  111. f.write(audio)
  112. return buf.getvalue(), sr
  113. # ============================================================
  114. # 模型加载
  115. # ============================================================
  116. def resolve_model_path(name: str) -> str:
  117. local_path = Path(MODEL_DIR) / name
  118. if local_path.exists():
  119. return str(local_path)
  120. if os.path.isabs(name) and Path(name).exists():
  121. return name
  122. raise FileNotFoundError(
  123. f"找不到模型文件: {local_path}"
  124. )
  125. def load_vocab() -> Dict[str, int]:
  126. """
  127. 优先使用 Kokoro 官方 config.json 的 vocab。
  128. 如果本地缺失,则回退到 tokenizer.json 里的 vocab。
  129. """
  130. config_file = Path(CONFIG_PATH)
  131. if config_file.exists():
  132. try:
  133. data = json.loads(config_file.read_text(encoding="utf-8"))
  134. vocab = data["vocab"]
  135. if isinstance(vocab, dict) and vocab:
  136. return {str(k): int(v) for k, v in vocab.items()}
  137. except Exception as e:
  138. logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e)
  139. tokenizer_file = Path(TOKENIZER_PATH)
  140. if not tokenizer_file.exists():
  141. raise RuntimeError(
  142. f"缺少 vocab 文件: {config_file} / {tokenizer_file}. "
  143. f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。"
  144. )
  145. try:
  146. data = json.loads(tokenizer_file.read_text(encoding="utf-8"))
  147. vocab = data["model"]["vocab"]
  148. if not isinstance(vocab, dict) or not vocab:
  149. raise ValueError("tokenizer vocab 为空或格式异常")
  150. return {str(k): int(v) for k, v in vocab.items()}
  151. except Exception as e:
  152. raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e
  153. def get_vocab() -> Dict[str, int]:
  154. global _VOCAB_CACHE
  155. if _VOCAB_CACHE is None:
  156. _VOCAB_CACHE = load_vocab()
  157. return _VOCAB_CACHE
  158. def get_en_g2p_pipeline():
  159. global _EN_G2P_PIPELINE
  160. if _EN_G2P_PIPELINE is None:
  161. from kokoro import KPipeline # type: ignore
  162. _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
  163. return _EN_G2P_PIPELINE
  164. def resolve_voice_path(voice: str) -> Path:
  165. voice_name = voice.strip()
  166. if not voice_name:
  167. voice_name = "af_heart"
  168. voice_name = VOICE_ALIASES.get(voice_name, voice_name)
  169. if not voice_name.endswith(".bin"):
  170. voice_name = f"{voice_name}.bin"
  171. voice_path = Path(VOICES_DIR) / voice_name
  172. if voice_path.exists():
  173. return voice_path
  174. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  175. if fallback_path.exists():
  176. logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name)
  177. return fallback_path
  178. raise FileNotFoundError(
  179. f"找不到 voice 文件: {voice_path}. "
  180. "请从模型仓库下载 voices/*.bin 到本地 voices 目录。"
  181. )
  182. def _download_voice_file(voice_path: Path) -> Path:
  183. voice_name = voice_path.name
  184. raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}"
  185. voice_path.parent.mkdir(parents=True, exist_ok=True)
  186. tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
  187. logger.info("从官方地址下载 voice 文件: %s", raw_url)
  188. with urllib.request.urlopen(raw_url, timeout=60) as response:
  189. content_type = response.headers.get_content_type()
  190. payload = response.read()
  191. if content_type == "text/html" or payload.startswith(b"<!doctype html") or payload.startswith(b"<html"):
  192. raise RuntimeError(
  193. f"下载到的仍是 HTML 页面而不是 voice 二进制文件: {raw_url}. "
  194. "请确认模型仓库的 voices 文件可直接通过 raw 链接访问。"
  195. )
  196. with open(tmp_path, "wb") as f:
  197. f.write(payload)
  198. tmp_path.replace(voice_path)
  199. return voice_path
  200. def _is_html_payload(data: bytes) -> bool:
  201. return data.startswith(b"<!doctype html") or data.startswith(b"<html")
  202. def load_onnx_session(name: str):
  203. try:
  204. import onnxruntime as ort # type: ignore
  205. except Exception as e:
  206. raise RuntimeError(
  207. "无法导入 onnxruntime。请在部署环境中安装 onnxruntime。"
  208. ) from e
  209. model_path = resolve_model_path(name)
  210. providers = ["CPUExecutionProvider"]
  211. session = ort.InferenceSession(model_path, providers=providers)
  212. logger.info("ONNX 模型加载完成: %s", model_path)
  213. return session
  214. def load_kokoro_engine(name: str):
  215. try:
  216. from kokoro_onnx import Kokoro # type: ignore
  217. except Exception as e:
  218. raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
  219. model_path = resolve_model_path(name)
  220. engine = Kokoro(
  221. model_path=model_path,
  222. voices_path=VOICES_V1_PATH,
  223. vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None,
  224. )
  225. logger.info("kokoro-onnx 引擎加载完成: %s", model_path)
  226. return engine
  227. def load_model(force_reload: bool = False, name: Optional[str] = None):
  228. global model_session, model_name, _KOKORO_ONNX_ENGINE
  229. with model_lock:
  230. target_name = name or model_name
  231. if force_reload or model_session is None or target_name != model_name:
  232. model_session = load_onnx_session(target_name)
  233. _KOKORO_ONNX_ENGINE = load_kokoro_engine(target_name)
  234. model_name = target_name
  235. return model_session
  236. def get_kokoro_engine(name: Optional[str] = None):
  237. global _KOKORO_ONNX_ENGINE
  238. target_name = name or model_name
  239. if _KOKORO_ONNX_ENGINE is None or target_name != model_name:
  240. load_model(name=target_name)
  241. return _KOKORO_ONNX_ENGINE
  242. @asynccontextmanager
  243. async def lifespan(app: FastAPI):
  244. try:
  245. load_model()
  246. yield
  247. finally:
  248. logger.info("应用关闭中...")
  249. app = FastAPI(title="Online TTS Service (Kokoro ONNX)", lifespan=lifespan)
  250. app.add_middleware(
  251. CORSMiddleware,
  252. allow_origins=["*"],
  253. allow_credentials=True,
  254. allow_methods=["*"],
  255. allow_headers=["*"],
  256. )
  257. @app.middleware("http")
  258. async def track_clients(request: Request, call_next):
  259. client_id = (
  260. request.query_params.get("client_id")
  261. or request.headers.get("X-Client-ID")
  262. or str(uuid.uuid4())
  263. )
  264. if request.url.path == "/generate":
  265. current_requests[client_id] = {"active": True}
  266. response = await call_next(request)
  267. if request.url.path == "/generate":
  268. current_requests.pop(client_id, None)
  269. return response
  270. class TTSRequest(BaseModel):
  271. text: str
  272. voice: Optional[str] = "af_heart"
  273. speed: Optional[float] = 1.0
  274. split_pattern: Optional[str] = r"\n+"
  275. model_name: Optional[str] = None
  276. @field_validator("speed")
  277. @classmethod
  278. def validate_speed(cls, v):
  279. if v is None:
  280. return 1.0
  281. v = float(v)
  282. if v <= 0:
  283. raise ValueError("speed 必须大于 0")
  284. return v
  285. def _encode_token_ids(text: str) -> List[int]:
  286. vocab = get_vocab()
  287. unknown_id = vocab.get(" ", 16)
  288. return [vocab.get(ch, unknown_id) for ch in text]
  289. def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]:
  290. pipeline = get_en_g2p_pipeline()
  291. _, tokens = pipeline.g2p(text)
  292. chunks = list(pipeline.en_tokenize(tokens))
  293. if not chunks:
  294. raise RuntimeError("英文文本音素化失败,未生成 phonemes。")
  295. return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks]
  296. def _phonemize_en_for_onnx(text: str) -> str:
  297. return _phonemize_en_chunks_for_onnx(text)[0][1]
  298. def _tokenize_for_onnx(text: str, voice: str):
  299. """
  300. 将文本转成 ONNX 所需 token。
  301. 这里优先复用 kokoro 的预处理链。
  302. 如果你项目里有更稳定的 tokenizer,可以替换这个函数。
  303. """
  304. normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
  305. if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")):
  306. phonemes = _phonemize_en_for_onnx(text)
  307. tokens = _encode_token_ids(phonemes)
  308. else:
  309. tokens = _encode_token_ids(text)
  310. if not tokens:
  311. raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。")
  312. if len(tokens) > 510:
  313. raise RuntimeError(
  314. f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。"
  315. )
  316. return np.asarray([[0, *tokens, 0]], dtype=np.int64)
  317. def _is_english_voice(voice: str) -> bool:
  318. normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
  319. return normalized_voice.startswith(("af_", "am_", "bf_", "bm_"))
  320. def _style_for_voice(voice: str) -> np.ndarray:
  321. """
  322. ONNX 模型需要 style 向量。
  323. 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。
  324. """
  325. voice_path = resolve_voice_path(voice)
  326. if voice_path.suffix == ".bin":
  327. try:
  328. header = voice_path.read_bytes()[:64]
  329. if _is_html_payload(header):
  330. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  331. if voice_path.name != fallback_path.name and fallback_path.exists():
  332. logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name)
  333. voice_path = fallback_path
  334. else:
  335. voice_path = _download_voice_file(voice_path)
  336. except Exception as e:
  337. logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e)
  338. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  339. if voice_path.name != fallback_path.name and fallback_path.exists():
  340. voice_path = fallback_path
  341. else:
  342. voice_path = _download_voice_file(voice_path)
  343. style = np.fromfile(str(voice_path), dtype=np.float32)
  344. if style.size == 0:
  345. raise RuntimeError(f"voice 文件为空: {voice_path}")
  346. if style.size % 256 != 0:
  347. # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。
  348. voice_path = _download_voice_file(voice_path)
  349. style = np.fromfile(str(voice_path), dtype=np.float32)
  350. if style.size == 0 or style.size % 256 != 0:
  351. raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}")
  352. style = style.reshape(-1, 256)
  353. return style
  354. def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray:
  355. if style.ndim != 2 or style.shape[-1] != 256:
  356. raise RuntimeError(f"style 维度异常: {style.shape}")
  357. # Follow the official ONNX example: ref_s = voices[len(tokens)]
  358. idx = min(max(token_len, 0), style.shape[0] - 1)
  359. return style[idx : idx + 1]
  360. def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]:
  361. for input_name in ("style", "ref_s"):
  362. for model_input in session.get_inputs():
  363. if model_input.name != input_name:
  364. continue
  365. input_shape = model_input.shape
  366. expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim
  367. if expected_rank == 2:
  368. return input_name, style_slice.astype(np.float32, copy=False)
  369. if expected_rank == 3:
  370. return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False)
  371. raise RuntimeError(
  372. f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}"
  373. )
  374. return None, None
  375. def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
  376. if not text.strip():
  377. raise HTTPException(status_code=400, detail="文本不能为空")
  378. engine = get_kokoro_engine(name=model_name)
  379. session = load_model(name=model_name)
  380. available_voices = set(engine.get_voices())
  381. if voice not in available_voices:
  382. raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
  383. phonemes = engine.tokenizer.phonemize(text, "en-us")
  384. batched_phonemes = engine._split_phonemes(phonemes)
  385. if not batched_phonemes:
  386. raise HTTPException(status_code=400, detail="文本音素化失败")
  387. voice_style = engine.get_voice_style(voice)
  388. audio_segments: List[np.ndarray] = []
  389. for phoneme_batch in batched_phonemes:
  390. tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
  391. if tokens.size == 0:
  392. continue
  393. style = voice_style[len(tokens)]
  394. feeds = {
  395. "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
  396. "style": np.asarray(style, dtype=np.float32),
  397. "speed": np.asarray([speed], dtype=np.float32),
  398. }
  399. outputs = session.run(None, feeds)
  400. if outputs:
  401. audio_segments.append(to_mono_numpy(outputs[0]))
  402. if not audio_segments:
  403. raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
  404. audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
  405. if audio.size == 0 or not np.isfinite(audio).all():
  406. raise HTTPException(status_code=500, detail="生成的音频无效")
  407. return audio
  408. def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
  409. segments: List[np.ndarray] = []
  410. if _is_english_voice(voice):
  411. parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
  412. else:
  413. parts = split_sentences(text) if text.strip() else []
  414. if not parts:
  415. parts = [text.strip()]
  416. with synthesis_lock:
  417. for part in parts:
  418. audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
  419. if audio.size > 0:
  420. segments.append(audio)
  421. if not segments:
  422. raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
  423. audio_concat = np.concatenate(segments, axis=0)
  424. buf = io.BytesIO()
  425. sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
  426. buf.seek(0)
  427. return buf
  428. @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
  429. def tts_post(req: TTSRequest):
  430. buf = synthesize_wav_bytes(
  431. text=req.text,
  432. voice=req.voice or "af_heart",
  433. speed=req.speed if req.speed is not None else 1.0,
  434. model_name=req.model_name,
  435. )
  436. return StreamingResponse(
  437. buf,
  438. media_type="audio/wav",
  439. headers={"Content-Disposition": 'inline; filename="tts.wav"'},
  440. )
  441. @app.get("/tts", summary="GET: 传入文本返回 WAV 流")
  442. def tts_get(
  443. text: str = Query(..., description="待合成文本"),
  444. voice: str = Query("af_heart"),
  445. speed: float = Query(1.0),
  446. model_name: str = Query(DEFAULT_MODEL_NAME),
  447. ):
  448. if speed is None or float(speed) <= 0:
  449. raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
  450. buf = synthesize_wav_bytes(
  451. text=text,
  452. voice=voice,
  453. speed=float(speed),
  454. model_name=model_name,
  455. )
  456. return StreamingResponse(
  457. buf,
  458. media_type="audio/wav",
  459. headers={"Content-Disposition": 'inline; filename="tts.wav"'},
  460. )
  461. @app.post("/generate")
  462. async def generate_audio_stream(data: Dict = Body(...)):
  463. async with request_semaphore:
  464. text = data.get("text", "")
  465. voice = data.get("voice", "af_heart")
  466. speed = float(data.get("speed", 1.0))
  467. model = data.get("model_name", DEFAULT_MODEL_NAME)
  468. client_id = data.get("client_id", str(uuid.uuid4()))
  469. if not text.strip():
  470. raise HTTPException(status_code=400, detail="文本不能为空")
  471. if _is_english_voice(voice):
  472. parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
  473. else:
  474. parts = split_sentences(text) if text.strip() else []
  475. if not parts:
  476. parts = [text.strip()]
  477. if client_id in current_requests:
  478. current_requests[client_id]["interrupt"] = True
  479. await asyncio.sleep(0.05)
  480. current_requests[client_id] = {"interrupt": False}
  481. async def stream():
  482. try:
  483. for idx, part in enumerate(parts):
  484. if not part:
  485. continue
  486. if current_requests.get(client_id, {}).get("interrupt"):
  487. break
  488. audio = await asyncio.to_thread(
  489. synthesize_audio,
  490. text=part,
  491. voice=voice,
  492. speed=speed,
  493. model_name=model,
  494. )
  495. with io.BytesIO() as buf:
  496. with sf.SoundFile(
  497. buf,
  498. "w",
  499. sample_rate,
  500. channels=audio.shape[1] if audio.ndim > 1 else 1,
  501. format="WAV",
  502. subtype="PCM_16",
  503. ) as f:
  504. f.write(audio)
  505. audio_b64 = base64.b64encode(buf.getvalue()).decode()
  506. yield json.dumps(
  507. {
  508. "index": idx,
  509. "sentence": part,
  510. "audio": audio_b64,
  511. "sample_rate": sample_rate,
  512. }
  513. ).encode() + b"\n"
  514. finally:
  515. current_requests.pop(client_id, None)
  516. return StreamingResponse(stream(), media_type="application/x-ndjson")
  517. if __name__ == "__main__":
  518. import uvicorn
  519. uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)