speech_tts_onnx_opt.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. from __future__ import annotations
  2. import asyncio
  3. import base64
  4. import concurrent.futures
  5. import hashlib
  6. import io
  7. import json
  8. import logging
  9. import os
  10. import re
  11. import threading
  12. import time
  13. import uuid
  14. import urllib.request
  15. from collections import OrderedDict
  16. from contextlib import asynccontextmanager
  17. from pathlib import Path
  18. from typing import Dict, List, Optional, Tuple
  19. import aiofiles
  20. import numpy as np
  21. import soundfile as sf
  22. from fastapi import Body, FastAPI, HTTPException, Query, Request
  23. from fastapi.middleware.cors import CORSMiddleware
  24. from fastapi.responses import StreamingResponse
  25. from pydantic import BaseModel, field_validator
  26. DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
  27. MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
  28. HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
  29. CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
  30. VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
  31. VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
  32. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
  33. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
  34. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
  35. DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
  36. MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
  37. MEMORY_CACHE_MAX_ITEMS = int(os.getenv("MEMORY_CACHE_SIZE", 120))
  38. MEMORY_CACHE_MAX_BYTES = int(os.getenv("MEMORY_CACHE_MAX_BYTES", str(128 * 1024 * 1024)))
  39. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
  40. ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
  41. ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
  42. ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
  43. ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "false").lower() == "true"
  44. VOICE_ALIASES: Dict[str, str] = {}
  45. logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
  46. logger = logging.getLogger("speech_tts_onnx_opt")
  47. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  48. SENT_SPLIT_RE = re.compile(r"(?<=[。!?;.!?;::,,])\s*|\n+")
  49. class _PhonemizerWordCountMismatchFilter(logging.Filter):
  50. def filter(self, record: logging.LogRecord) -> bool:
  51. return "words count mismatch" not in record.getMessage()
  52. logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
  53. class MemoryAudioCache:
  54. def __init__(self, max_items: int, max_bytes: int, ttl: int):
  55. self.max_items = max_items
  56. self.max_bytes = max_bytes
  57. self.ttl = ttl
  58. self.lock = threading.Lock()
  59. self.items: "OrderedDict[str, dict]" = OrderedDict()
  60. self.total_bytes = 0
  61. def _entry_size(self, value: dict) -> int:
  62. return len(value.get("audio_bytes", b"")) + len(value.get("sentence", "").encode("utf-8")) + 128
  63. def _purge_expired(self, now: float):
  64. expired = [k for k, v in self.items.items() if now - v["ts"] > self.ttl]
  65. for key in expired:
  66. entry = self.items.pop(key, None)
  67. if entry:
  68. self.total_bytes -= entry["size"]
  69. def get(self, key: str) -> Optional[dict]:
  70. now = time.time()
  71. with self.lock:
  72. self._purge_expired(now)
  73. entry = self.items.get(key)
  74. if not entry:
  75. return None
  76. self.items.move_to_end(key)
  77. entry["ts"] = now
  78. return {
  79. "sentence": entry["sentence"],
  80. "sample_rate": entry["sample_rate"],
  81. "audio_bytes": entry["audio_bytes"],
  82. }
  83. def set(self, key: str, value: dict):
  84. now = time.time()
  85. with self.lock:
  86. self._purge_expired(now)
  87. old = self.items.pop(key, None)
  88. if old:
  89. self.total_bytes -= old["size"]
  90. entry = {
  91. "sentence": value["sentence"],
  92. "sample_rate": value["sample_rate"],
  93. "audio_bytes": value["audio_bytes"],
  94. "ts": now,
  95. }
  96. entry["size"] = self._entry_size(entry)
  97. self.items[key] = entry
  98. self.total_bytes += entry["size"]
  99. while self.items and (
  100. len(self.items) > self.max_items or self.total_bytes > self.max_bytes
  101. ):
  102. _, removed = self.items.popitem(last=False)
  103. self.total_bytes -= removed["size"]
  104. def clear(self):
  105. with self.lock:
  106. self.items.clear()
  107. self.total_bytes = 0
  108. def info(self) -> dict:
  109. now = time.time()
  110. with self.lock:
  111. self._purge_expired(now)
  112. return {"items": len(self.items), "bytes": self.total_bytes}
  113. model_lock = threading.Lock()
  114. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  115. current_requests: Dict[str, Dict] = {}
  116. memory_cache = MemoryAudioCache(
  117. max_items=MEMORY_CACHE_MAX_ITEMS,
  118. max_bytes=MEMORY_CACHE_MAX_BYTES,
  119. ttl=MEMORY_CACHE_TTL,
  120. )
  121. executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
  122. inflight_lock = threading.Lock()
  123. inflight_tasks: Dict[str, asyncio.Future] = {}
  124. model_session = None
  125. model_name = DEFAULT_MODEL_NAME
  126. sample_rate = DEFAULT_SAMPLE_RATE
  127. _KOKORO_ONNX_ENGINE = None
  128. _EN_G2P_PIPELINE = None
  129. def split_sentences(text: str) -> List[str]:
  130. parts = [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  131. merged: List[str] = []
  132. buf = ""
  133. for part in parts:
  134. if len(part) < 3 and merged:
  135. merged[-1] = f"{merged[-1]} {part}".strip()
  136. else:
  137. merged.append(part)
  138. return merged
  139. def iter_text_parts(text: str, split_pattern: Optional[str]) -> List[str]:
  140. text = (text or "").strip()
  141. if not text:
  142. return []
  143. blocks = [text]
  144. if split_pattern:
  145. try:
  146. blocks = [p.strip() for p in re.split(split_pattern, text) if p.strip()]
  147. except re.error:
  148. logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
  149. parts: List[str] = []
  150. for block in blocks:
  151. parts.extend(split_sentences(block) or [block])
  152. return parts
  153. def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
  154. raw = f"{sentence}|{voice}|{speed:.4f}|{model}"
  155. return hashlib.md5(raw.encode("utf-8")).hexdigest()
  156. def sentence_cache_path(key: str) -> str:
  157. return os.path.join(CACHE_DIR, f"{key}.wav")
  158. def meta_cache_path(key: str) -> str:
  159. return os.path.join(CACHE_DIR, f"{key}.json")
  160. def to_mono_numpy(audio) -> np.ndarray:
  161. arr = np.asarray(audio, dtype=np.float32)
  162. if arr.size == 0:
  163. return np.array([], dtype=np.float32)
  164. if arr.ndim == 2:
  165. if arr.shape[1] == 1:
  166. arr = arr[:, 0]
  167. elif arr.shape[0] == 1:
  168. arr = arr[0]
  169. else:
  170. arr = arr.mean(axis=1)
  171. elif arr.ndim > 2:
  172. arr = arr.reshape(-1)
  173. return arr.astype(np.float32, copy=False)
  174. def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
  175. audio, sr = sf.read(wav_path)
  176. buf = io.BytesIO()
  177. with sf.SoundFile(buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16") as f:
  178. f.write(audio)
  179. return buf.getvalue(), sr
  180. async def load_sentence_from_disk(key: str) -> Optional[dict]:
  181. wav_path = sentence_cache_path(key)
  182. meta_path = meta_cache_path(key)
  183. if not Path(wav_path).exists() or not Path(meta_path).exists():
  184. return None
  185. async with aiofiles.open(meta_path, "r") as f:
  186. meta = json.loads(await f.read())
  187. wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(executor, lambda: read_wav_bytes(wav_path))
  188. return {"sentence": meta["sentence"], "sample_rate": sr, "audio_bytes": wav_bytes}
  189. async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
  190. wav_path = sentence_cache_path(key)
  191. meta_path = meta_cache_path(key)
  192. await asyncio.get_event_loop().run_in_executor(executor, lambda: sf.write(wav_path, audio, sr, format="WAV"))
  193. async with aiofiles.open(meta_path, "w") as f:
  194. await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
  195. async def clean_disk_cache():
  196. files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
  197. if len(files) <= DISK_CACHE_SIZE:
  198. return
  199. for f in files[: len(files) - DISK_CACHE_SIZE]:
  200. try:
  201. f.unlink()
  202. meta = Path(meta_cache_path(f.stem))
  203. if meta.exists():
  204. meta.unlink()
  205. except Exception:
  206. pass
  207. def resolve_model_path(name: str) -> str:
  208. local_path = Path(MODEL_DIR) / name
  209. if local_path.exists():
  210. return str(local_path)
  211. if os.path.isabs(name) and Path(name).exists():
  212. return name
  213. raise FileNotFoundError(f"找不到模型文件: {local_path}")
  214. def get_en_g2p_pipeline():
  215. global _EN_G2P_PIPELINE
  216. if _EN_G2P_PIPELINE is None:
  217. from kokoro import KPipeline # type: ignore
  218. _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
  219. return _EN_G2P_PIPELINE
  220. def _download_voice_file(voice_path: Path) -> Path:
  221. raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_path.name}"
  222. voice_path.parent.mkdir(parents=True, exist_ok=True)
  223. tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
  224. with urllib.request.urlopen(raw_url, timeout=60) as response:
  225. payload = response.read()
  226. with open(tmp_path, "wb") as f:
  227. f.write(payload)
  228. tmp_path.replace(voice_path)
  229. return voice_path
  230. def load_onnx_session(name: str):
  231. try:
  232. import onnxruntime as ort # type: ignore
  233. except Exception as e:
  234. raise RuntimeError("无法导入 onnxruntime。请在部署环境中安装 onnxruntime。") from e
  235. model_path = resolve_model_path(name)
  236. sess_options = ort.SessionOptions()
  237. sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
  238. sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
  239. sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
  240. sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
  241. return ort.InferenceSession(model_path, sess_options=sess_options, providers=["CPUExecutionProvider"])
  242. def load_kokoro_engine(name: str):
  243. try:
  244. from kokoro_onnx import Kokoro # type: ignore
  245. except Exception as e:
  246. raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
  247. model_path = resolve_model_path(name)
  248. return Kokoro(model_path=model_path, voices_path=VOICES_V1_PATH, vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None)
  249. def load_model(force_reload: bool = False, name: Optional[str] = None):
  250. global model_session, model_name, _KOKORO_ONNX_ENGINE
  251. with model_lock:
  252. target = name or model_name
  253. if force_reload or model_session is None or target != model_name:
  254. model_session = load_onnx_session(target)
  255. _KOKORO_ONNX_ENGINE = load_kokoro_engine(target)
  256. model_name = target
  257. logger.info("ONNX 模型加载完成: %s", resolve_model_path(target))
  258. return model_session
  259. def get_kokoro_engine(name: Optional[str] = None):
  260. global _KOKORO_ONNX_ENGINE
  261. target = name or model_name
  262. if _KOKORO_ONNX_ENGINE is None or target != model_name:
  263. load_model(name=target)
  264. return _KOKORO_ONNX_ENGINE
  265. @asynccontextmanager
  266. async def lifespan(app: FastAPI):
  267. load_model()
  268. yield
  269. app = FastAPI(title="Online TTS Service (Kokoro ONNX Optimized)", lifespan=lifespan)
  270. app.add_middleware(
  271. CORSMiddleware,
  272. allow_origins=["*"],
  273. allow_credentials=True,
  274. allow_methods=["*"],
  275. allow_headers=["*"],
  276. )
  277. @app.middleware("http")
  278. async def track_clients(request: Request, call_next):
  279. client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
  280. if request.url.path == "/generate":
  281. current_requests[client_id] = {"active": True}
  282. response = await call_next(request)
  283. if request.url.path == "/generate":
  284. current_requests.pop(client_id, None)
  285. return response
  286. class TTSRequest(BaseModel):
  287. text: str
  288. voice: Optional[str] = "af_heart"
  289. speed: Optional[float] = 1.0
  290. split_pattern: Optional[str] = r"\n+"
  291. model_name: Optional[str] = None
  292. @field_validator("speed")
  293. @classmethod
  294. def validate_speed(cls, v):
  295. if v is None:
  296. return 1.0
  297. v = float(v)
  298. if v <= 0:
  299. raise ValueError("speed 必须大于 0")
  300. return v
  301. def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
  302. if not text.strip():
  303. raise HTTPException(status_code=400, detail="文本不能为空")
  304. engine = get_kokoro_engine(name=model_name)
  305. session = load_model(name=model_name)
  306. if voice not in set(engine.get_voices()):
  307. raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
  308. phonemes = engine.tokenizer.phonemize(text, "en-us")
  309. batched_phonemes = engine._split_phonemes(phonemes)
  310. if not batched_phonemes:
  311. raise HTTPException(status_code=400, detail="文本音素化失败")
  312. voice_style = engine.get_voice_style(voice)
  313. audio_segments: List[np.ndarray] = []
  314. for phoneme_batch in batched_phonemes:
  315. tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
  316. if tokens.size == 0:
  317. continue
  318. feeds = {
  319. "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
  320. "style": np.asarray(voice_style[len(tokens)], dtype=np.float32),
  321. "speed": np.asarray([speed], dtype=np.float32),
  322. }
  323. outputs = session.run(None, feeds)
  324. if outputs:
  325. audio_segments.append(to_mono_numpy(outputs[0]))
  326. if not audio_segments:
  327. raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
  328. audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
  329. if audio.size == 0 or not np.isfinite(audio).all():
  330. raise HTTPException(status_code=500, detail="生成的音频无效")
  331. return audio
  332. def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
  333. buf = io.BytesIO()
  334. with sf.SoundFile(buf, "w", samplerate=sr, channels=1, format="WAV", subtype="PCM_16") as f:
  335. f.write(audio)
  336. return buf.getvalue()
  337. def cache_item_to_response(item: dict) -> dict:
  338. return {
  339. "sentence": item["sentence"],
  340. "sample_rate": item["sample_rate"],
  341. "audio": base64.b64encode(item["audio_bytes"]).decode(),
  342. }
  343. async def _compute_sentence_item(sentence: str, voice: str, speed: float, model: str) -> dict:
  344. def _run():
  345. audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model)
  346. wav_bytes = encode_wav_bytes(audio, sample_rate)
  347. return {"sentence": sentence, "sample_rate": sample_rate, "audio_bytes": wav_bytes}, audio
  348. item, audio = await asyncio.to_thread(_run)
  349. memory_cache.set(sentence_cache_key(sentence, voice, speed, model), item)
  350. await save_sentence_to_disk(sentence_cache_key(sentence, voice, speed, model), audio, sample_rate, sentence)
  351. await clean_disk_cache()
  352. return item
  353. async def get_or_create_sentence_cache_item(sentence: str, voice: str, speed: float, model: str) -> dict:
  354. key = sentence_cache_key(sentence, voice, speed, model)
  355. cached = memory_cache.get(key)
  356. if cached:
  357. return cached
  358. disk_item = await load_sentence_from_disk(key)
  359. if disk_item:
  360. memory_cache.set(key, disk_item)
  361. return disk_item
  362. loop = asyncio.get_running_loop()
  363. with inflight_lock:
  364. fut = inflight_tasks.get(key)
  365. if fut is None:
  366. fut = loop.create_task(_compute_sentence_item(sentence, voice, speed, model))
  367. inflight_tasks[key] = fut
  368. try:
  369. return await fut
  370. finally:
  371. with inflight_lock:
  372. if inflight_tasks.get(key) is fut:
  373. inflight_tasks.pop(key, None)
  374. def synthesize_wav_bytes(text: str, voice: str, speed: float, split_pattern: Optional[str], model_name: Optional[str] = None) -> io.BytesIO:
  375. parts = iter_text_parts(text, split_pattern)
  376. if not parts:
  377. raise HTTPException(status_code=400, detail="文本不能为空")
  378. buffers: List[np.ndarray] = []
  379. for part in parts:
  380. key = sentence_cache_key(part, voice, speed, model_name or DEFAULT_MODEL_NAME)
  381. cached = memory_cache.get(key)
  382. if cached:
  383. audio, _ = sf.read(io.BytesIO(cached["audio_bytes"]), dtype="float32")
  384. buffers.append(to_mono_numpy(audio))
  385. continue
  386. audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
  387. buffers.append(audio)
  388. merged = np.concatenate(buffers, axis=0) if len(buffers) > 1 else buffers[0]
  389. buf = io.BytesIO()
  390. sf.write(buf, merged, samplerate=sample_rate, format="WAV", subtype="PCM_16")
  391. buf.seek(0)
  392. return buf
  393. @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
  394. def tts_post(req: TTSRequest):
  395. buf = synthesize_wav_bytes(
  396. text=req.text,
  397. voice=req.voice or "af_heart",
  398. speed=req.speed if req.speed is not None else 1.0,
  399. split_pattern=req.split_pattern,
  400. model_name=req.model_name,
  401. )
  402. return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
  403. @app.get("/tts", summary="GET: 传入文本返回 WAV 流")
  404. def tts_get(
  405. text: str = Query(..., description="待合成文本"),
  406. voice: str = Query("af_heart"),
  407. speed: float = Query(1.0),
  408. model_name: str = Query(DEFAULT_MODEL_NAME),
  409. split_pattern: str = Query(r"\n+"),
  410. ):
  411. if speed is None or float(speed) <= 0:
  412. raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
  413. buf = synthesize_wav_bytes(text=text, voice=voice, speed=float(speed), split_pattern=split_pattern, model_name=model_name)
  414. return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
  415. @app.post("/generate")
  416. async def generate_audio_stream(data: Dict = Body(...)):
  417. async with request_semaphore:
  418. text = data.get("text", "")
  419. voice = data.get("voice", "af_heart")
  420. speed = float(data.get("speed", 1.0))
  421. model = data.get("model_name", DEFAULT_MODEL_NAME)
  422. split_pattern = data.get("split_pattern", r"\n+")
  423. client_id = data.get("client_id", str(uuid.uuid4()))
  424. if not text.strip():
  425. raise HTTPException(status_code=400, detail="文本不能为空")
  426. parts = iter_text_parts(text, split_pattern)
  427. if client_id in current_requests:
  428. current_requests[client_id]["interrupt"] = True
  429. await asyncio.sleep(0.05)
  430. current_requests[client_id] = {"interrupt": False}
  431. async def stream():
  432. try:
  433. for idx, part in enumerate(parts):
  434. if current_requests.get(client_id, {}).get("interrupt"):
  435. break
  436. item = await get_or_create_sentence_cache_item(part, voice, speed, model)
  437. resp = cache_item_to_response(item)
  438. yield json.dumps({"index": idx, **resp}).encode() + b"\n"
  439. finally:
  440. current_requests.pop(client_id, None)
  441. return StreamingResponse(stream(), media_type="application/x-ndjson")
  442. @app.get("/clear-cache")
  443. async def clear_cache():
  444. memory_cache.clear()
  445. for f in Path(CACHE_DIR).glob("*"):
  446. f.unlink()
  447. return {"status": "success"}
  448. @app.get("/cache-info")
  449. async def get_cache_info():
  450. mem = memory_cache.info()
  451. disk_files = list(Path(CACHE_DIR).glob("*.wav"))
  452. return {"memory_cache": mem["items"], "memory_cache_bytes": mem["bytes"], "disk_cache": len(disk_files)}
  453. if __name__ == "__main__":
  454. import uvicorn
  455. uvicorn.run("speech_tts_onnx_opt:app", host="0.0.0.0", port=18000, reload=False, workers=1)