浏览代码

onnx模型优化版本,内存占用小

sequoia 1 月之前
父节点
当前提交
27e5b85adc
共有 2 个文件被更改,包括 213 次插入52 次删除
  1. 209 52
      speech_tts_onnx.py
  2. 4 0
      start_onnx.sh

+ 209 - 52
speech_tts_onnx.py

@@ -17,17 +17,20 @@ from typing import Dict, List, Optional, Tuple
 
 import numpy as np
 import soundfile as sf
+import aiofiles
 from fastapi import Body, FastAPI, HTTPException, Query, Request
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from pydantic import BaseModel, field_validator
+from cachetools import TTLCache
+import concurrent.futures
 
 
 # ============================================================
-# 配置
+# 配置 model_quantized
 # ============================================================
 
-DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model_quantized.onnx")
+DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
 MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
 HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
 TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
@@ -40,6 +43,11 @@ MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
 MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
 DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
 DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
+MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
+ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
+ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
+ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
+ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "true").lower() == "true"
 VOICE_ALIASES = {}
 
 logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
@@ -67,8 +75,11 @@ _KOKORO_ONNX_ENGINE = None
 
 model_lock = threading.Lock()
 synthesis_lock = threading.Lock()
+cache_lock = threading.Lock()
 request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
 current_requests: Dict[str, Dict] = {}
+memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=MEMORY_CACHE_TTL)
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
 
 model_session = None
 model_name = DEFAULT_MODEL_NAME
@@ -96,6 +107,16 @@ def meta_cache_path(key: str) -> str:
     return os.path.join(CACHE_DIR, f"{key}.json")
 
 
+def put_memory_cache(key: str, value: dict):
+    with cache_lock:
+        memory_cache[key] = value
+
+
+def get_memory_cache(key: str) -> Optional[dict]:
+    with cache_lock:
+        return memory_cache.get(key)
+
+
 def to_mono_numpy(audio) -> np.ndarray:
     if audio is None:
         return np.array([], dtype=np.float32)
@@ -146,6 +167,71 @@ def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
     return buf.getvalue(), sr
 
 
+async def load_sentence_from_disk(key: str) -> Optional[dict]:
+    wav_path = sentence_cache_path(key)
+    meta_path = meta_cache_path(key)
+    if not Path(wav_path).exists() or not Path(meta_path).exists():
+        return None
+
+    async with aiofiles.open(meta_path, "r") as f:
+        meta_text = await f.read()
+        meta = json.loads(meta_text)
+
+    wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(
+        executor, lambda: read_wav_bytes(wav_path)
+    )
+    meta["audio_bytes"] = wav_bytes
+    meta["audio"] = base64.b64encode(wav_bytes).decode()
+    meta["sample_rate"] = sr
+    return meta
+
+
+async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
+    wav_path = sentence_cache_path(key)
+    meta_path = meta_cache_path(key)
+
+    await asyncio.get_event_loop().run_in_executor(
+        executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
+    )
+
+    async with aiofiles.open(meta_path, "w") as f:
+        await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
+
+
+async def clean_disk_cache():
+    files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
+    if len(files) <= DISK_CACHE_SIZE:
+        return
+    remove_n = len(files) - DISK_CACHE_SIZE
+    for f in files[:remove_n]:
+        try:
+            f.unlink()
+            json_path = meta_cache_path(f.stem)
+            if Path(json_path).exists():
+                Path(json_path).unlink()
+        except Exception:
+            pass
+
+
+def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
+    buf = io.BytesIO()
+    with sf.SoundFile(
+        buf,
+        "w",
+        samplerate=sr,
+        channels=audio.shape[1] if audio.ndim > 1 else 1,
+        format="WAV",
+        subtype="PCM_16",
+    ) as f:
+        f.write(audio)
+    return buf.getvalue()
+
+
+def wav_to_pcm_payload(wav_bytes: bytes) -> bytes:
+    data, _ = sf.read(io.BytesIO(wav_bytes), dtype="int16")
+    return np.asarray(data, dtype=np.int16).tobytes()
+
+
 # ============================================================
 # 模型加载
 # ============================================================
@@ -267,7 +353,12 @@ def load_onnx_session(name: str):
 
     model_path = resolve_model_path(name)
     providers = ["CPUExecutionProvider"]
-    session = ort.InferenceSession(model_path, providers=providers)
+    sess_options = ort.SessionOptions()
+    sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
+    sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
+    sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
+    sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
+    session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
     logger.info("ONNX 模型加载完成: %s", model_path)
     return session
 
@@ -518,30 +609,68 @@ def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[s
 
 
 def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
-    segments: List[np.ndarray] = []
-    if _is_english_voice(voice):
-        parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
-    else:
-        parts = split_sentences(text) if text.strip() else []
-    if not parts:
-        parts = [text.strip()]
-
-    with synthesis_lock:
-        for part in parts:
-            audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
-            if audio.size > 0:
-                segments.append(audio)
-
-    if not segments:
-        raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
-
-    audio_concat = np.concatenate(segments, axis=0)
+    audio = synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
     buf = io.BytesIO()
-    sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
+    sf.write(buf, audio, samplerate=sample_rate, format="WAV", subtype="PCM_16")
     buf.seek(0)
     return buf
 
 
+def iter_text_parts(text: str, split_pattern: Optional[str] = None) -> List[str]:
+    text = (text or "").strip()
+    if not text:
+        return []
+
+    blocks = [text]
+    if split_pattern:
+        try:
+            blocks = [part.strip() for part in re.split(split_pattern, text) if part.strip()]
+        except re.error:
+            logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
+
+    parts: List[str] = []
+    for block in blocks:
+        sentences = split_sentences(block)
+        if sentences:
+            parts.extend(sentences)
+        elif block:
+            parts.append(block)
+    return parts
+
+
+async def get_or_create_sentence_cache_item(
+    sentence: str,
+    voice: str,
+    speed: float,
+    model_name: str,
+) -> dict:
+    key = sentence_cache_key(sentence, voice, speed, model_name)
+    cached_item = get_memory_cache(key)
+    if cached_item:
+        return cached_item
+
+    disk_item = await load_sentence_from_disk(key)
+    if disk_item:
+        put_memory_cache(key, disk_item)
+        return disk_item
+
+    def _synthesize_item():
+        audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model_name)
+        wav_bytes = encode_wav_bytes(audio, sample_rate)
+        return {
+            "sentence": sentence,
+            "sample_rate": sample_rate,
+            "audio_bytes": wav_bytes,
+            "audio": base64.b64encode(wav_bytes).decode(),
+        }, audio
+
+    item, audio = await asyncio.to_thread(_synthesize_item)
+    put_memory_cache(key, item)
+    await save_sentence_to_disk(key, audio, sample_rate, sentence)
+    await clean_disk_cache()
+    return item
+
+
 @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
 def tts_post(req: TTSRequest):
     buf = synthesize_wav_bytes(
@@ -591,12 +720,7 @@ async def generate_audio_stream(data: Dict = Body(...)):
         if not text.strip():
             raise HTTPException(status_code=400, detail="文本不能为空")
 
-        if _is_english_voice(voice):
-            parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
-        else:
-            parts = split_sentences(text) if text.strip() else []
-        if not parts:
-            parts = [text.strip()]
+        parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
 
         if client_id in current_requests:
             current_requests[client_id]["interrupt"] = True
@@ -607,37 +731,17 @@ async def generate_audio_stream(data: Dict = Body(...)):
         async def stream():
             try:
                 for idx, part in enumerate(parts):
-                    if not part:
-                        continue
                     if current_requests.get(client_id, {}).get("interrupt"):
                         break
 
-                    audio = await asyncio.to_thread(
-                        synthesize_audio,
-                        text=part,
-                        voice=voice,
-                        speed=speed,
-                        model_name=model,
-                    )
-
-                    with io.BytesIO() as buf:
-                        with sf.SoundFile(
-                            buf,
-                            "w",
-                            sample_rate,
-                            channels=audio.shape[1] if audio.ndim > 1 else 1,
-                            format="WAV",
-                            subtype="PCM_16",
-                        ) as f:
-                            f.write(audio)
-                        audio_b64 = base64.b64encode(buf.getvalue()).decode()
+                    item = await get_or_create_sentence_cache_item(part, voice, speed, model)
 
                     yield json.dumps(
                         {
                             "index": idx,
-                            "sentence": part,
-                            "audio": audio_b64,
-                            "sample_rate": sample_rate,
+                            "sentence": item["sentence"],
+                            "audio": item["audio"],
+                            "sample_rate": item["sample_rate"],
                         }
                     ).encode() + b"\n"
             finally:
@@ -646,6 +750,59 @@ async def generate_audio_stream(data: Dict = Body(...)):
         return StreamingResponse(stream(), media_type="application/x-ndjson")
 
 
+@app.post("/generate_pcm", summary="POST: 返回更轻量的 PCM 分片流")
+async def generate_audio_pcm_stream(data: Dict = Body(...)):
+    async with request_semaphore:
+        text = data.get("text", "")
+        voice = data.get("voice", "af_heart")
+        speed = float(data.get("speed", 1.0))
+        model = data.get("model_name", DEFAULT_MODEL_NAME)
+
+        if not text.strip():
+            raise HTTPException(status_code=400, detail="文本不能为空")
+
+        parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
+
+        async def pcm_stream():
+            for idx, part in enumerate(parts):
+                item = await get_or_create_sentence_cache_item(part, voice, speed, model)
+                payload = wav_to_pcm_payload(item["audio_bytes"])
+                header = json.dumps(
+                    {
+                        "index": idx,
+                        "sentence": item["sentence"],
+                        "sample_rate": item["sample_rate"],
+                        "format": "s16le",
+                        "bytes": len(payload),
+                    }
+                ).encode() + b"\n"
+                yield header
+                yield payload
+
+        return StreamingResponse(
+            pcm_stream(),
+            media_type="application/octet-stream",
+            headers={"X-Audio-Format": "s16le", "X-Sample-Rate": str(sample_rate)},
+        )
+
+
+@app.get("/clear-cache")
+async def clear_cache():
+    with cache_lock:
+        memory_cache.clear()
+    for f in Path(CACHE_DIR).glob("*"):
+        f.unlink()
+    return {"status": "success"}
+
+
+@app.get("/cache-info")
+async def get_cache_info():
+    with cache_lock:
+        mem_count = len(memory_cache)
+    disk_files = list(Path(CACHE_DIR).glob("*.wav"))
+    return {"memory_cache": mem_count, "disk_cache": len(disk_files)}
+
+
 if __name__ == "__main__":
     import uvicorn
 

+ 4 - 0
start_onnx.sh

@@ -0,0 +1,4 @@
+#!/bin/bash
+eval "$(/root/miniconda3/bin/conda shell.bash hook)"
+conda activate py311
+nohup uvicorn speech_tts_onnx:app --host 0.0.0.0 --port 8028 &