speech_tts_onnx.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. from __future__ import annotations
  2. import base64
  3. import hashlib
  4. import io
  5. import json
  6. import logging
  7. import os
  8. import re
  9. import asyncio
  10. import threading
  11. import uuid
  12. import urllib.request
  13. from contextlib import asynccontextmanager
  14. from pathlib import Path
  15. from typing import Dict, List, Optional, Tuple
  16. import numpy as np
  17. import soundfile as sf
  18. import aiofiles
  19. from fastapi import Body, FastAPI, HTTPException, Query, Request
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from fastapi.responses import StreamingResponse
  22. from pydantic import BaseModel, field_validator
  23. from cachetools import TTLCache
  24. import concurrent.futures
  25. # ============================================================
  26. # 配置 model_quantized
  27. # ============================================================
  28. DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
  29. MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
  30. HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
  31. TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
  32. CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
  33. VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
  34. VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
  35. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
  36. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
  37. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
  38. MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
  39. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
  40. DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
  41. MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
  42. ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
  43. ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
  44. ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
  45. ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "true").lower() == "true"
  46. VOICE_ALIASES = {}
  47. logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
  48. logger = logging.getLogger("speech_tts_onnx")
  49. class _PhonemizerWordCountMismatchFilter(logging.Filter):
  50. def filter(self, record: logging.LogRecord) -> bool:
  51. return "words count mismatch" not in record.getMessage()
  52. logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
  53. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  54. SENT_SPLIT_RE = re.compile(r"(?<=[.!?,:])\s+")
  55. _TOKENIZER_CACHE = None
  56. _VOCAB_CACHE = None
  57. _EN_G2P_PIPELINE = None
  58. _KOKORO_ONNX_ENGINE = None
  59. # ============================================================
  60. # 运行时状态
  61. # ============================================================
  62. model_lock = threading.Lock()
  63. synthesis_lock = threading.Lock()
  64. cache_lock = threading.Lock()
  65. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  66. current_requests: Dict[str, Dict] = {}
  67. memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=MEMORY_CACHE_TTL)
  68. executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
  69. model_session = None
  70. model_name = DEFAULT_MODEL_NAME
  71. sample_rate = DEFAULT_SAMPLE_RATE
  72. # ============================================================
  73. # 工具
  74. # ============================================================
  75. def split_sentences(text: str) -> List[str]:
  76. return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  77. def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
  78. raw = f"{sentence}|{voice}|{speed}|{model}"
  79. return hashlib.md5(raw.encode("utf-8")).hexdigest()
  80. def sentence_cache_path(key: str) -> str:
  81. return os.path.join(CACHE_DIR, f"{key}.wav")
  82. def meta_cache_path(key: str) -> str:
  83. return os.path.join(CACHE_DIR, f"{key}.json")
  84. def put_memory_cache(key: str, value: dict):
  85. with cache_lock:
  86. memory_cache[key] = value
  87. def get_memory_cache(key: str) -> Optional[dict]:
  88. with cache_lock:
  89. return memory_cache.get(key)
  90. def to_mono_numpy(audio) -> np.ndarray:
  91. if audio is None:
  92. return np.array([], dtype=np.float32)
  93. if isinstance(audio, np.ndarray):
  94. arr = audio
  95. else:
  96. try:
  97. arr = np.asarray(audio)
  98. except Exception:
  99. return np.array([], dtype=np.float32)
  100. arr = np.asarray(arr)
  101. if arr.size == 0:
  102. return np.array([], dtype=np.float32)
  103. if arr.ndim == 2:
  104. if arr.shape[0] == 1:
  105. arr = arr[0]
  106. elif arr.shape[1] == 1:
  107. arr = arr[:, 0]
  108. else:
  109. arr = arr.mean(axis=1)
  110. elif arr.ndim > 2:
  111. arr = arr.reshape(-1)
  112. if arr.ndim == 0:
  113. arr = arr.reshape(1)
  114. if arr.dtype != np.float32:
  115. arr = arr.astype(np.float32, copy=False)
  116. return arr
  117. def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
  118. audio, sr = sf.read(wav_path)
  119. buf = io.BytesIO()
  120. with sf.SoundFile(
  121. buf,
  122. "w",
  123. samplerate=sr,
  124. channels=audio.shape[1] if audio.ndim > 1 else 1,
  125. format="WAV",
  126. subtype="PCM_16",
  127. ) as f:
  128. f.write(audio)
  129. return buf.getvalue(), sr
  130. async def load_sentence_from_disk(key: str) -> Optional[dict]:
  131. wav_path = sentence_cache_path(key)
  132. meta_path = meta_cache_path(key)
  133. if not Path(wav_path).exists() or not Path(meta_path).exists():
  134. return None
  135. async with aiofiles.open(meta_path, "r") as f:
  136. meta_text = await f.read()
  137. meta = json.loads(meta_text)
  138. wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(
  139. executor, lambda: read_wav_bytes(wav_path)
  140. )
  141. meta["audio_bytes"] = wav_bytes
  142. meta["audio"] = base64.b64encode(wav_bytes).decode()
  143. meta["sample_rate"] = sr
  144. return meta
  145. async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
  146. wav_path = sentence_cache_path(key)
  147. meta_path = meta_cache_path(key)
  148. await asyncio.get_event_loop().run_in_executor(
  149. executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
  150. )
  151. async with aiofiles.open(meta_path, "w") as f:
  152. await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
  153. async def clean_disk_cache():
  154. files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
  155. if len(files) <= DISK_CACHE_SIZE:
  156. return
  157. remove_n = len(files) - DISK_CACHE_SIZE
  158. for f in files[:remove_n]:
  159. try:
  160. f.unlink()
  161. json_path = meta_cache_path(f.stem)
  162. if Path(json_path).exists():
  163. Path(json_path).unlink()
  164. except Exception:
  165. pass
  166. def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
  167. buf = io.BytesIO()
  168. with sf.SoundFile(
  169. buf,
  170. "w",
  171. samplerate=sr,
  172. channels=audio.shape[1] if audio.ndim > 1 else 1,
  173. format="WAV",
  174. subtype="PCM_16",
  175. ) as f:
  176. f.write(audio)
  177. return buf.getvalue()
  178. def wav_to_pcm_payload(wav_bytes: bytes) -> bytes:
  179. data, _ = sf.read(io.BytesIO(wav_bytes), dtype="int16")
  180. return np.asarray(data, dtype=np.int16).tobytes()
  181. # ============================================================
  182. # 模型加载
  183. # ============================================================
  184. def resolve_model_path(name: str) -> str:
  185. local_path = Path(MODEL_DIR) / name
  186. if local_path.exists():
  187. return str(local_path)
  188. if os.path.isabs(name) and Path(name).exists():
  189. return name
  190. raise FileNotFoundError(
  191. f"找不到模型文件: {local_path}"
  192. )
  193. def load_vocab() -> Dict[str, int]:
  194. """
  195. 优先使用 Kokoro 官方 config.json 的 vocab。
  196. 如果本地缺失,则回退到 tokenizer.json 里的 vocab。
  197. """
  198. config_file = Path(CONFIG_PATH)
  199. if config_file.exists():
  200. try:
  201. data = json.loads(config_file.read_text(encoding="utf-8"))
  202. vocab = data["vocab"]
  203. if isinstance(vocab, dict) and vocab:
  204. return {str(k): int(v) for k, v in vocab.items()}
  205. except Exception as e:
  206. logger.warning("加载 config vocab 失败,回退 tokenizer vocab: %s", e)
  207. tokenizer_file = Path(TOKENIZER_PATH)
  208. if not tokenizer_file.exists():
  209. raise RuntimeError(
  210. f"缺少 vocab 文件: {config_file} / {tokenizer_file}. "
  211. f"请从 {HF_MODEL_ID} 或 hexgrad/Kokoro-82M 下载后放到模型目录。"
  212. )
  213. try:
  214. data = json.loads(tokenizer_file.read_text(encoding="utf-8"))
  215. vocab = data["model"]["vocab"]
  216. if not isinstance(vocab, dict) or not vocab:
  217. raise ValueError("tokenizer vocab 为空或格式异常")
  218. return {str(k): int(v) for k, v in vocab.items()}
  219. except Exception as e:
  220. raise RuntimeError(f"无法加载 tokenizer vocab: {tokenizer_file}. 错误: {e}") from e
  221. def get_vocab() -> Dict[str, int]:
  222. global _VOCAB_CACHE
  223. if _VOCAB_CACHE is None:
  224. _VOCAB_CACHE = load_vocab()
  225. return _VOCAB_CACHE
  226. def get_en_g2p_pipeline():
  227. global _EN_G2P_PIPELINE
  228. if _EN_G2P_PIPELINE is None:
  229. from kokoro import KPipeline # type: ignore
  230. _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
  231. return _EN_G2P_PIPELINE
  232. def resolve_voice_path(voice: str) -> Path:
  233. voice_name = voice.strip()
  234. if not voice_name:
  235. voice_name = "af_heart"
  236. voice_name = VOICE_ALIASES.get(voice_name, voice_name)
  237. if not voice_name.endswith(".bin"):
  238. voice_name = f"{voice_name}.bin"
  239. voice_path = Path(VOICES_DIR) / voice_name
  240. if voice_path.exists():
  241. return voice_path
  242. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  243. if fallback_path.exists():
  244. logger.warning("voice %s 不存在,回退到 %s", voice_name, fallback_path.name)
  245. return fallback_path
  246. raise FileNotFoundError(
  247. f"找不到 voice 文件: {voice_path}. "
  248. "请从模型仓库下载 voices/*.bin 到本地 voices 目录。"
  249. )
  250. def _download_voice_file(voice_path: Path) -> Path:
  251. voice_name = voice_path.name
  252. raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_name}"
  253. voice_path.parent.mkdir(parents=True, exist_ok=True)
  254. tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
  255. logger.info("从官方地址下载 voice 文件: %s", raw_url)
  256. with urllib.request.urlopen(raw_url, timeout=60) as response:
  257. content_type = response.headers.get_content_type()
  258. payload = response.read()
  259. if content_type == "text/html" or payload.startswith(b"<!doctype html") or payload.startswith(b"<html"):
  260. raise RuntimeError(
  261. f"下载到的仍是 HTML 页面而不是 voice 二进制文件: {raw_url}. "
  262. "请确认模型仓库的 voices 文件可直接通过 raw 链接访问。"
  263. )
  264. with open(tmp_path, "wb") as f:
  265. f.write(payload)
  266. tmp_path.replace(voice_path)
  267. return voice_path
  268. def _is_html_payload(data: bytes) -> bool:
  269. return data.startswith(b"<!doctype html") or data.startswith(b"<html")
  270. def load_onnx_session(name: str):
  271. try:
  272. import onnxruntime as ort # type: ignore
  273. except Exception as e:
  274. raise RuntimeError(
  275. "无法导入 onnxruntime。请在部署环境中安装 onnxruntime。"
  276. ) from e
  277. model_path = resolve_model_path(name)
  278. providers = ["CPUExecutionProvider"]
  279. sess_options = ort.SessionOptions()
  280. sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
  281. sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
  282. sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
  283. sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
  284. session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
  285. logger.info("ONNX 模型加载完成: %s", model_path)
  286. return session
  287. def load_kokoro_engine(name: str):
  288. try:
  289. from kokoro_onnx import Kokoro # type: ignore
  290. except Exception as e:
  291. raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
  292. model_path = resolve_model_path(name)
  293. engine = Kokoro(
  294. model_path=model_path,
  295. voices_path=VOICES_V1_PATH,
  296. vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None,
  297. )
  298. logger.info("kokoro-onnx 引擎加载完成: %s", model_path)
  299. return engine
  300. def load_model(force_reload: bool = False, name: Optional[str] = None):
  301. global model_session, model_name, _KOKORO_ONNX_ENGINE
  302. with model_lock:
  303. target_name = name or model_name
  304. if force_reload or model_session is None or target_name != model_name:
  305. model_session = load_onnx_session(target_name)
  306. _KOKORO_ONNX_ENGINE = load_kokoro_engine(target_name)
  307. model_name = target_name
  308. return model_session
  309. def get_kokoro_engine(name: Optional[str] = None):
  310. global _KOKORO_ONNX_ENGINE
  311. target_name = name or model_name
  312. if _KOKORO_ONNX_ENGINE is None or target_name != model_name:
  313. load_model(name=target_name)
  314. return _KOKORO_ONNX_ENGINE
  315. @asynccontextmanager
  316. async def lifespan(app: FastAPI):
  317. try:
  318. load_model()
  319. yield
  320. finally:
  321. logger.info("应用关闭中...")
  322. app = FastAPI(title="Online TTS Service (Kokoro ONNX)", lifespan=lifespan)
  323. app.add_middleware(
  324. CORSMiddleware,
  325. allow_origins=["*"],
  326. allow_credentials=True,
  327. allow_methods=["*"],
  328. allow_headers=["*"],
  329. )
  330. @app.middleware("http")
  331. async def track_clients(request: Request, call_next):
  332. client_id = (
  333. request.query_params.get("client_id")
  334. or request.headers.get("X-Client-ID")
  335. or str(uuid.uuid4())
  336. )
  337. if request.url.path == "/generate":
  338. current_requests[client_id] = {"active": True}
  339. response = await call_next(request)
  340. if request.url.path == "/generate":
  341. current_requests.pop(client_id, None)
  342. return response
  343. class TTSRequest(BaseModel):
  344. text: str
  345. voice: Optional[str] = "af_heart"
  346. speed: Optional[float] = 1.0
  347. split_pattern: Optional[str] = r"\n+"
  348. model_name: Optional[str] = None
  349. @field_validator("speed")
  350. @classmethod
  351. def validate_speed(cls, v):
  352. if v is None:
  353. return 1.0
  354. v = float(v)
  355. if v <= 0:
  356. raise ValueError("speed 必须大于 0")
  357. return v
  358. def _encode_token_ids(text: str) -> List[int]:
  359. vocab = get_vocab()
  360. unknown_id = vocab.get(" ", 16)
  361. return [vocab.get(ch, unknown_id) for ch in text]
  362. def _phonemize_en_chunks_for_onnx(text: str) -> List[Tuple[str, str]]:
  363. pipeline = get_en_g2p_pipeline()
  364. _, tokens = pipeline.g2p(text)
  365. chunks = list(pipeline.en_tokenize(tokens))
  366. if not chunks:
  367. raise RuntimeError("英文文本音素化失败,未生成 phonemes。")
  368. return [(graphemes, phonemes) for graphemes, phonemes, _ in chunks]
  369. def _phonemize_en_for_onnx(text: str) -> str:
  370. return _phonemize_en_chunks_for_onnx(text)[0][1]
  371. def _tokenize_for_onnx(text: str, voice: str):
  372. """
  373. 将文本转成 ONNX 所需 token。
  374. 这里优先复用 kokoro 的预处理链。
  375. 如果你项目里有更稳定的 tokenizer,可以替换这个函数。
  376. """
  377. normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
  378. if normalized_voice.startswith(("af_", "am_", "bf_", "bm_")):
  379. phonemes = _phonemize_en_for_onnx(text)
  380. tokens = _encode_token_ids(phonemes)
  381. else:
  382. tokens = _encode_token_ids(text)
  383. if not tokens:
  384. raise RuntimeError("文本编码后得到空 tokens。请检查 tokenizer.json 是否正确。")
  385. if len(tokens) > 510:
  386. raise RuntimeError(
  387. f"文本过长,tokens={len(tokens)},超过模型 512 上限。请拆句后再调用。"
  388. )
  389. return np.asarray([[0, *tokens, 0]], dtype=np.int64)
  390. def _is_english_voice(voice: str) -> bool:
  391. normalized_voice = VOICE_ALIASES.get((voice or "").strip(), (voice or "").strip())
  392. return normalized_voice.startswith(("af_", "am_", "bf_", "bm_"))
  393. def _style_for_voice(voice: str) -> np.ndarray:
  394. """
  395. ONNX 模型需要 style 向量。
  396. 这里先提供一个可运行的占位实现,后续可以替换成正式的 voice embedding 映射。
  397. """
  398. voice_path = resolve_voice_path(voice)
  399. if voice_path.suffix == ".bin":
  400. try:
  401. header = voice_path.read_bytes()[:64]
  402. if _is_html_payload(header):
  403. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  404. if voice_path.name != fallback_path.name and fallback_path.exists():
  405. logger.warning("voice 文件 %s 是 HTML,占位回退到 %s", voice_path.name, fallback_path.name)
  406. voice_path = fallback_path
  407. else:
  408. voice_path = _download_voice_file(voice_path)
  409. except Exception as e:
  410. logger.warning("voice 文件检查失败,尝试直接重新下载: %s", e)
  411. fallback_path = Path(VOICES_DIR) / "af_bella.bin"
  412. if voice_path.name != fallback_path.name and fallback_path.exists():
  413. voice_path = fallback_path
  414. else:
  415. voice_path = _download_voice_file(voice_path)
  416. style = np.fromfile(str(voice_path), dtype=np.float32)
  417. if style.size == 0:
  418. raise RuntimeError(f"voice 文件为空: {voice_path}")
  419. if style.size % 256 != 0:
  420. # 某些仓库文件可能是 xet pointer 或 HTML 页面,重新下载一次兜底。
  421. voice_path = _download_voice_file(voice_path)
  422. style = np.fromfile(str(voice_path), dtype=np.float32)
  423. if style.size == 0 or style.size % 256 != 0:
  424. raise RuntimeError(f"voice 文件维度异常: {voice_path}, size={style.size}")
  425. style = style.reshape(-1, 256)
  426. return style
  427. def _select_style_slice(style: np.ndarray, token_len: int) -> np.ndarray:
  428. if style.ndim != 2 or style.shape[-1] != 256:
  429. raise RuntimeError(f"style 维度异常: {style.shape}")
  430. # Follow the official ONNX example: ref_s = voices[len(tokens)]
  431. idx = min(max(token_len, 0), style.shape[0] - 1)
  432. return style[idx : idx + 1]
  433. def _prepare_style_input(session, style_slice: np.ndarray) -> Tuple[Optional[str], Optional[np.ndarray]]:
  434. for input_name in ("style", "ref_s"):
  435. for model_input in session.get_inputs():
  436. if model_input.name != input_name:
  437. continue
  438. input_shape = model_input.shape
  439. expected_rank = len(input_shape) if input_shape is not None else style_slice.ndim
  440. if expected_rank == 2:
  441. return input_name, style_slice.astype(np.float32, copy=False)
  442. if expected_rank == 3:
  443. return input_name, style_slice[:, np.newaxis, :].astype(np.float32, copy=False)
  444. raise RuntimeError(
  445. f"模型输入 {input_name} 的 rank 不受支持: shape={input_shape}"
  446. )
  447. return None, None
  448. def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
  449. if not text.strip():
  450. raise HTTPException(status_code=400, detail="文本不能为空")
  451. engine = get_kokoro_engine(name=model_name)
  452. session = load_model(name=model_name)
  453. available_voices = set(engine.get_voices())
  454. if voice not in available_voices:
  455. raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
  456. phonemes = engine.tokenizer.phonemize(text, "en-us")
  457. batched_phonemes = engine._split_phonemes(phonemes)
  458. if not batched_phonemes:
  459. raise HTTPException(status_code=400, detail="文本音素化失败")
  460. voice_style = engine.get_voice_style(voice)
  461. audio_segments: List[np.ndarray] = []
  462. for phoneme_batch in batched_phonemes:
  463. tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
  464. if tokens.size == 0:
  465. continue
  466. style = voice_style[len(tokens)]
  467. feeds = {
  468. "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
  469. "style": np.asarray(style, dtype=np.float32),
  470. "speed": np.asarray([speed], dtype=np.float32),
  471. }
  472. outputs = session.run(None, feeds)
  473. if outputs:
  474. audio_segments.append(to_mono_numpy(outputs[0]))
  475. if not audio_segments:
  476. raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
  477. audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
  478. if audio.size == 0 or not np.isfinite(audio).all():
  479. raise HTTPException(status_code=500, detail="生成的音频无效")
  480. return audio
  481. def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
  482. audio = synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
  483. buf = io.BytesIO()
  484. sf.write(buf, audio, samplerate=sample_rate, format="WAV", subtype="PCM_16")
  485. buf.seek(0)
  486. return buf
  487. def iter_text_parts(text: str, split_pattern: Optional[str] = None) -> List[str]:
  488. text = (text or "").strip()
  489. if not text:
  490. return []
  491. blocks = [text]
  492. if split_pattern:
  493. try:
  494. blocks = [part.strip() for part in re.split(split_pattern, text) if part.strip()]
  495. except re.error:
  496. logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
  497. parts: List[str] = []
  498. for block in blocks:
  499. sentences = split_sentences(block)
  500. if sentences:
  501. parts.extend(sentences)
  502. elif block:
  503. parts.append(block)
  504. return parts
  505. async def get_or_create_sentence_cache_item(
  506. sentence: str,
  507. voice: str,
  508. speed: float,
  509. model_name: str,
  510. ) -> dict:
  511. key = sentence_cache_key(sentence, voice, speed, model_name)
  512. cached_item = get_memory_cache(key)
  513. if cached_item:
  514. return cached_item
  515. disk_item = await load_sentence_from_disk(key)
  516. if disk_item:
  517. put_memory_cache(key, disk_item)
  518. return disk_item
  519. def _synthesize_item():
  520. audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model_name)
  521. wav_bytes = encode_wav_bytes(audio, sample_rate)
  522. return {
  523. "sentence": sentence,
  524. "sample_rate": sample_rate,
  525. "audio_bytes": wav_bytes,
  526. "audio": base64.b64encode(wav_bytes).decode(),
  527. }, audio
  528. item, audio = await asyncio.to_thread(_synthesize_item)
  529. put_memory_cache(key, item)
  530. await save_sentence_to_disk(key, audio, sample_rate, sentence)
  531. await clean_disk_cache()
  532. return item
  533. @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
  534. def tts_post(req: TTSRequest):
  535. buf = synthesize_wav_bytes(
  536. text=req.text,
  537. voice=req.voice or "af_heart",
  538. speed=req.speed if req.speed is not None else 1.0,
  539. model_name=req.model_name,
  540. )
  541. return StreamingResponse(
  542. buf,
  543. media_type="audio/wav",
  544. headers={"Content-Disposition": 'inline; filename="tts.wav"'},
  545. )
  546. @app.get("/tts", summary="GET: 传入文本返回 WAV 流")
  547. def tts_get(
  548. text: str = Query(..., description="待合成文本"),
  549. voice: str = Query("af_heart"),
  550. speed: float = Query(1.0),
  551. model_name: str = Query(DEFAULT_MODEL_NAME),
  552. ):
  553. if speed is None or float(speed) <= 0:
  554. raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
  555. buf = synthesize_wav_bytes(
  556. text=text,
  557. voice=voice,
  558. speed=float(speed),
  559. model_name=model_name,
  560. )
  561. return StreamingResponse(
  562. buf,
  563. media_type="audio/wav",
  564. headers={"Content-Disposition": 'inline; filename="tts.wav"'},
  565. )
  566. @app.post("/generate")
  567. async def generate_audio_stream(data: Dict = Body(...)):
  568. async with request_semaphore:
  569. text = data.get("text", "")
  570. voice = data.get("voice", "af_heart")
  571. speed = float(data.get("speed", 1.0))
  572. model = data.get("model_name", DEFAULT_MODEL_NAME)
  573. client_id = data.get("client_id", str(uuid.uuid4()))
  574. if not text.strip():
  575. raise HTTPException(status_code=400, detail="文本不能为空")
  576. parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
  577. if client_id in current_requests:
  578. current_requests[client_id]["interrupt"] = True
  579. await asyncio.sleep(0.05)
  580. current_requests[client_id] = {"interrupt": False}
  581. async def stream():
  582. try:
  583. for idx, part in enumerate(parts):
  584. if current_requests.get(client_id, {}).get("interrupt"):
  585. break
  586. item = await get_or_create_sentence_cache_item(part, voice, speed, model)
  587. yield json.dumps(
  588. {
  589. "index": idx,
  590. "sentence": item["sentence"],
  591. "audio": item["audio"],
  592. "sample_rate": item["sample_rate"],
  593. }
  594. ).encode() + b"\n"
  595. finally:
  596. current_requests.pop(client_id, None)
  597. return StreamingResponse(stream(), media_type="application/x-ndjson")
  598. @app.post("/generate_pcm", summary="POST: 返回更轻量的 PCM 分片流")
  599. async def generate_audio_pcm_stream(data: Dict = Body(...)):
  600. async with request_semaphore:
  601. text = data.get("text", "")
  602. voice = data.get("voice", "af_heart")
  603. speed = float(data.get("speed", 1.0))
  604. model = data.get("model_name", DEFAULT_MODEL_NAME)
  605. if not text.strip():
  606. raise HTTPException(status_code=400, detail="文本不能为空")
  607. parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
  608. async def pcm_stream():
  609. for idx, part in enumerate(parts):
  610. item = await get_or_create_sentence_cache_item(part, voice, speed, model)
  611. payload = wav_to_pcm_payload(item["audio_bytes"])
  612. header = json.dumps(
  613. {
  614. "index": idx,
  615. "sentence": item["sentence"],
  616. "sample_rate": item["sample_rate"],
  617. "format": "s16le",
  618. "bytes": len(payload),
  619. }
  620. ).encode() + b"\n"
  621. yield header
  622. yield payload
  623. return StreamingResponse(
  624. pcm_stream(),
  625. media_type="application/octet-stream",
  626. headers={"X-Audio-Format": "s16le", "X-Sample-Rate": str(sample_rate)},
  627. )
  628. @app.get("/clear-cache")
  629. async def clear_cache():
  630. with cache_lock:
  631. memory_cache.clear()
  632. for f in Path(CACHE_DIR).glob("*"):
  633. f.unlink()
  634. return {"status": "success"}
  635. @app.get("/cache-info")
  636. async def get_cache_info():
  637. with cache_lock:
  638. mem_count = len(memory_cache)
  639. disk_files = list(Path(CACHE_DIR).glob("*.wav"))
  640. return {"memory_cache": mem_count, "disk_cache": len(disk_files)}
  641. if __name__ == "__main__":
  642. import uvicorn
  643. uvicorn.run("speech_tts_onnx:app", host="0.0.0.0", port=18000, reload=False, workers=1)