speech_tts_onnx.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. from __future__ import annotations
  2. import base64
  3. import hashlib
  4. import io
  5. import json
  6. import logging
  7. import os
  8. import re
  9. import asyncio
  10. import threading
  11. import uuid
  12. import urllib.request
  13. from contextlib import asynccontextmanager
  14. from pathlib import Path
  15. from typing import Dict, List, Optional, Tuple
  16. import numpy as np
  17. import soundfile as sf
  18. from fastapi import Body, FastAPI, HTTPException, Query, Request
  19. from fastapi.middleware.cors import CORSMiddleware
  20. from fastapi.responses import StreamingResponse
  21. from pydantic import BaseModel, field_validator
  22. # ============================================================
  23. # 配置
  24. # ============================================================
  25. DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
  26. MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
  27. HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
  28. TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
  29. CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
  30. VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
  31. VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
  32. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
  33. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
  34. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
  35. MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
  36. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
  37. DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
  38. VOICE_ALIASES = {}
  39. logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
  40. logger = logging.getLogger("speech_tts_onnx")
  41. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  42. SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+")
  43. _TOKENIZER_CACHE = None
  44. _VOCAB_CACHE = None
  45. _EN_G2P_PIPELINE = None
  46. _KOKORO_ONNX_ENGINE = None
  47. # ============================================================
  48. # 运行时状态
  49. # ============================================================
  50. model_lock = threading.Lock()
  51. synthesis_lock = threading.Lock()
  52. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  53. current_requests: Dict[str, Dict] = {}
  54. model_session = None
  55. model_name = DEFAULT_MODEL_NAME
  56. sample_rate = DEFAULT_SAMPLE_RATE
  57. # ============================================================
  58. # 工具
  59. # ============================================================
  60. def split_sentences(text: str) -> List[str]:
  61. return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  62. def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
  63. raw = f"{sentence}|{voice}|{speed}|{model}"
  64. return hashlib.md5(raw.encode("utf-8")).hexdigest()
  65. def sentence_cache_path(key: str) -> str:
  66. return os.path.join(CACHE_DIR, f"{key}.wav")
  67. def meta_cache_path(key: str) -> str:
  68. return os.path.join(CACHE_DIR, f"{key}.json")
  69. def to_mono_numpy(audio) -> np.ndarray:
  70. if audio is None:
  71. return np.array([], dtype=np.float32)
  72. if isinstance(audio, np.ndarray):
  73. arr = audio
  74. else:
  75. try:
  76. arr = np.asarray(audio)
  77. except Exception:
  78. return np.array([], dtype=np.float32)
  79. arr = np.asarray(arr)
  80. if arr.size == 0:
  81. return np.array([], dtype=np.float32)
  82. if arr.ndim == 2:
  83. if arr.shape[0] == 1:
  84. arr = arr[0]
  85. elif arr.shape[1] == 1:
  86. arr = arr[:, 0]
  87. else:
  88. arr = arr.mean(axis=1)
  89. elif arr.ndim > 2:
  90. arr = arr.reshape(-1)
  91. if arr.ndim == 0:
  92. arr = arr.reshape(1)
  93. if arr.dtype != np.float32:
  94. arr = arr.astype(np.float32, copy=False)
  95. return arr
  96. def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
  97. audio, sr = sf.read(wav_path)
  98. buf = io.BytesIO()
  99. with sf.SoundFile(
  100. buf,
  101. "w",
  102. samplerate=sr,
  103. channels=audio.shape[1] if audio.ndim > 1 else 1,
  104. format="WAV",
  105. subtype="PCM_16",
  106. ) as f:
  107. f.write(audio)
  108. return buf.getvalue(), sr
  109. # ============================================================
  110. # 模型加载
  111. # ============================================================
  112. def resolve_model_path(name: str) -> str:
  113. local_path = Path(MODEL_DIR) / name
  114. if local_path.exists():
  115. return str(local_path)
  116. if os.path.isabs(name) and Path(name).exists():
  117. return name
  118. raise FileNotFoundError(
  119. f"找不到模型文件: {local_path}"
  120. )
  121. def load_vocab() -> Dict[str, int]:
  122. """
  123. 优先使用 Kokoro 官方 config.json 的 vocab。
  124. 如果本地缺失,则回退到 tokenizer.json 里的 vocab。
  125. """
  126. config_file = Path(CONFIG_PATH)
  127. if config_file.exists():
  128. try:
  129. data = json.loads(config_file.read_text(encoding="utf-8"))
  130. vocab = data["vocab"]
  131. if isinstance(vocab, dict) and vocab:
  132. return {str(k): int(v) for k, v in vocab.items()}
  133. except Exception as e:
  134. logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e)
  135. tokenizer_file = Path(TOKENIZER_PATH)
  136. if not tokenizer_file.exists():
  137. raise RuntimeError(
  138. f"缺少 vocab 文件: {config_file} / {tokenizer_file}. "
  139. f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。"
  140. )
  141. try:
  142. data = json.loads(tokenizer_file.read_text(encoding="utf-8"))
  143. vocab = data["model"]["vocab"]
  144. if not isinstance(vocab, dict) or not vocab:
  145. raise ValueError("tokenizer vocab 为空或格式异常")
  146. return {str(k): int(v) for k, v in vocab.items()}
  147. except Exception as e:
  148. raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e
  149. def get_vocab() -> Dict[str, int]:
  150. global _VOCAB_CACHE
  151. if _VOCAB_CACHE is None:
  152. _VOCAB_CACHE = load_vocab()
  153. return _VOCAB_CACHE
  154. def get_en_g2p_pipeline():
  155. global _EN_G2P_PIPELINE
  156. if _EN_G2P_PIPELINE is None:
  157. from kokoro import KPipeline # type: ignore
  158. _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
  159. return _EN_G2P_PIPELINE
  160. def resolve_voice_path(voice: str) -> Path:
  161. voice_name = voice.strip()
  162. if not voice_name:
  163. voice_name = "af_heart"
  164. voice_name = VOICE_ALIASES.get(voice_name, voice_name)
  165. if not voice_name.endswith(".bin"):
  166. voice_name = f"{voice_name}.bin"
  167. voice_path = Path(VOICES_DIR) / voice_name
  168. if voice_path.exists():
  169. return voice_path
  170. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  171. if fallback_path.exists():
  172. logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name)
  173. return fallback_path
  174. raise FileNotFoundError(
  175. f"找不到 voice 文件: {voice_path}. "
  176. "请从模型仓库下载 voices/*.bin 到本地 voices 目录。"
  177. )
  178. def _download_voice_file(voice_path: Path) -> Path:
  179. voice_name = voice_path.name
  180. raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}"
  181. voice_path.parent.mkdir(parents=True, exist_ok=True)
  182. tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
  183. logger.info("从官方地址下载 voice 文件: %s", raw_url)
  184. with urllib.request.urlopen(raw_url, timeout=60) as response:
  185. content_type = response.headers.get_content_type()
  186. payload = response.read()
  187. if content_type == "text/html" or payload.startswith(b"<!doctype html") or payload.startswith(b"<html"):
  188. raise RuntimeError(
  189. f"下载到的仍是 HTML 页面而不是 voice 二进制文件: {raw_url}. "
  190. "请确认模型仓库的 voices 文件可直接通过 raw 链接访问。"
  191. )
  192. with open(tmp_path, "wb") as f:
  193. f.write(payload)
  194. tmp_path.replace(voice_path)
  195. return voice_path
  196. def _is_html_payload(data: bytes) -> bool:
  197. return data.startswith(b"<!doctype html") or data.startswith(b"<html")
  198. def load_onnx_session(name: str):
  199. try:
  200. import onnxruntime as ort # type: ignore
  201. except Exception as e:
  202. raise RuntimeError(
  203. "无法导入 onnxruntime。请在部署环境中安装 onnxruntime。"
  204. ) from e
  205. model_path = resolve_model_path(name)
  206. providers = ["CPUExecutionProvider"]
  207. session = ort.InferenceSession(model_path, providers=providers)
  208. logger.info("ONNX 模型加载完成: %s", model_path)
  209. return session
  210. def load_kokoro_engine(name: str):
  211. try:
  212. from kokoro_onnx import Kokoro # type: ignore
  213. except Exception as e:
  214. raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
  215. model_path = resolve_model_path(name)
  216. engine = Kokoro(
  217. model_path=model_path,
  218. voices_path=VOICES_V1_PATH,
  219. vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None,
  220. )
  221. logger.info("kokoro-onnx 引擎加载完成: %s", model_path)
  222. return engine
  223. def load_model(force_reload: bool = False, name: Optional[str] = None):
  224. global model_session, model_name, _KOKORO_ONNX_ENGINE
  225. with model_lock:
  226. target_name = name or model_name
  227. if force_reload or model_session is None or target_name != model_name:
  228. model_session = load_onnx_session(target_name)
  229. _KOKORO_ONNX_ENGINE = load_kokoro_engine(target_name)
  230. model_name = target_name
  231. return model_session
  232. def get_kokoro_engine(name: Optional[str] = None):
  233. global _KOKORO_ONNX_ENGINE
  234. target_name = name or model_name
  235. if _KOKORO_ONNX_ENGINE is None or target_name != model_name:
  236. load_model(name=target_name)
  237. return _KOKORO_ONNX_ENGINE
  238. @asynccontextmanager
  239. async def lifespan(app: FastAPI):
  240. try:
  241. load_model()
  242. yield
  243. finally:
  244. logger.info("应用关闭中...")
  245. app = FastAPI(title="Online TTS Service (Kokoro ONNX)", lifespan=lifespan)
  246. app.add_middleware(
  247. CORSMiddleware,
  248. allow_origins=["*"],
  249. allow_credentials=True,
  250. allow_methods=["*"],
  251. allow_headers=["*"],
  252. )
  253. @app.middleware("http")
  254. async def track_clients(request: Request, call_next):
  255. client_id = (
  256. request.query_params.get("client_id")
  257. or request.headers.get("X-Client-ID")
  258. or str(uuid.uuid4())
  259. )
  260. if request.url.path == "/generate":
  261. current_requests[client_id] = {"active": True}
  262. response = await call_next(request)
  263. if request.url.path == "/generate":
  264. current_requests.pop(client_id, None)
  265. return response
  266. class TTSRequest(BaseModel):
  267. text: str
  268. voice: Optional[str] = "af_heart"
  269. speed: Optional[float] = 1.0
  270. split_pattern: Optional[str] = r"\n+"
  271. model_name: Optional[str] = None
  272. @field_validator("speed")
  273. @classmethod
  274. def validate_speed(cls, v):
  275. if v is None:
  276. return 1.0
  277. v = float(v)
  278. if v <= 0:
  279. raise ValueError("speed 必须大于 0")
  280. return v
  281. def _encode_token_ids(text: str) -> List[int]:
  282. vocab = get_vocab()
  283. unknown_id = vocab.get(" ", 16)
  284. return [vocab.get(ch, unknown_id) for ch in text]
  285. def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]:
  286. pipeline = get_en_g2p_pipeline()
  287. _, tokens = pipeline.g2p(text)
  288. chunks = list(pipeline.en_tokenize(tokens))
  289. if not chunks:
  290. raise RuntimeError("英文文本音素化失败,未生成 phonemes。")
  291. return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks]
  292. def _phonemize_en_for_onnx(text: str) -> str:
  293. return _phonemize_en_chunks_for_onnx(text)[0][1]
  294. def _tokenize_for_onnx(text: str, voice: str):
  295. """
  296. 将文本转成 ONNX 所需 token。
  297. 这里优先复用 kokoro 的预处理链。
  298. 如果你项目里有更稳定的 tokenizer,可以替换这个函数。
  299. """
  300. normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
  301. if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")):
  302. phonemes = _phonemize_en_for_onnx(text)
  303. tokens = _encode_token_ids(phonemes)
  304. else:
  305. tokens = _encode_token_ids(text)
  306. if not tokens:
  307. raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。")
  308. if len(tokens) > 510:
  309. raise RuntimeError(
  310. f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。"
  311. )
  312. return np.asarray([[0, *tokens, 0]], dtype=np.int64)
  313. def _is_english_voice(voice: str) -> bool:
  314. normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
  315. return normalized_voice.startswith(("af_", "am_", "bf_", "bm_"))
  316. def _style_for_voice(voice: str) -> np.ndarray:
  317. """
  318. ONNX 模型需要 style 向量。
  319. 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。
  320. """
  321. voice_path = resolve_voice_path(voice)
  322. if voice_path.suffix == ".bin":
  323. try:
  324. header = voice_path.read_bytes()[:64]
  325. if _is_html_payload(header):
  326. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  327. if voice_path.name != fallback_path.name and fallback_path.exists():
  328. logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name)
  329. voice_path = fallback_path
  330. else:
  331. voice_path = _download_voice_file(voice_path)
  332. except Exception as e:
  333. logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e)
  334. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  335. if voice_path.name != fallback_path.name and fallback_path.exists():
  336. voice_path = fallback_path
  337. else:
  338. voice_path = _download_voice_file(voice_path)
  339. style = np.fromfile(str(voice_path), dtype=np.float32)
  340. if style.size == 0:
  341. raise RuntimeError(f"voice 文件为空: {voice_path}")
  342. if style.size % 256 != 0:
  343. # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。
  344. voice_path = _download_voice_file(voice_path)
  345. style = np.fromfile(str(voice_path), dtype=np.float32)
  346. if style.size == 0 or style.size % 256 != 0:
  347. raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}")
  348. style = style.reshape(-1, 256)
  349. return style
  350. def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray:
  351. if style.ndim != 2 or style.shape[-1] != 256:
  352. raise RuntimeError(f"style 维度异常: {style.shape}")
  353. # Follow the official ONNX example: ref_s = voices[len(tokens)]
  354. idx = min(max(token_len, 0), style.shape[0] - 1)
  355. return style[idx : idx + 1]
  356. def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]:
  357. for input_name in ("style", "ref_s"):
  358. for model_input in session.get_inputs():
  359. if model_input.name != input_name:
  360. continue
  361. input_shape = model_input.shape
  362. expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim
  363. if expected_rank == 2:
  364. return input_name, style_slice.astype(np.float32, copy=False)
  365. if expected_rank == 3:
  366. return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False)
  367. raise RuntimeError(
  368. f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}"
  369. )
  370. return None, None
  371. def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
  372. if not text.strip():
  373. raise HTTPException(status_code=400, detail="文本不能为空")
  374. engine = get_kokoro_engine(name=model_name)
  375. session = load_model(name=model_name)
  376. available_voices = set(engine.get_voices())
  377. if voice not in available_voices:
  378. raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
  379. phonemes = engine.tokenizer.phonemize(text, "en-us")
  380. batched_phonemes = engine._split_phonemes(phonemes)
  381. if not batched_phonemes:
  382. raise HTTPException(status_code=400, detail="文本音素化失败")
  383. voice_style = engine.get_voice_style(voice)
  384. audio_segments: List[np.ndarray] = []
  385. for phoneme_batch in batched_phonemes:
  386. tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
  387. if tokens.size == 0:
  388. continue
  389. style = voice_style[len(tokens)]
  390. feeds = {
  391. "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
  392. "style": np.asarray(style, dtype=np.float32),
  393. "speed": np.asarray([speed], dtype=np.float32),
  394. }
  395. outputs = session.run(None, feeds)
  396. if outputs:
  397. audio_segments.append(to_mono_numpy(outputs[0]))
  398. if not audio_segments:
  399. raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
  400. audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
  401. if audio.size == 0 or not np.isfinite(audio).all():
  402. raise HTTPException(status_code=500, detail="生成的音频无效")
  403. return audio
  404. def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
  405. segments: List[np.ndarray] = []
  406. if _is_english_voice(voice):
  407. parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
  408. else:
  409. parts = split_sentences(text) if text.strip() else []
  410. if not parts:
  411. parts = [text.strip()]
  412. with synthesis_lock:
  413. for part in parts:
  414. audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
  415. if audio.size > 0:
  416. segments.append(audio)
  417. if not segments:
  418. raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
  419. audio_concat = np.concatenate(segments, axis=0)
  420. buf = io.BytesIO()
  421. sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
  422. buf.seek(0)
  423. return buf
  424. @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
  425. def tts_post(req: TTSRequest):
  426. buf = synthesize_wav_bytes(
  427. text=req.text,
  428. voice=req.voice or "af_heart",
  429. speed=req.speed if req.speed is not None else 1.0,
  430. model_name=req.model_name,
  431. )
  432. return StreamingResponse(
  433. buf,
  434. media_type="audio/wav",
  435. headers={"Content-Disposition": 'inline; filename="tts.wav"'},
  436. )
  437. @app.get("/tts", summary="GET: 传入文本返回 WAV 流")
  438. def tts_get(
  439. text: str = Query(..., description="待合成文本"),
  440. voice: str = Query("af_heart"),
  441. speed: float = Query(1.0),
  442. model_name: str = Query(DEFAULT_MODEL_NAME),
  443. ):
  444. if speed is None or float(speed) <= 0:
  445. raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
  446. buf = synthesize_wav_bytes(
  447. text=text,
  448. voice=voice,
  449. speed=float(speed),
  450. model_name=model_name,
  451. )
  452. return StreamingResponse(
  453. buf,
  454. media_type="audio/wav",
  455. headers={"Content-Disposition": 'inline; filename="tts.wav"'},
  456. )
  457. @app.post("/generate")
  458. async def generate_audio_stream(data: Dict = Body(...)):
  459. async with request_semaphore:
  460. text = data.get("text", "")
  461. voice = data.get("voice", "af_heart")
  462. speed = float(data.get("speed", 1.0))
  463. model = data.get("model_name", DEFAULT_MODEL_NAME)
  464. client_id = data.get("client_id", str(uuid.uuid4()))
  465. if not text.strip():
  466. raise HTTPException(status_code=400, detail="文本不能为空")
  467. if _is_english_voice(voice):
  468. parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
  469. else:
  470. parts = split_sentences(text) if text.strip() else []
  471. if not parts:
  472. parts = [text.strip()]
  473. if client_id in current_requests:
  474. current_requests[client_id]["interrupt"] = True
  475. await asyncio.sleep(0.05)
  476. current_requests[client_id] = {"interrupt": False}
  477. async def stream():
  478. try:
  479. for idx, part in enumerate(parts):
  480. if not part:
  481. continue
  482. if current_requests.get(client_id, {}).get("interrupt"):
  483. break
  484. audio = await asyncio.to_thread(
  485. synthesize_audio,
  486. text=part,
  487. voice=voice,
  488. speed=speed,
  489. model_name=model,
  490. )
  491. with io.BytesIO() as buf:
  492. with sf.SoundFile(
  493. buf,
  494. "w",
  495. sample_rate,
  496. channels=audio.shape[1] if audio.ndim > 1 else 1,
  497. format="WAV",
  498. subtype="PCM_16",
  499. ) as f:
  500. f.write(audio)
  501. audio_b64 = base64.b64encode(buf.getvalue()).decode()
  502. yield json.dumps(
  503. {
  504. "index": idx,
  505. "sentence": part,
  506. "audio": audio_b64,
  507. "sample_rate": sample_rate,
  508. }
  509. ).encode() + b"\n"
  510. finally:
  511. current_requests.pop(client_id, None)
  512. return StreamingResponse(stream(), media_type="application/x-ndjson")
  513. if __name__ == "__main__":
  514. import uvicorn
  515. uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)