Просмотр исходного кода

onnx模型优化版本,内存占用小

sequoia 1 месяц назад
Родитель
Сommit
27e5b85adc
2 измененных файлов с 213 добавлено и 52 удалено
  1. 209 52
      speech_tts_onnx.py
  2. 4 0
      start_onnx.sh

+ 209 - 52
speech_tts_onnx.py

@@ -17,17 +17,20 @@ from typing import Dict, List, Optional, Tuple
 
 
 import numpy as np
 import numpy as np
 import soundfile as sf
 import soundfile as sf
+import aiofiles
 from fastapi import Body, FastAPI, HTTPException, Query, Request
 from fastapi import Body, FastAPI, HTTPException, Query, Request
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 from pydantic import BaseModel, field_validator
 from pydantic import BaseModel, field_validator
+from cachetools import TTLCache
+import concurrent.futures
 
 
 
 
 # ============================================================
 # ============================================================
-# 配置
+# 配置 model_quantized
 # ============================================================
 # ============================================================
 
 
-DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model_quantized.onnx")
+DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
 MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
 MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
 HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
 HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
 TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
 TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
@@ -40,6 +43,11 @@ MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
 MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
 MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
 DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
 DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
 DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
 DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
+MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
+ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
+ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
+ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
+ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "true").lower() == "true"
 VOICE_ALIASES = {}
 VOICE_ALIASES = {}
 
 
 logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
 logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
@@ -67,8 +75,11 @@ _KOKORO_ONNX_ENGINE = None
 
 
 model_lock = threading.Lock()
 model_lock = threading.Lock()
 synthesis_lock = threading.Lock()
 synthesis_lock = threading.Lock()
+cache_lock = threading.Lock()
 request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
 request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
 current_requests: Dict[str, Dict] = {}
 current_requests: Dict[str, Dict] = {}
+memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=MEMORY_CACHE_TTL)
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
 
 
 model_session = None
 model_session = None
 model_name = DEFAULT_MODEL_NAME
 model_name = DEFAULT_MODEL_NAME
@@ -96,6 +107,16 @@ def meta_cache_path(key: str) -> str:
     return os.path.join(CACHE_DIR, f"{key}.json")
     return os.path.join(CACHE_DIR, f"{key}.json")
 
 
 
 
+def put_memory_cache(key: str, value: dict):
+    with cache_lock:
+        memory_cache[key] = value
+
+
+def get_memory_cache(key: str) -> Optional[dict]:
+    with cache_lock:
+        return memory_cache.get(key)
+
+
 def to_mono_numpy(audio) -> np.ndarray:
 def to_mono_numpy(audio) -> np.ndarray:
     if audio is None:
     if audio is None:
         return np.array([], dtype=np.float32)
         return np.array([], dtype=np.float32)
@@ -146,6 +167,71 @@ def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
     return buf.getvalue(), sr
     return buf.getvalue(), sr
 
 
 
 
+async def load_sentence_from_disk(key: str) -> Optional[dict]:
+    wav_path = sentence_cache_path(key)
+    meta_path = meta_cache_path(key)
+    if not Path(wav_path).exists() or not Path(meta_path).exists():
+        return None
+
+    async with aiofiles.open(meta_path, "r") as f:
+        meta_text = await f.read()
+        meta = json.loads(meta_text)
+
+    wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(
+        executor, lambda: read_wav_bytes(wav_path)
+    )
+    meta["audio_bytes"] = wav_bytes
+    meta["audio"] = base64.b64encode(wav_bytes).decode()
+    meta["sample_rate"] = sr
+    return meta
+
+
+async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
+    wav_path = sentence_cache_path(key)
+    meta_path = meta_cache_path(key)
+
+    await asyncio.get_event_loop().run_in_executor(
+        executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
+    )
+
+    async with aiofiles.open(meta_path, "w") as f:
+        await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
+
+
+async def clean_disk_cache():
+    files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
+    if len(files) <= DISK_CACHE_SIZE:
+        return
+    remove_n = len(files) - DISK_CACHE_SIZE
+    for f in files[:remove_n]:
+        try:
+            f.unlink()
+            json_path = meta_cache_path(f.stem)
+            if Path(json_path).exists():
+                Path(json_path).unlink()
+        except Exception:
+            pass
+
+
+def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
+    buf = io.BytesIO()
+    with sf.SoundFile(
+        buf,
+        "w",
+        samplerate=sr,
+        channels=audio.shape[1] if audio.ndim > 1 else 1,
+        format="WAV",
+        subtype="PCM_16",
+    ) as f:
+        f.write(audio)
+    return buf.getvalue()
+
+
+def wav_to_pcm_payload(wav_bytes: bytes) -> bytes:
+    data, _ = sf.read(io.BytesIO(wav_bytes), dtype="int16")
+    return np.asarray(data, dtype=np.int16).tobytes()
+
+
 # ============================================================
 # ============================================================
 # 模型加载
 # 模型加载
 # ============================================================
 # ============================================================
@@ -267,7 +353,12 @@ def load_onnx_session(name: str):
 
 
     model_path = resolve_model_path(name)
     model_path = resolve_model_path(name)
     providers = ["CPUExecutionProvider"]
     providers = ["CPUExecutionProvider"]
-    session = ort.InferenceSession(model_path, providers=providers)
+    sess_options = ort.SessionOptions()
+    sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
+    sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
+    sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
+    sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
+    session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
     logger.info("ONNX 模型加载完成: %s", model_path)
     logger.info("ONNX 模型加载完成: %s", model_path)
     return session
     return session
 
 
@@ -518,30 +609,68 @@ def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[s
 
 
 
 
 def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
 def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
-    segments: List[np.ndarray] = []
-    if _is_english_voice(voice):
-        parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
-    else:
-        parts = split_sentences(text) if text.strip() else []
-    if not parts:
-        parts = [text.strip()]
-
-    with synthesis_lock:
-        for part in parts:
-            audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
-            if audio.size > 0:
-                segments.append(audio)
-
-    if not segments:
-        raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
-
-    audio_concat = np.concatenate(segments, axis=0)
+    audio = synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
     buf = io.BytesIO()
     buf = io.BytesIO()
-    sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
+    sf.write(buf, audio, samplerate=sample_rate, format="WAV", subtype="PCM_16")
     buf.seek(0)
     buf.seek(0)
     return buf
     return buf
 
 
 
 
+def iter_text_parts(text: str, split_pattern: Optional[str] = None) -> List[str]:
+    text = (text or "").strip()
+    if not text:
+        return []
+
+    blocks = [text]
+    if split_pattern:
+        try:
+            blocks = [part.strip() for part in re.split(split_pattern, text) if part.strip()]
+        except re.error:
+            logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
+
+    parts: List[str] = []
+    for block in blocks:
+        sentences = split_sentences(block)
+        if sentences:
+            parts.extend(sentences)
+        elif block:
+            parts.append(block)
+    return parts
+
+
+async def get_or_create_sentence_cache_item(
+    sentence: str,
+    voice: str,
+    speed: float,
+    model_name: str,
+) -> dict:
+    key = sentence_cache_key(sentence, voice, speed, model_name)
+    cached_item = get_memory_cache(key)
+    if cached_item:
+        return cached_item
+
+    disk_item = await load_sentence_from_disk(key)
+    if disk_item:
+        put_memory_cache(key, disk_item)
+        return disk_item
+
+    def _synthesize_item():
+        audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model_name)
+        wav_bytes = encode_wav_bytes(audio, sample_rate)
+        return {
+            "sentence": sentence,
+            "sample_rate": sample_rate,
+            "audio_bytes": wav_bytes,
+            "audio": base64.b64encode(wav_bytes).decode(),
+        }, audio
+
+    item, audio = await asyncio.to_thread(_synthesize_item)
+    put_memory_cache(key, item)
+    await save_sentence_to_disk(key, audio, sample_rate, sentence)
+    await clean_disk_cache()
+    return item
+
+
 @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
 @app.post("/tts", summary="POST: 传入文本返回 WAV 流")
 def tts_post(req: TTSRequest):
 def tts_post(req: TTSRequest):
     buf = synthesize_wav_bytes(
     buf = synthesize_wav_bytes(
@@ -591,12 +720,7 @@ async def generate_audio_stream(data: Dict = Body(...)):
         if not text.strip():
         if not text.strip():
             raise HTTPException(status_code=400, detail="文本不能为空")
             raise HTTPException(status_code=400, detail="文本不能为空")
 
 
-        if _is_english_voice(voice):
-            parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
-        else:
-            parts = split_sentences(text) if text.strip() else []
-        if not parts:
-            parts = [text.strip()]
+        parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
 
 
         if client_id in current_requests:
         if client_id in current_requests:
             current_requests[client_id]["interrupt"] = True
             current_requests[client_id]["interrupt"] = True
@@ -607,37 +731,17 @@ async def generate_audio_stream(data: Dict = Body(...)):
         async def stream():
         async def stream():
             try:
             try:
                 for idx, part in enumerate(parts):
                 for idx, part in enumerate(parts):
-                    if not part:
-                        continue
                     if current_requests.get(client_id, {}).get("interrupt"):
                     if current_requests.get(client_id, {}).get("interrupt"):
                         break
                         break
 
 
-                    audio = await asyncio.to_thread(
-                        synthesize_audio,
-                        text=part,
-                        voice=voice,
-                        speed=speed,
-                        model_name=model,
-                    )
-
-                    with io.BytesIO() as buf:
-                        with sf.SoundFile(
-                            buf,
-                            "w",
-                            sample_rate,
-                            channels=audio.shape[1] if audio.ndim > 1 else 1,
-                            format="WAV",
-                            subtype="PCM_16",
-                        ) as f:
-                            f.write(audio)
-                        audio_b64 = base64.b64encode(buf.getvalue()).decode()
+                    item = await get_or_create_sentence_cache_item(part, voice, speed, model)
 
 
                     yield json.dumps(
                     yield json.dumps(
                         {
                         {
                             "index": idx,
                             "index": idx,
-                            "sentence": part,
-                            "audio": audio_b64,
-                            "sample_rate": sample_rate,
+                            "sentence": item["sentence"],
+                            "audio": item["audio"],
+                            "sample_rate": item["sample_rate"],
                         }
                         }
                     ).encode() + b"\n"
                     ).encode() + b"\n"
             finally:
             finally:
@@ -646,6 +750,59 @@ async def generate_audio_stream(data: Dict = Body(...)):
         return StreamingResponse(stream(), media_type="application/x-ndjson")
         return StreamingResponse(stream(), media_type="application/x-ndjson")
 
 
 
 
+@app.post("/generate_pcm", summary="POST: 返回更轻量的 PCM 分片流")
+async def generate_audio_pcm_stream(data: Dict = Body(...)):
+    async with request_semaphore:
+        text = data.get("text", "")
+        voice = data.get("voice", "af_heart")
+        speed = float(data.get("speed", 1.0))
+        model = data.get("model_name", DEFAULT_MODEL_NAME)
+
+        if not text.strip():
+            raise HTTPException(status_code=400, detail="文本不能为空")
+
+        parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
+
+        async def pcm_stream():
+            for idx, part in enumerate(parts):
+                item = await get_or_create_sentence_cache_item(part, voice, speed, model)
+                payload = wav_to_pcm_payload(item["audio_bytes"])
+                header = json.dumps(
+                    {
+                        "index": idx,
+                        "sentence": item["sentence"],
+                        "sample_rate": item["sample_rate"],
+                        "format": "s16le",
+                        "bytes": len(payload),
+                    }
+                ).encode() + b"\n"
+                yield header
+                yield payload
+
+        return StreamingResponse(
+            pcm_stream(),
+            media_type="application/octet-stream",
+            headers={"X-Audio-Format": "s16le", "X-Sample-Rate": str(sample_rate)},
+        )
+
+
+@app.get("/clear-cache")
+async def clear_cache():
+    with cache_lock:
+        memory_cache.clear()
+    for f in Path(CACHE_DIR).glob("*"):
+        f.unlink()
+    return {"status": "success"}
+
+
+@app.get("/cache-info")
+async def get_cache_info():
+    with cache_lock:
+        mem_count = len(memory_cache)
+    disk_files = list(Path(CACHE_DIR).glob("*.wav"))
+    return {"memory_cache": mem_count, "disk_cache": len(disk_files)}
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     import uvicorn
     import uvicorn
 
 

+ 4 - 0
start_onnx.sh

@@ -0,0 +1,4 @@
+#!/bin/bash
+eval "$(/root/miniconda3/bin/conda shell.bash hook)"
+conda activate py311
+nohup uvicorn speech_tts_onnx:app --host 0.0.0.0 --port 8028 &