sequoia 1 mesiac pred
rodič
commit
96d2d79a43
2 zmenil súbory, kde vykonal 568 pridanie a 1 odobranie
  1. 567 0
      speech_tts_onnx_opt.py
  2. 1 1
      start_onnx.sh

+ 567 - 0
speech_tts_onnx_opt.py

@@ -0,0 +1,567 @@
+from __future__ import annotations
+
+import asyncio
+import base64
+import concurrent.futures
+import hashlib
+import io
+import json
+import logging
+import os
+import re
+import threading
+import time
+import uuid
+import urllib.request
+from collections import OrderedDict
+from contextlib import asynccontextmanager
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import aiofiles
+import numpy as np
+import soundfile as sf
+from fastapi import Body, FastAPI, HTTPException, Query, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
+from pydantic import BaseModel, field_validator
+
+
+DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
+MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
+HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
+CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
+VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
+VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
+CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
+LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
+MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
+DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
+MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
+MEMORY_CACHE_MAX_ITEMS = int(os.getenv("MEMORY_CACHE_SIZE", 120))
+MEMORY_CACHE_MAX_BYTES = int(os.getenv("MEMORY_CACHE_MAX_BYTES", str(128 * 1024 * 1024)))
+DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
+ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
+ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
+ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
+ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "false").lower() == "true"
+VOICE_ALIASES: Dict[str, str] = {}
+
+logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
+logger = logging.getLogger("speech_tts_onnx_opt")
+
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+SENT_SPLIT_RE = re.compile(r"(?<=[。!?;.!?;::,,])\s*|\n+")
+
+
+class _PhonemizerWordCountMismatchFilter(logging.Filter):
+    def filter(self, record: logging.LogRecord) -> bool:
+        return "words count mismatch" not in record.getMessage()
+
+
+logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
+
+
+class MemoryAudioCache:
+    def __init__(self, max_items: int, max_bytes: int, ttl: int):
+        self.max_items = max_items
+        self.max_bytes = max_bytes
+        self.ttl = ttl
+        self.lock = threading.Lock()
+        self.items: "OrderedDict[str, dict]" = OrderedDict()
+        self.total_bytes = 0
+
+    def _entry_size(self, value: dict) -> int:
+        return len(value.get("audio_bytes", b"")) + len(value.get("sentence", "").encode("utf-8")) + 128
+
+    def _purge_expired(self, now: float):
+        expired = [k for k, v in self.items.items() if now - v["ts"] > self.ttl]
+        for key in expired:
+            entry = self.items.pop(key, None)
+            if entry:
+                self.total_bytes -= entry["size"]
+
+    def get(self, key: str) -> Optional[dict]:
+        now = time.time()
+        with self.lock:
+            self._purge_expired(now)
+            entry = self.items.get(key)
+            if not entry:
+                return None
+            self.items.move_to_end(key)
+            entry["ts"] = now
+            return {
+                "sentence": entry["sentence"],
+                "sample_rate": entry["sample_rate"],
+                "audio_bytes": entry["audio_bytes"],
+            }
+
+    def set(self, key: str, value: dict):
+        now = time.time()
+        with self.lock:
+            self._purge_expired(now)
+            old = self.items.pop(key, None)
+            if old:
+                self.total_bytes -= old["size"]
+
+            entry = {
+                "sentence": value["sentence"],
+                "sample_rate": value["sample_rate"],
+                "audio_bytes": value["audio_bytes"],
+                "ts": now,
+            }
+            entry["size"] = self._entry_size(entry)
+            self.items[key] = entry
+            self.total_bytes += entry["size"]
+
+            while self.items and (
+                len(self.items) > self.max_items or self.total_bytes > self.max_bytes
+            ):
+                _, removed = self.items.popitem(last=False)
+                self.total_bytes -= removed["size"]
+
+    def clear(self):
+        with self.lock:
+            self.items.clear()
+            self.total_bytes = 0
+
+    def info(self) -> dict:
+        now = time.time()
+        with self.lock:
+            self._purge_expired(now)
+            return {"items": len(self.items), "bytes": self.total_bytes}
+
+
+model_lock = threading.Lock()
+request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
+current_requests: Dict[str, Dict] = {}
+memory_cache = MemoryAudioCache(
+    max_items=MEMORY_CACHE_MAX_ITEMS,
+    max_bytes=MEMORY_CACHE_MAX_BYTES,
+    ttl=MEMORY_CACHE_TTL,
+)
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
+inflight_lock = threading.Lock()
+inflight_tasks: Dict[str, asyncio.Future] = {}
+
+model_session = None
+model_name = DEFAULT_MODEL_NAME
+sample_rate = DEFAULT_SAMPLE_RATE
+_KOKORO_ONNX_ENGINE = None
+_EN_G2P_PIPELINE = None
+
+
+def split_sentences(text: str) -> List[str]:
+    parts = [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
+    merged: List[str] = []
+    buf = ""
+    for part in parts:
+        if len(part) < 3 and merged:
+            merged[-1] = f"{merged[-1]} {part}".strip()
+        else:
+            merged.append(part)
+    return merged
+
+
+def iter_text_parts(text: str, split_pattern: Optional[str]) -> List[str]:
+    text = (text or "").strip()
+    if not text:
+        return []
+
+    blocks = [text]
+    if split_pattern:
+        try:
+            blocks = [p.strip() for p in re.split(split_pattern, text) if p.strip()]
+        except re.error:
+            logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
+
+    parts: List[str] = []
+    for block in blocks:
+        parts.extend(split_sentences(block) or [block])
+    return parts
+
+
+def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
+    raw = f"{sentence}|{voice}|{speed:.4f}|{model}"
+    return hashlib.md5(raw.encode("utf-8")).hexdigest()
+
+
+def sentence_cache_path(key: str) -> str:
+    return os.path.join(CACHE_DIR, f"{key}.wav")
+
+
+def meta_cache_path(key: str) -> str:
+    return os.path.join(CACHE_DIR, f"{key}.json")
+
+
+def to_mono_numpy(audio) -> np.ndarray:
+    arr = np.asarray(audio, dtype=np.float32)
+    if arr.size == 0:
+        return np.array([], dtype=np.float32)
+    if arr.ndim == 2:
+        if arr.shape[1] == 1:
+            arr = arr[:, 0]
+        elif arr.shape[0] == 1:
+            arr = arr[0]
+        else:
+            arr = arr.mean(axis=1)
+    elif arr.ndim > 2:
+        arr = arr.reshape(-1)
+    return arr.astype(np.float32, copy=False)
+
+
+def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
+    audio, sr = sf.read(wav_path)
+    buf = io.BytesIO()
+    with sf.SoundFile(buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16") as f:
+        f.write(audio)
+    return buf.getvalue(), sr
+
+
+async def load_sentence_from_disk(key: str) -> Optional[dict]:
+    wav_path = sentence_cache_path(key)
+    meta_path = meta_cache_path(key)
+    if not Path(wav_path).exists() or not Path(meta_path).exists():
+        return None
+
+    async with aiofiles.open(meta_path, "r") as f:
+        meta = json.loads(await f.read())
+
+    wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(executor, lambda: read_wav_bytes(wav_path))
+    return {"sentence": meta["sentence"], "sample_rate": sr, "audio_bytes": wav_bytes}
+
+
+async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
+    wav_path = sentence_cache_path(key)
+    meta_path = meta_cache_path(key)
+    await asyncio.get_event_loop().run_in_executor(executor, lambda: sf.write(wav_path, audio, sr, format="WAV"))
+    async with aiofiles.open(meta_path, "w") as f:
+        await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
+
+
+async def clean_disk_cache():
+    files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
+    if len(files) <= DISK_CACHE_SIZE:
+        return
+    for f in files[: len(files) - DISK_CACHE_SIZE]:
+        try:
+            f.unlink()
+            meta = Path(meta_cache_path(f.stem))
+            if meta.exists():
+                meta.unlink()
+        except Exception:
+            pass
+
+
+def resolve_model_path(name: str) -> str:
+    local_path = Path(MODEL_DIR) / name
+    if local_path.exists():
+        return str(local_path)
+    if os.path.isabs(name) and Path(name).exists():
+        return name
+    raise FileNotFoundError(f"找不到模型文件: {local_path}")
+
+
+def get_en_g2p_pipeline():
+    global _EN_G2P_PIPELINE
+    if _EN_G2P_PIPELINE is None:
+        from kokoro import KPipeline  # type: ignore
+
+        _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
+    return _EN_G2P_PIPELINE
+
+
+def _download_voice_file(voice_path: Path) -> Path:
+    raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_path.name}"
+    voice_path.parent.mkdir(parents=True, exist_ok=True)
+    tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
+    with urllib.request.urlopen(raw_url, timeout=60) as response:
+        payload = response.read()
+    with open(tmp_path, "wb") as f:
+        f.write(payload)
+    tmp_path.replace(voice_path)
+    return voice_path
+
+
+def load_onnx_session(name: str):
+    try:
+        import onnxruntime as ort  # type: ignore
+    except Exception as e:
+        raise RuntimeError("无法导入 onnxruntime。请在部署环境中安装 onnxruntime。") from e
+
+    model_path = resolve_model_path(name)
+    sess_options = ort.SessionOptions()
+    sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
+    sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
+    sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
+    sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
+    return ort.InferenceSession(model_path, sess_options=sess_options, providers=["CPUExecutionProvider"])
+
+
+def load_kokoro_engine(name: str):
+    try:
+        from kokoro_onnx import Kokoro  # type: ignore
+    except Exception as e:
+        raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
+
+    model_path = resolve_model_path(name)
+    return Kokoro(model_path=model_path, voices_path=VOICES_V1_PATH, vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None)
+
+
+def load_model(force_reload: bool = False, name: Optional[str] = None):
+    global model_session, model_name, _KOKORO_ONNX_ENGINE
+    with model_lock:
+        target = name or model_name
+        if force_reload or model_session is None or target != model_name:
+            model_session = load_onnx_session(target)
+            _KOKORO_ONNX_ENGINE = load_kokoro_engine(target)
+            model_name = target
+            logger.info("ONNX 模型加载完成: %s", resolve_model_path(target))
+        return model_session
+
+
+def get_kokoro_engine(name: Optional[str] = None):
+    global _KOKORO_ONNX_ENGINE
+    target = name or model_name
+    if _KOKORO_ONNX_ENGINE is None or target != model_name:
+        load_model(name=target)
+    return _KOKORO_ONNX_ENGINE
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    load_model()
+    yield
+
+
+app = FastAPI(title="Online TTS Service (Kokoro ONNX Optimized)", lifespan=lifespan)
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+@app.middleware("http")
+async def track_clients(request: Request, call_next):
+    client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
+    if request.url.path == "/generate":
+        current_requests[client_id] = {"active": True}
+    response = await call_next(request)
+    if request.url.path == "/generate":
+        current_requests.pop(client_id, None)
+    return response
+
+
+class TTSRequest(BaseModel):
+    text: str
+    voice: Optional[str] = "af_heart"
+    speed: Optional[float] = 1.0
+    split_pattern: Optional[str] = r"\n+"
+    model_name: Optional[str] = None
+
+    @field_validator("speed")
+    @classmethod
+    def validate_speed(cls, v):
+        if v is None:
+            return 1.0
+        v = float(v)
+        if v <= 0:
+            raise ValueError("speed 必须大于 0")
+        return v
+
+
+def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
+    if not text.strip():
+        raise HTTPException(status_code=400, detail="文本不能为空")
+
+    engine = get_kokoro_engine(name=model_name)
+    session = load_model(name=model_name)
+    if voice not in set(engine.get_voices()):
+        raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
+
+    phonemes = engine.tokenizer.phonemize(text, "en-us")
+    batched_phonemes = engine._split_phonemes(phonemes)
+    if not batched_phonemes:
+        raise HTTPException(status_code=400, detail="文本音素化失败")
+
+    voice_style = engine.get_voice_style(voice)
+    audio_segments: List[np.ndarray] = []
+    for phoneme_batch in batched_phonemes:
+        tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
+        if tokens.size == 0:
+            continue
+        feeds = {
+            "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
+            "style": np.asarray(voice_style[len(tokens)], dtype=np.float32),
+            "speed": np.asarray([speed], dtype=np.float32),
+        }
+        outputs = session.run(None, feeds)
+        if outputs:
+            audio_segments.append(to_mono_numpy(outputs[0]))
+
+    if not audio_segments:
+        raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
+    audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
+    if audio.size == 0 or not np.isfinite(audio).all():
+        raise HTTPException(status_code=500, detail="生成的音频无效")
+    return audio
+
+
+def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
+    buf = io.BytesIO()
+    with sf.SoundFile(buf, "w", samplerate=sr, channels=1, format="WAV", subtype="PCM_16") as f:
+        f.write(audio)
+    return buf.getvalue()
+
+
+def cache_item_to_response(item: dict) -> dict:
+    return {
+        "sentence": item["sentence"],
+        "sample_rate": item["sample_rate"],
+        "audio": base64.b64encode(item["audio_bytes"]).decode(),
+    }
+
+
+async def _compute_sentence_item(sentence: str, voice: str, speed: float, model: str) -> dict:
+    def _run():
+        audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model)
+        wav_bytes = encode_wav_bytes(audio, sample_rate)
+        return {"sentence": sentence, "sample_rate": sample_rate, "audio_bytes": wav_bytes}, audio
+
+    item, audio = await asyncio.to_thread(_run)
+    memory_cache.set(sentence_cache_key(sentence, voice, speed, model), item)
+    await save_sentence_to_disk(sentence_cache_key(sentence, voice, speed, model), audio, sample_rate, sentence)
+    await clean_disk_cache()
+    return item
+
+
+async def get_or_create_sentence_cache_item(sentence: str, voice: str, speed: float, model: str) -> dict:
+    key = sentence_cache_key(sentence, voice, speed, model)
+    cached = memory_cache.get(key)
+    if cached:
+        return cached
+
+    disk_item = await load_sentence_from_disk(key)
+    if disk_item:
+        memory_cache.set(key, disk_item)
+        return disk_item
+
+    loop = asyncio.get_running_loop()
+    with inflight_lock:
+        fut = inflight_tasks.get(key)
+        if fut is None:
+            fut = loop.create_task(_compute_sentence_item(sentence, voice, speed, model))
+            inflight_tasks[key] = fut
+
+    try:
+        return await fut
+    finally:
+        with inflight_lock:
+            if inflight_tasks.get(key) is fut:
+                inflight_tasks.pop(key, None)
+
+
+def synthesize_wav_bytes(text: str, voice: str, speed: float, split_pattern: Optional[str], model_name: Optional[str] = None) -> io.BytesIO:
+    parts = iter_text_parts(text, split_pattern)
+    if not parts:
+        raise HTTPException(status_code=400, detail="文本不能为空")
+
+    buffers: List[np.ndarray] = []
+    for part in parts:
+        key = sentence_cache_key(part, voice, speed, model_name or DEFAULT_MODEL_NAME)
+        cached = memory_cache.get(key)
+        if cached:
+            audio, _ = sf.read(io.BytesIO(cached["audio_bytes"]), dtype="float32")
+            buffers.append(to_mono_numpy(audio))
+            continue
+        audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
+        buffers.append(audio)
+
+    merged = np.concatenate(buffers, axis=0) if len(buffers) > 1 else buffers[0]
+    buf = io.BytesIO()
+    sf.write(buf, merged, samplerate=sample_rate, format="WAV", subtype="PCM_16")
+    buf.seek(0)
+    return buf
+
+
+@app.post("/tts", summary="POST: 传入文本返回 WAV 流")
+def tts_post(req: TTSRequest):
+    buf = synthesize_wav_bytes(
+        text=req.text,
+        voice=req.voice or "af_heart",
+        speed=req.speed if req.speed is not None else 1.0,
+        split_pattern=req.split_pattern,
+        model_name=req.model_name,
+    )
+    return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
+
+
+@app.get("/tts", summary="GET: 传入文本返回 WAV 流")
+def tts_get(
+    text: str = Query(..., description="待合成文本"),
+    voice: str = Query("af_heart"),
+    speed: float = Query(1.0),
+    model_name: str = Query(DEFAULT_MODEL_NAME),
+    split_pattern: str = Query(r"\n+"),
+):
+    if speed is None or float(speed) <= 0:
+        raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
+    buf = synthesize_wav_bytes(text=text, voice=voice, speed=float(speed), split_pattern=split_pattern, model_name=model_name)
+    return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
+
+
+@app.post("/generate")
+async def generate_audio_stream(data: Dict = Body(...)):
+    async with request_semaphore:
+        text = data.get("text", "")
+        voice = data.get("voice", "af_heart")
+        speed = float(data.get("speed", 1.0))
+        model = data.get("model_name", DEFAULT_MODEL_NAME)
+        split_pattern = data.get("split_pattern", r"\n+")
+        client_id = data.get("client_id", str(uuid.uuid4()))
+
+        if not text.strip():
+            raise HTTPException(status_code=400, detail="文本不能为空")
+
+        parts = iter_text_parts(text, split_pattern)
+        if client_id in current_requests:
+            current_requests[client_id]["interrupt"] = True
+            await asyncio.sleep(0.05)
+        current_requests[client_id] = {"interrupt": False}
+
+        async def stream():
+            try:
+                for idx, part in enumerate(parts):
+                    if current_requests.get(client_id, {}).get("interrupt"):
+                        break
+                    item = await get_or_create_sentence_cache_item(part, voice, speed, model)
+                    resp = cache_item_to_response(item)
+                    yield json.dumps({"index": idx, **resp}).encode() + b"\n"
+            finally:
+                current_requests.pop(client_id, None)
+
+        return StreamingResponse(stream(), media_type="application/x-ndjson")
+
+
+@app.get("/clear-cache")
+async def clear_cache():
+    memory_cache.clear()
+    for f in Path(CACHE_DIR).glob("*"):
+        f.unlink()
+    return {"status": "success"}
+
+
+@app.get("/cache-info")
+async def get_cache_info():
+    mem = memory_cache.info()
+    disk_files = list(Path(CACHE_DIR).glob("*.wav"))
+    return {"memory_cache": mem["items"], "memory_cache_bytes": mem["bytes"], "disk_cache": len(disk_files)}
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run("speech_tts_onnx_opt:app", host="0.0.0.0", port=18000, reload=False, workers=1)

+ 1 - 1
start_onnx.sh

@@ -1,4 +1,4 @@
 #!/bin/bash
 eval "$(/root/miniconda3/bin/conda shell.bash hook)"
 conda activate py311
-nohup uvicorn speech_tts_onnx:app --host 0.0.0.0 --port 8028 &
+nohup uvicorn speech_tts_onnx_opt:app --host 0.0.0.0 --port 8028 &