|
@@ -0,0 +1,567 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+import asyncio
|
|
|
|
|
+import base64
|
|
|
|
|
+import concurrent.futures
|
|
|
|
|
+import hashlib
|
|
|
|
|
+import io
|
|
|
|
|
+import json
|
|
|
|
|
+import logging
|
|
|
|
|
+import os
|
|
|
|
|
+import re
|
|
|
|
|
+import threading
|
|
|
|
|
+import time
|
|
|
|
|
+import uuid
|
|
|
|
|
+import urllib.request
|
|
|
|
|
+from collections import OrderedDict
|
|
|
|
|
+from contextlib import asynccontextmanager
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Dict, List, Optional, Tuple
|
|
|
|
|
+
|
|
|
|
|
+import aiofiles
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import soundfile as sf
|
|
|
|
|
+from fastapi import Body, FastAPI, HTTPException, Query, Request
|
|
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
|
|
+from pydantic import BaseModel, field_validator
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
|
|
|
|
|
+MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
|
|
|
|
|
+HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
|
|
|
|
|
+CONFIG_PATH = os.getenv("TTS_ONNX_CONFIG_PATH", str(Path(MODEL_DIR) / "config.json"))
|
|
|
|
|
+VOICES_DIR = os.getenv("TTS_ONNX_VOICES_DIR", str(Path(MODEL_DIR) / "voices"))
|
|
|
|
|
+VOICES_V1_PATH = os.getenv("TTS_ONNX_VOICES_V1_PATH", str(Path(MODEL_DIR) / "voices-v1.0.bin"))
|
|
|
|
|
+CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
|
|
|
|
|
+LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
|
|
|
+MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
|
|
|
|
|
+DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
|
|
|
|
|
+MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
|
|
|
|
|
+MEMORY_CACHE_MAX_ITEMS = int(os.getenv("MEMORY_CACHE_SIZE", 120))
|
|
|
|
|
+MEMORY_CACHE_MAX_BYTES = int(os.getenv("MEMORY_CACHE_MAX_BYTES", str(128 * 1024 * 1024)))
|
|
|
|
|
+DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
|
|
|
|
|
+ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
|
|
|
|
|
+ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
|
|
|
|
|
+ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
|
|
|
|
|
+ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "false").lower() == "true"
|
|
|
|
|
+VOICE_ALIASES: Dict[str, str] = {}
|
|
|
|
|
+
|
|
|
|
|
+logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
|
|
|
|
|
+logger = logging.getLogger("speech_tts_onnx_opt")
|
|
|
|
|
+
|
|
|
|
|
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+SENT_SPLIT_RE = re.compile(r"(?<=[。!?;.!?;::,,])\s*|\n+")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _PhonemizerWordCountMismatchFilter(logging.Filter):
|
|
|
|
|
+ def filter(self, record: logging.LogRecord) -> bool:
|
|
|
|
|
+ return "words count mismatch" not in record.getMessage()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+logging.getLogger("phonemizer").addFilter(_PhonemizerWordCountMismatchFilter())
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class MemoryAudioCache:
|
|
|
|
|
+ def __init__(self, max_items: int, max_bytes: int, ttl: int):
|
|
|
|
|
+ self.max_items = max_items
|
|
|
|
|
+ self.max_bytes = max_bytes
|
|
|
|
|
+ self.ttl = ttl
|
|
|
|
|
+ self.lock = threading.Lock()
|
|
|
|
|
+ self.items: "OrderedDict[str, dict]" = OrderedDict()
|
|
|
|
|
+ self.total_bytes = 0
|
|
|
|
|
+
|
|
|
|
|
+ def _entry_size(self, value: dict) -> int:
|
|
|
|
|
+ return len(value.get("audio_bytes", b"")) + len(value.get("sentence", "").encode("utf-8")) + 128
|
|
|
|
|
+
|
|
|
|
|
+ def _purge_expired(self, now: float):
|
|
|
|
|
+ expired = [k for k, v in self.items.items() if now - v["ts"] > self.ttl]
|
|
|
|
|
+ for key in expired:
|
|
|
|
|
+ entry = self.items.pop(key, None)
|
|
|
|
|
+ if entry:
|
|
|
|
|
+ self.total_bytes -= entry["size"]
|
|
|
|
|
+
|
|
|
|
|
+ def get(self, key: str) -> Optional[dict]:
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ self._purge_expired(now)
|
|
|
|
|
+ entry = self.items.get(key)
|
|
|
|
|
+ if not entry:
|
|
|
|
|
+ return None
|
|
|
|
|
+ self.items.move_to_end(key)
|
|
|
|
|
+ entry["ts"] = now
|
|
|
|
|
+ return {
|
|
|
|
|
+ "sentence": entry["sentence"],
|
|
|
|
|
+ "sample_rate": entry["sample_rate"],
|
|
|
|
|
+ "audio_bytes": entry["audio_bytes"],
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def set(self, key: str, value: dict):
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ self._purge_expired(now)
|
|
|
|
|
+ old = self.items.pop(key, None)
|
|
|
|
|
+ if old:
|
|
|
|
|
+ self.total_bytes -= old["size"]
|
|
|
|
|
+
|
|
|
|
|
+ entry = {
|
|
|
|
|
+ "sentence": value["sentence"],
|
|
|
|
|
+ "sample_rate": value["sample_rate"],
|
|
|
|
|
+ "audio_bytes": value["audio_bytes"],
|
|
|
|
|
+ "ts": now,
|
|
|
|
|
+ }
|
|
|
|
|
+ entry["size"] = self._entry_size(entry)
|
|
|
|
|
+ self.items[key] = entry
|
|
|
|
|
+ self.total_bytes += entry["size"]
|
|
|
|
|
+
|
|
|
|
|
+ while self.items and (
|
|
|
|
|
+ len(self.items) > self.max_items or self.total_bytes > self.max_bytes
|
|
|
|
|
+ ):
|
|
|
|
|
+ _, removed = self.items.popitem(last=False)
|
|
|
|
|
+ self.total_bytes -= removed["size"]
|
|
|
|
|
+
|
|
|
|
|
+ def clear(self):
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ self.items.clear()
|
|
|
|
|
+ self.total_bytes = 0
|
|
|
|
|
+
|
|
|
|
|
+ def info(self) -> dict:
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ self._purge_expired(now)
|
|
|
|
|
+ return {"items": len(self.items), "bytes": self.total_bytes}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+model_lock = threading.Lock()
|
|
|
|
|
+request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
|
|
|
|
+current_requests: Dict[str, Dict] = {}
|
|
|
|
|
+memory_cache = MemoryAudioCache(
|
|
|
|
|
+ max_items=MEMORY_CACHE_MAX_ITEMS,
|
|
|
|
|
+ max_bytes=MEMORY_CACHE_MAX_BYTES,
|
|
|
|
|
+ ttl=MEMORY_CACHE_TTL,
|
|
|
|
|
+)
|
|
|
|
|
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
|
|
|
|
+inflight_lock = threading.Lock()
|
|
|
|
|
+inflight_tasks: Dict[str, asyncio.Future] = {}
|
|
|
|
|
+
|
|
|
|
|
+model_session = None
|
|
|
|
|
+model_name = DEFAULT_MODEL_NAME
|
|
|
|
|
+sample_rate = DEFAULT_SAMPLE_RATE
|
|
|
|
|
+_KOKORO_ONNX_ENGINE = None
|
|
|
|
|
+_EN_G2P_PIPELINE = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def split_sentences(text: str) -> List[str]:
|
|
|
|
|
+ parts = [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
|
|
|
|
|
+ merged: List[str] = []
|
|
|
|
|
+ buf = ""
|
|
|
|
|
+ for part in parts:
|
|
|
|
|
+ if len(part) < 3 and merged:
|
|
|
|
|
+ merged[-1] = f"{merged[-1]} {part}".strip()
|
|
|
|
|
+ else:
|
|
|
|
|
+ merged.append(part)
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def iter_text_parts(text: str, split_pattern: Optional[str]) -> List[str]:
|
|
|
|
|
+ text = (text or "").strip()
|
|
|
|
|
+ if not text:
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ blocks = [text]
|
|
|
|
|
+ if split_pattern:
|
|
|
|
|
+ try:
|
|
|
|
|
+ blocks = [p.strip() for p in re.split(split_pattern, text) if p.strip()]
|
|
|
|
|
+ except re.error:
|
|
|
|
|
+ logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
|
|
|
|
|
+
|
|
|
|
|
+ parts: List[str] = []
|
|
|
|
|
+ for block in blocks:
|
|
|
|
|
+ parts.extend(split_sentences(block) or [block])
|
|
|
|
|
+ return parts
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def sentence_cache_key(sentence: str, voice: str, speed: float, model: str) -> str:
|
|
|
|
|
+ raw = f"{sentence}|{voice}|{speed:.4f}|{model}"
|
|
|
|
|
+ return hashlib.md5(raw.encode("utf-8")).hexdigest()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def sentence_cache_path(key: str) -> str:
|
|
|
|
|
+ return os.path.join(CACHE_DIR, f"{key}.wav")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def meta_cache_path(key: str) -> str:
|
|
|
|
|
+ return os.path.join(CACHE_DIR, f"{key}.json")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def to_mono_numpy(audio) -> np.ndarray:
|
|
|
|
|
+ arr = np.asarray(audio, dtype=np.float32)
|
|
|
|
|
+ if arr.size == 0:
|
|
|
|
|
+ return np.array([], dtype=np.float32)
|
|
|
|
|
+ if arr.ndim == 2:
|
|
|
|
|
+ if arr.shape[1] == 1:
|
|
|
|
|
+ arr = arr[:, 0]
|
|
|
|
|
+ elif arr.shape[0] == 1:
|
|
|
|
|
+ arr = arr[0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ arr = arr.mean(axis=1)
|
|
|
|
|
+ elif arr.ndim > 2:
|
|
|
|
|
+ arr = arr.reshape(-1)
|
|
|
|
|
+ return arr.astype(np.float32, copy=False)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
|
|
|
|
|
+ audio, sr = sf.read(wav_path)
|
|
|
|
|
+ buf = io.BytesIO()
|
|
|
|
|
+ with sf.SoundFile(buf, "w", samplerate=sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format="WAV", subtype="PCM_16") as f:
|
|
|
|
|
+ f.write(audio)
|
|
|
|
|
+ return buf.getvalue(), sr
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def load_sentence_from_disk(key: str) -> Optional[dict]:
|
|
|
|
|
+ wav_path = sentence_cache_path(key)
|
|
|
|
|
+ meta_path = meta_cache_path(key)
|
|
|
|
|
+ if not Path(wav_path).exists() or not Path(meta_path).exists():
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ async with aiofiles.open(meta_path, "r") as f:
|
|
|
|
|
+ meta = json.loads(await f.read())
|
|
|
|
|
+
|
|
|
|
|
+ wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(executor, lambda: read_wav_bytes(wav_path))
|
|
|
|
|
+ return {"sentence": meta["sentence"], "sample_rate": sr, "audio_bytes": wav_bytes}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
|
|
|
|
|
+ wav_path = sentence_cache_path(key)
|
|
|
|
|
+ meta_path = meta_cache_path(key)
|
|
|
|
|
+ await asyncio.get_event_loop().run_in_executor(executor, lambda: sf.write(wav_path, audio, sr, format="WAV"))
|
|
|
|
|
+ async with aiofiles.open(meta_path, "w") as f:
|
|
|
|
|
+ await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def clean_disk_cache():
|
|
|
|
|
+ files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
|
|
|
|
|
+ if len(files) <= DISK_CACHE_SIZE:
|
|
|
|
|
+ return
|
|
|
|
|
+ for f in files[: len(files) - DISK_CACHE_SIZE]:
|
|
|
|
|
+ try:
|
|
|
|
|
+ f.unlink()
|
|
|
|
|
+ meta = Path(meta_cache_path(f.stem))
|
|
|
|
|
+ if meta.exists():
|
|
|
|
|
+ meta.unlink()
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def resolve_model_path(name: str) -> str:
|
|
|
|
|
+ local_path = Path(MODEL_DIR) / name
|
|
|
|
|
+ if local_path.exists():
|
|
|
|
|
+ return str(local_path)
|
|
|
|
|
+ if os.path.isabs(name) and Path(name).exists():
|
|
|
|
|
+ return name
|
|
|
|
|
+ raise FileNotFoundError(f"找不到模型文件: {local_path}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_en_g2p_pipeline():
|
|
|
|
|
+ global _EN_G2P_PIPELINE
|
|
|
|
|
+ if _EN_G2P_PIPELINE is None:
|
|
|
|
|
+ from kokoro import KPipeline # type: ignore
|
|
|
|
|
+
|
|
|
|
|
+ _EN_G2P_PIPELINE = KPipeline(lang_code="a", repo_id="hexgrad/Kokoro-82M", model=False)
|
|
|
|
|
+ return _EN_G2P_PIPELINE
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _download_voice_file(voice_path: Path) -> Path:
|
|
|
|
|
+ raw_url = f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/voices/{voice_path.name}"
|
|
|
|
|
+ voice_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ tmp_path = voice_path.with_suffix(voice_path.suffix + ".tmp")
|
|
|
|
|
+ with urllib.request.urlopen(raw_url, timeout=60) as response:
|
|
|
|
|
+ payload = response.read()
|
|
|
|
|
+ with open(tmp_path, "wb") as f:
|
|
|
|
|
+ f.write(payload)
|
|
|
|
|
+ tmp_path.replace(voice_path)
|
|
|
|
|
+ return voice_path
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def load_onnx_session(name: str):
|
|
|
|
|
+ try:
|
|
|
|
|
+ import onnxruntime as ort # type: ignore
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ raise RuntimeError("无法导入 onnxruntime。请在部署环境中安装 onnxruntime。") from e
|
|
|
|
|
+
|
|
|
|
|
+ model_path = resolve_model_path(name)
|
|
|
|
|
+ sess_options = ort.SessionOptions()
|
|
|
|
|
+ sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
|
|
|
|
|
+ sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
|
|
|
|
|
+ sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
|
|
|
|
|
+ sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
|
|
|
|
|
+ return ort.InferenceSession(model_path, sess_options=sess_options, providers=["CPUExecutionProvider"])
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def load_kokoro_engine(name: str):
|
|
|
|
|
+ try:
|
|
|
|
|
+ from kokoro_onnx import Kokoro # type: ignore
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ raise RuntimeError("无法导入 kokoro_onnx。请在部署环境中安装 kokoro-onnx。") from e
|
|
|
|
|
+
|
|
|
|
|
+ model_path = resolve_model_path(name)
|
|
|
|
|
+ return Kokoro(model_path=model_path, voices_path=VOICES_V1_PATH, vocab_config=CONFIG_PATH if Path(CONFIG_PATH).exists() else None)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def load_model(force_reload: bool = False, name: Optional[str] = None):
|
|
|
|
|
+ global model_session, model_name, _KOKORO_ONNX_ENGINE
|
|
|
|
|
+ with model_lock:
|
|
|
|
|
+ target = name or model_name
|
|
|
|
|
+ if force_reload or model_session is None or target != model_name:
|
|
|
|
|
+ model_session = load_onnx_session(target)
|
|
|
|
|
+ _KOKORO_ONNX_ENGINE = load_kokoro_engine(target)
|
|
|
|
|
+ model_name = target
|
|
|
|
|
+ logger.info("ONNX 模型加载完成: %s", resolve_model_path(target))
|
|
|
|
|
+ return model_session
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_kokoro_engine(name: Optional[str] = None):
|
|
|
|
|
+ global _KOKORO_ONNX_ENGINE
|
|
|
|
|
+ target = name or model_name
|
|
|
|
|
+ if _KOKORO_ONNX_ENGINE is None or target != model_name:
|
|
|
|
|
+ load_model(name=target)
|
|
|
|
|
+ return _KOKORO_ONNX_ENGINE
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@asynccontextmanager
|
|
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
|
|
+ load_model()
|
|
|
|
|
+ yield
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+app = FastAPI(title="Online TTS Service (Kokoro ONNX Optimized)", lifespan=lifespan)
|
|
|
|
|
+app.add_middleware(
|
|
|
|
|
+ CORSMiddleware,
|
|
|
|
|
+ allow_origins=["*"],
|
|
|
|
|
+ allow_credentials=True,
|
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.middleware("http")
|
|
|
|
|
+async def track_clients(request: Request, call_next):
|
|
|
|
|
+ client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
|
|
|
|
|
+ if request.url.path == "/generate":
|
|
|
|
|
+ current_requests[client_id] = {"active": True}
|
|
|
|
|
+ response = await call_next(request)
|
|
|
|
|
+ if request.url.path == "/generate":
|
|
|
|
|
+ current_requests.pop(client_id, None)
|
|
|
|
|
+ return response
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TTSRequest(BaseModel):
|
|
|
|
|
+ text: str
|
|
|
|
|
+ voice: Optional[str] = "af_heart"
|
|
|
|
|
+ speed: Optional[float] = 1.0
|
|
|
|
|
+ split_pattern: Optional[str] = r"\n+"
|
|
|
|
|
+ model_name: Optional[str] = None
|
|
|
|
|
+
|
|
|
|
|
+ @field_validator("speed")
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def validate_speed(cls, v):
|
|
|
|
|
+ if v is None:
|
|
|
|
|
+ return 1.0
|
|
|
|
|
+ v = float(v)
|
|
|
|
|
+ if v <= 0:
|
|
|
|
|
+ raise ValueError("speed 必须大于 0")
|
|
|
|
|
+ return v
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> np.ndarray:
|
|
|
|
|
+ if not text.strip():
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ engine = get_kokoro_engine(name=model_name)
|
|
|
|
|
+ session = load_model(name=model_name)
|
|
|
|
|
+ if voice not in set(engine.get_voices()):
|
|
|
|
|
+ raise HTTPException(status_code=400, detail=f"不支持的 voice: {voice}")
|
|
|
|
|
+
|
|
|
|
|
+ phonemes = engine.tokenizer.phonemize(text, "en-us")
|
|
|
|
|
+ batched_phonemes = engine._split_phonemes(phonemes)
|
|
|
|
|
+ if not batched_phonemes:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="文本音素化失败")
|
|
|
|
|
+
|
|
|
|
|
+ voice_style = engine.get_voice_style(voice)
|
|
|
|
|
+ audio_segments: List[np.ndarray] = []
|
|
|
|
|
+ for phoneme_batch in batched_phonemes:
|
|
|
|
|
+ tokens = np.array(engine.tokenizer.tokenize(phoneme_batch), dtype=np.int64)
|
|
|
|
|
+ if tokens.size == 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+ feeds = {
|
|
|
|
|
+ "input_ids": np.asarray([[0, *tokens.tolist(), 0]], dtype=np.int64),
|
|
|
|
|
+ "style": np.asarray(voice_style[len(tokens)], dtype=np.float32),
|
|
|
|
|
+ "speed": np.asarray([speed], dtype=np.float32),
|
|
|
|
|
+ }
|
|
|
|
|
+ outputs = session.run(None, feeds)
|
|
|
|
|
+ if outputs:
|
|
|
|
|
+ audio_segments.append(to_mono_numpy(outputs[0]))
|
|
|
|
|
+
|
|
|
|
|
+ if not audio_segments:
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="ONNX 推理未返回音频输出")
|
|
|
|
|
+ audio = np.concatenate(audio_segments, axis=0) if len(audio_segments) > 1 else audio_segments[0]
|
|
|
|
|
+ if audio.size == 0 or not np.isfinite(audio).all():
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="生成的音频无效")
|
|
|
|
|
+ return audio
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
|
|
|
|
|
+ buf = io.BytesIO()
|
|
|
|
|
+ with sf.SoundFile(buf, "w", samplerate=sr, channels=1, format="WAV", subtype="PCM_16") as f:
|
|
|
|
|
+ f.write(audio)
|
|
|
|
|
+ return buf.getvalue()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def cache_item_to_response(item: dict) -> dict:
|
|
|
|
|
+ return {
|
|
|
|
|
+ "sentence": item["sentence"],
|
|
|
|
|
+ "sample_rate": item["sample_rate"],
|
|
|
|
|
+ "audio": base64.b64encode(item["audio_bytes"]).decode(),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _compute_sentence_item(sentence: str, voice: str, speed: float, model: str) -> dict:
|
|
|
|
|
+ def _run():
|
|
|
|
|
+ audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model)
|
|
|
|
|
+ wav_bytes = encode_wav_bytes(audio, sample_rate)
|
|
|
|
|
+ return {"sentence": sentence, "sample_rate": sample_rate, "audio_bytes": wav_bytes}, audio
|
|
|
|
|
+
|
|
|
|
|
+ item, audio = await asyncio.to_thread(_run)
|
|
|
|
|
+ memory_cache.set(sentence_cache_key(sentence, voice, speed, model), item)
|
|
|
|
|
+ await save_sentence_to_disk(sentence_cache_key(sentence, voice, speed, model), audio, sample_rate, sentence)
|
|
|
|
|
+ await clean_disk_cache()
|
|
|
|
|
+ return item
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def get_or_create_sentence_cache_item(sentence: str, voice: str, speed: float, model: str) -> dict:
|
|
|
|
|
+ key = sentence_cache_key(sentence, voice, speed, model)
|
|
|
|
|
+ cached = memory_cache.get(key)
|
|
|
|
|
+ if cached:
|
|
|
|
|
+ return cached
|
|
|
|
|
+
|
|
|
|
|
+ disk_item = await load_sentence_from_disk(key)
|
|
|
|
|
+ if disk_item:
|
|
|
|
|
+ memory_cache.set(key, disk_item)
|
|
|
|
|
+ return disk_item
|
|
|
|
|
+
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ with inflight_lock:
|
|
|
|
|
+ fut = inflight_tasks.get(key)
|
|
|
|
|
+ if fut is None:
|
|
|
|
|
+ fut = loop.create_task(_compute_sentence_item(sentence, voice, speed, model))
|
|
|
|
|
+ inflight_tasks[key] = fut
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ return await fut
|
|
|
|
|
+ finally:
|
|
|
|
|
+ with inflight_lock:
|
|
|
|
|
+ if inflight_tasks.get(key) is fut:
|
|
|
|
|
+ inflight_tasks.pop(key, None)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def synthesize_wav_bytes(text: str, voice: str, speed: float, split_pattern: Optional[str], model_name: Optional[str] = None) -> io.BytesIO:
|
|
|
|
|
+ parts = iter_text_parts(text, split_pattern)
|
|
|
|
|
+ if not parts:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ buffers: List[np.ndarray] = []
|
|
|
|
|
+ for part in parts:
|
|
|
|
|
+ key = sentence_cache_key(part, voice, speed, model_name or DEFAULT_MODEL_NAME)
|
|
|
|
|
+ cached = memory_cache.get(key)
|
|
|
|
|
+ if cached:
|
|
|
|
|
+ audio, _ = sf.read(io.BytesIO(cached["audio_bytes"]), dtype="float32")
|
|
|
|
|
+ buffers.append(to_mono_numpy(audio))
|
|
|
|
|
+ continue
|
|
|
|
|
+ audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
|
|
|
|
|
+ buffers.append(audio)
|
|
|
|
|
+
|
|
|
|
|
+ merged = np.concatenate(buffers, axis=0) if len(buffers) > 1 else buffers[0]
|
|
|
|
|
+ buf = io.BytesIO()
|
|
|
|
|
+ sf.write(buf, merged, samplerate=sample_rate, format="WAV", subtype="PCM_16")
|
|
|
|
|
+ buf.seek(0)
|
|
|
|
|
+ return buf
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/tts", summary="POST: 传入文本返回 WAV 流")
|
|
|
|
|
+def tts_post(req: TTSRequest):
|
|
|
|
|
+ buf = synthesize_wav_bytes(
|
|
|
|
|
+ text=req.text,
|
|
|
|
|
+ voice=req.voice or "af_heart",
|
|
|
|
|
+ speed=req.speed if req.speed is not None else 1.0,
|
|
|
|
|
+ split_pattern=req.split_pattern,
|
|
|
|
|
+ model_name=req.model_name,
|
|
|
|
|
+ )
|
|
|
|
|
+ return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/tts", summary="GET: 传入文本返回 WAV 流")
|
|
|
|
|
+def tts_get(
|
|
|
|
|
+ text: str = Query(..., description="待合成文本"),
|
|
|
|
|
+ voice: str = Query("af_heart"),
|
|
|
|
|
+ speed: float = Query(1.0),
|
|
|
|
|
+ model_name: str = Query(DEFAULT_MODEL_NAME),
|
|
|
|
|
+ split_pattern: str = Query(r"\n+"),
|
|
|
|
|
+):
|
|
|
|
|
+ if speed is None or float(speed) <= 0:
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
|
|
|
|
|
+ buf = synthesize_wav_bytes(text=text, voice=voice, speed=float(speed), split_pattern=split_pattern, model_name=model_name)
|
|
|
|
|
+ return StreamingResponse(buf, media_type="audio/wav", headers={"Content-Disposition": 'inline; filename="tts.wav"'})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/generate")
|
|
|
|
|
+async def generate_audio_stream(data: Dict = Body(...)):
|
|
|
|
|
+ async with request_semaphore:
|
|
|
|
|
+ text = data.get("text", "")
|
|
|
|
|
+ voice = data.get("voice", "af_heart")
|
|
|
|
|
+ speed = float(data.get("speed", 1.0))
|
|
|
|
|
+ model = data.get("model_name", DEFAULT_MODEL_NAME)
|
|
|
|
|
+ split_pattern = data.get("split_pattern", r"\n+")
|
|
|
|
|
+ client_id = data.get("client_id", str(uuid.uuid4()))
|
|
|
|
|
+
|
|
|
|
|
+ if not text.strip():
|
|
|
|
|
+ raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ parts = iter_text_parts(text, split_pattern)
|
|
|
|
|
+ if client_id in current_requests:
|
|
|
|
|
+ current_requests[client_id]["interrupt"] = True
|
|
|
|
|
+ await asyncio.sleep(0.05)
|
|
|
|
|
+ current_requests[client_id] = {"interrupt": False}
|
|
|
|
|
+
|
|
|
|
|
+ async def stream():
|
|
|
|
|
+ try:
|
|
|
|
|
+ for idx, part in enumerate(parts):
|
|
|
|
|
+ if current_requests.get(client_id, {}).get("interrupt"):
|
|
|
|
|
+ break
|
|
|
|
|
+ item = await get_or_create_sentence_cache_item(part, voice, speed, model)
|
|
|
|
|
+ resp = cache_item_to_response(item)
|
|
|
|
|
+ yield json.dumps({"index": idx, **resp}).encode() + b"\n"
|
|
|
|
|
+ finally:
|
|
|
|
|
+ current_requests.pop(client_id, None)
|
|
|
|
|
+
|
|
|
|
|
+ return StreamingResponse(stream(), media_type="application/x-ndjson")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/clear-cache")
|
|
|
|
|
+async def clear_cache():
|
|
|
|
|
+ memory_cache.clear()
|
|
|
|
|
+ for f in Path(CACHE_DIR).glob("*"):
|
|
|
|
|
+ f.unlink()
|
|
|
|
|
+ return {"status": "success"}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/cache-info")
|
|
|
|
|
+async def get_cache_info():
|
|
|
|
|
+ mem = memory_cache.info()
|
|
|
|
|
+ disk_files = list(Path(CACHE_DIR).glob("*.wav"))
|
|
|
|
|
+ return {"memory_cache": mem["items"], "memory_cache_bytes": mem["bytes"], "disk_cache": len(disk_files)}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ import uvicorn
|
|
|
|
|
+
|
|
|
|
|
+ uvicorn.run("speech_tts_onnx_opt:app", host="0.0.0.0", port=18000, reload=False, workers=1)
|