| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567 |
- from __future__ import annotations
- import asyncio
- import base64
- import concurrent.futures
- import hashlib
- import io
- import json
- import logging
- import os
- import re
- import threading
- import time
- import uuid
- import urllib.request
- from collections import OrderedDict
- from contextlib import asynccontextmanager
- from pathlib import Path
- from typing import Dict, List, Optional, Tuple
- import aiofiles
- import numpy as np
- import soundfile as sf
- from fastapi import Body, FastAPI, HTTPException, Query, Request
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import StreamingResponse
- from pydantic import BaseModel, field_validator
- DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
- MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
- HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
- CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
- VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
- VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
- CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
- MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
- DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
- MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
- MEMORY_CACHE_MAX_ITEMS = int(os.getenv("MEMORY_CACHE_SIZE", 120))
- MEMORY_CACHE_MAX_BYTES = int(os.getenv("MEMORY_CACHE_MAX_BYTES", str(128 * 1024 * 1024)))
- DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
- ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
- ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
- ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
- ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "false").lower() == "true"
- VOICE_ALIASES: Dict[str, str] = {}
- logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
- logger = logging.getLogger("speech_tts_onnx_opt")
- Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
- SENT_SPLIT_RE = re.compile(r"(?<=[。!?;.!?;::,,])\s*|\n+")
- class _PhonemizerWordCountMismatchFilter(logging.Filter):
- def filter(self, record: logging.LogRecord) -> bool:
- return "words count mismatch" not in record.getMessage()
- logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
- class MemoryAudioCache:
- def __init__(self, max_items: int, max_bytes: int, ttl: int):
- self.max_items = max_items
- self.max_bytes = max_bytes
- self.ttl = ttl
- self.lock = threading.Lock()
- self.items: "OrderedDict[str, dict]" = OrderedDict()
- self.total_bytes = 0
- def _entry_size(self, value: dict) -> int:
- return len(value.get("audio_bytes", b"")) + len(value.get("sentence", "").encode("utf-8")) + 128
- def _purge_expired(self, now: float):
- expired = [k for k, v in self.items.items() if now - v["ts"] > self.ttl]
- for key in expired:
- entry = self.items.pop(key, None)
- if entry:
- self.total_bytes -= entry["size"]
- def get(self, key: str) -> Optional[dict]:
- now = time.time()
- with self.lock:
- self._purge_expired(now)
- entry = self.items.get(key)
- if not entry:
- return None
- self.items.move_to_end(key)
- entry["ts"] = now
- return {
- "sentence": entry["sentence"],
- "sample_rate": entry["sample_rate"],
- "audio_bytes": entry["audio_bytes"],
- }
- def set(self, key: str, value: dict):
- now = time.time()
- with self.lock:
- self._purge_expired(now)
- old = self.items.pop(key, None)
- if old:
- self.total_bytes -= old["size"]
- entry = {
- "sentence": value["sentence"],
- "sample_rate": value["sample_rate"],
- "audio_bytes": value["audio_bytes"],
- "ts": now,
- }
- entry["size"] = self._entry_size(entry)
- self.items[key] = entry
- self.total_bytes += entry["size"]
- while self.items and (
- len(self.items) > self.max_items or self.total_bytes > self.max_bytes
- ):
- _, removed = self.items.popitem(last=False)
- self.total_bytes -= removed["size"]
- def clear(self):
- with self.lock:
- self.items.clear()
- self.total_bytes = 0
- def info(self) -> dict:
- now = time.time()
- with self.lock:
- self._purge_expired(now)
- return {"items": len(self.items), "bytes": self.total_bytes}
- model_lock = threading.Lock()
- request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
- current_requests: Dict[str, Dict] = {}
- memory_cache = MemoryAudioCache(
- max_items=MEMORY_CACHE_MAX_ITEMS,
- max_bytes=MEMORY_CACHE_MAX_BYTES,
- ttl=MEMORY_CACHE_TTL,
- )
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
- inflight_lock = threading.Lock()
- inflight_tasks: Dict[str, asyncio.Future] = {}
- model_session = None
- model_name = DEFAULT_MODEL_NAME
- sample_rate = DEFAULT_SAMPLE_RATE
- _KOKORO_ONNX_ENGINE = None
- _EN_G2P_PIPELINE = None
- def split_sentences(text: str) -> List[str]:
- parts = [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
- merged: List[str] = []
- buf = ""
- for part in parts:
- if len(part) < 3 and merged:
- merged[-1] = f"{merged[-1]} {part}".strip()
- else:
- merged.append(part)
- return merged
- def iter_text_parts(text: str, split_pattern: Optional[str]) -> List[str]:
- text = (text or "").strip()
- if not text:
- return []
- blocks = [text]
- if split_pattern:
- try:
- blocks = [p.strip() for p in re.split(split_pattern, text) if p.strip()]
- except re.error:
- logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
- parts: List[str] = []
- for block in blocks:
- parts.extend(split_sentences(block) or [block])
- return parts
- def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
- raw = f"{sentence}|{voice}|{speed:.4f}|{model}"
- return hashlib.md5(raw.encode("utf-8")).hexdigest()
- def sentence_cache_path(key: str) -> str:
- return os.path.join(CACHE_DIR, f"{key}.wav")
- def meta_cache_path(key: str) -> str:
- return os.path.join(CACHE_DIR, f"{key}.json")
- def to_mono_numpy(audio) -> np.ndarray:
- arr = np.asarray(audio, dtype=np.float32)
- if arr.size == 0:
- return np.array([], dtype=np.float32)
- if arr.ndim == 2:
- if arr.shape[1] == 1:
- arr = arr[:, 0]
- elif arr.shape[0] == 1:
- arr = arr[0]
- else:
- arr = arr.mean(axis=1)
- elif arr.ndim > 2:
- arr = arr.reshape(-1)
- return arr.astype(np.float32, copy=False)
- def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
- audio, sr = sf.read(wav_path)
- buf = io.BytesIO()
- with sf.SoundFile(buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16") as f:
- f.write(audio)
- return buf.getvalue(), sr
- async def load_sentence_from_disk(key: str) -> Optional[dict]:
- wav_path = sentence_cache_path(key)
- meta_path = meta_cache_path(key)
- if not Path(wav_path).exists() or not Path(meta_path).exists():
- return None
- async with aiofiles.open(meta_path, "r") as f:
- meta = json.loads(await f.read())
- wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(executor, lambda: read_wav_bytes(wav_path))
- return {"sentence": meta["sentence"], "sample_rate": sr, "audio_bytes": wav_bytes}
- async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
- wav_path = sentence_cache_path(key)
- meta_path = meta_cache_path(key)
- await asyncio.get_event_loop().run_in_executor(executor, lambda: sf.write(wav_path, audio, sr, format="WAV"))
- async with aiofiles.open(meta_path, "w") as f:
- await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
- async def clean_disk_cache():
- files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
- if len(files) <= DISK_CACHE_SIZE:
- return
- for f in files[: len(files) - DISK_CACHE_SIZE]:
- try:
- f.unlink()
- meta = Path(meta_cache_path(f.stem))
- if meta.exists():
- meta.unlink()
- except Exception:
- pass
- def resolve_model_path(name: str) -> str:
- local_path = Path(MODEL_DIR) / name
- if local_path.exists():
- return str(local_path)
- if os.path.isabs(name) and Path(name).exists():
- return name
- raise FileNotFoundError(f"找不到模型文件: {local_path}")
- def get_en_g2p_pipeline():
- global _EN_G2P_PIPELINE
- if _EN_G2P_PIPELINE is None:
- from kokoro import KPipeline # type: ignore
- _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
- return _EN_G2P_PIPELINE
- def _download_voice_file(voice_path: Path) -> Path:
- raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_path.name}"
- voice_path.parent.mkdir(parents=True, exist_ok=True)
- tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
- with urllib.request.urlopen(raw_url, timeout=60) as response:
- payload = response.read()
- with open(tmp_path, "wb") as f:
- f.write(payload)
- tmp_path.replace(voice_path)
- return voice_path
- def load_onnx_session(name: str):
- try:
- import onnxruntime as ort # type: ignore
- except Exception as e:
- raise RuntimeError("无法导入 onnxruntime。请在部署环境中安装 onnxruntime。") from e
- model_path = resolve_model_path(name)
- sess_options = ort.SessionOptions()
- sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
- sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
- sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
- sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
- return ort.InferenceSession(model_path, sess_options=sess_options, providers=["CPUExecutionProvider"])
- def load_kokoro_engine(name: str):
- try:
- from kokoro_onnx import Kokoro # type: ignore
- except Exception as e:
- raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
- model_path = resolve_model_path(name)
- return Kokoro(model_path=model_path, voices_path=VOICES_V1_PATH, vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None)
- def load_model(force_reload: bool = False, name: Optional[str] = None):
- global model_session, model_name, _KOKORO_ONNX_ENGINE
- with model_lock:
- target = name or model_name
- if force_reload or model_session is None or target != model_name:
- model_session = load_onnx_session(target)
- _KOKORO_ONNX_ENGINE = load_kokoro_engine(target)
- model_name = target
- logger.info("ONNX 模型加载完成: %s", resolve_model_path(target))
- return model_session
- def get_kokoro_engine(name: Optional[str] = None):
- global _KOKORO_ONNX_ENGINE
- target = name or model_name
- if _KOKORO_ONNX_ENGINE is None or target != model_name:
- load_model(name=target)
- return _KOKORO_ONNX_ENGINE
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- load_model()
- yield
- app = FastAPI(title="Online TTS Service (Kokoro ONNX Optimized)", lifespan=lifespan)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.middleware("http")
- async def track_clients(request: Request, call_next):
- client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
- if request.url.path == "/generate":
- current_requests[client_id] = {"active": True}
- response = await call_next(request)
- if request.url.path == "/generate":
- current_requests.pop(client_id, None)
- return response
- class TTSRequest(BaseModel):
- text: str
- voice: Optional[str] = "af_heart"
- speed: Optional[float] = 1.0
- split_pattern: Optional[str] = r"\n+"
- model_name: Optional[str] = None
- @field_validator("speed")
- @classmethod
- def validate_speed(cls, v):
- if v is None:
- return 1.0
- v = float(v)
- if v <= 0:
- raise ValueError("speed 必须大于 0")
- return v
- def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
- if not text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空")
- engine = get_kokoro_engine(name=model_name)
- session = load_model(name=model_name)
- if voice not in set(engine.get_voices()):
- raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
- phonemes = engine.tokenizer.phonemize(text, "en-us")
- batched_phonemes = engine._split_phonemes(phonemes)
- if not batched_phonemes:
- raise HTTPException(status_code=400, detail="文本音素化失败")
- voice_style = engine.get_voice_style(voice)
- audio_segments: List[np.ndarray] = []
- for phoneme_batch in batched_phonemes:
- tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
- if tokens.size == 0:
- continue
- feeds = {
- "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
- "style": np.asarray(voice_style[len(tokens)], dtype=np.float32),
- "speed": np.asarray([speed], dtype=np.float32),
- }
- outputs = session.run(None, feeds)
- if outputs:
- audio_segments.append(to_mono_numpy(outputs[0]))
- if not audio_segments:
- raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
- audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
- if audio.size == 0 or not np.isfinite(audio).all():
- raise HTTPException(status_code=500, detail="生成的音频无效")
- return audio
- def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
- buf = io.BytesIO()
- with sf.SoundFile(buf, "w", samplerate=sr, channels=1, format="WAV", subtype="PCM_16") as f:
- f.write(audio)
- return buf.getvalue()
- def cache_item_to_response(item: dict) -> dict:
- return {
- "sentence": item["sentence"],
- "sample_rate": item["sample_rate"],
- "audio": base64.b64encode(item["audio_bytes"]).decode(),
- }
- async def _compute_sentence_item(sentence: str, voice: str, speed: float, model: str) -> dict:
- def _run():
- audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model)
- wav_bytes = encode_wav_bytes(audio, sample_rate)
- return {"sentence": sentence, "sample_rate": sample_rate, "audio_bytes": wav_bytes}, audio
- item, audio = await asyncio.to_thread(_run)
- memory_cache.set(sentence_cache_key(sentence, voice, speed, model), item)
- await save_sentence_to_disk(sentence_cache_key(sentence, voice, speed, model), audio, sample_rate, sentence)
- await clean_disk_cache()
- return item
- async def get_or_create_sentence_cache_item(sentence: str, voice: str, speed: float, model: str) -> dict:
- key = sentence_cache_key(sentence, voice, speed, model)
- cached = memory_cache.get(key)
- if cached:
- return cached
- disk_item = await load_sentence_from_disk(key)
- if disk_item:
- memory_cache.set(key, disk_item)
- return disk_item
- loop = asyncio.get_running_loop()
- with inflight_lock:
- fut = inflight_tasks.get(key)
- if fut is None:
- fut = loop.create_task(_compute_sentence_item(sentence, voice, speed, model))
- inflight_tasks[key] = fut
- try:
- return await fut
- finally:
- with inflight_lock:
- if inflight_tasks.get(key) is fut:
- inflight_tasks.pop(key, None)
- def synthesize_wav_bytes(text: str, voice: str, speed: float, split_pattern: Optional[str], model_name: Optional[str] = None) -> io.BytesIO:
- parts = iter_text_parts(text, split_pattern)
- if not parts:
- raise HTTPException(status_code=400, detail="文本不能为空")
- buffers: List[np.ndarray] = []
- for part in parts:
- key = sentence_cache_key(part, voice, speed, model_name or DEFAULT_MODEL_NAME)
- cached = memory_cache.get(key)
- if cached:
- audio, _ = sf.read(io.BytesIO(cached["audio_bytes"]), dtype="float32")
- buffers.append(to_mono_numpy(audio))
- continue
- audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
- buffers.append(audio)
- merged = np.concatenate(buffers, axis=0) if len(buffers) > 1 else buffers[0]
- buf = io.BytesIO()
- sf.write(buf, merged, samplerate=sample_rate, format="WAV", subtype="PCM_16")
- buf.seek(0)
- return buf
- @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
- def tts_post(req: TTSRequest):
- buf = synthesize_wav_bytes(
- text=req.text,
- voice=req.voice or "af_heart",
- speed=req.speed if req.speed is not None else 1.0,
- split_pattern=req.split_pattern,
- model_name=req.model_name,
- )
- return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
- @app.get("/tts", summary="GET: 传入文本返回 WAV 流")
- def tts_get(
- text: str = Query(..., description="待合成文本"),
- voice: str = Query("af_heart"),
- speed: float = Query(1.0),
- model_name: str = Query(DEFAULT_MODEL_NAME),
- split_pattern: str = Query(r"\n+"),
- ):
- if speed is None or float(speed) <= 0:
- raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
- buf = synthesize_wav_bytes(text=text, voice=voice, speed=float(speed), split_pattern=split_pattern, model_name=model_name)
- return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
- @app.post("/generate")
- async def generate_audio_stream(data: Dict = Body(...)):
- async with request_semaphore:
- text = data.get("text", "")
- voice = data.get("voice", "af_heart")
- speed = float(data.get("speed", 1.0))
- model = data.get("model_name", DEFAULT_MODEL_NAME)
- split_pattern = data.get("split_pattern", r"\n+")
- client_id = data.get("client_id", str(uuid.uuid4()))
- if not text.strip():
- raise HTTPException(status_code=400, detail="文本不能为空")
- parts = iter_text_parts(text, split_pattern)
- if client_id in current_requests:
- current_requests[client_id]["interrupt"] = True
- await asyncio.sleep(0.05)
- current_requests[client_id] = {"interrupt": False}
- async def stream():
- try:
- for idx, part in enumerate(parts):
- if current_requests.get(client_id, {}).get("interrupt"):
- break
- item = await get_or_create_sentence_cache_item(part, voice, speed, model)
- resp = cache_item_to_response(item)
- yield json.dumps({"index": idx, **resp}).encode() + b"\n"
- finally:
- current_requests.pop(client_id, None)
- return StreamingResponse(stream(), media_type="application/x-ndjson")
- @app.get("/clear-cache")
- async def clear_cache():
- memory_cache.clear()
- for f in Path(CACHE_DIR).glob("*"):
- f.unlink()
- return {"status": "success"}
- @app.get("/cache-info")
- async def get_cache_info():
- mem = memory_cache.info()
- disk_files = list(Path(CACHE_DIR).glob("*.wav"))
- return {"memory_cache": mem["items"], "memory_cache_bytes": mem["bytes"], "disk_cache": len(disk_files)}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("speech_tts_onnx_opt:app", host="0.0.0.0", port=18000, reload=False, workers=1)
|