|
|
@@ -17,17 +17,20 @@ from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
|
+import aiofiles
|
|
|
from fastapi import Body, FastAPI, HTTPException, Query, Request
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from pydantic import BaseModel, field_validator
|
|
|
+from cachetools import TTLCache
|
|
|
+import concurrent.futures
|
|
|
|
|
|
|
|
|
# ============================================================
|
|
|
-# 配置
|
|
|
+# 配置 model_quantized
|
|
|
# ============================================================
|
|
|
|
|
|
-DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model_quantized.onnx")
|
|
|
+DEFAULT_MODEL_NAME = os.getenv("TTS_ONNX_MODEL_NAME", "model.onnx")
|
|
|
MODEL_DIR = os.getenv("TTS_ONNX_MODEL_DIR", "/home/tts-server/onnx")
|
|
|
HF_MODEL_ID = os.getenv("TTS_ONNX_HF_MODEL_ID", "onnx-community/Kokoro-82M-ONNX")
|
|
|
TOKENIZER_PATH = os.getenv("TTS_ONNX_TOKENIZER_PATH", str(Path(MODEL_DIR) / "tokenizer.json"))
|
|
|
@@ -40,6 +43,11 @@ MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 8))
|
|
|
MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
|
|
|
DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
|
|
|
DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", 24000))
|
|
|
+MEMORY_CACHE_TTL = int(os.getenv("MEMORY_CACHE_TTL", 18000))
|
|
|
+ORT_INTRA_OP_THREADS = int(os.getenv("ORT_INTRA_OP_THREADS", "2"))
|
|
|
+ORT_INTER_OP_THREADS = int(os.getenv("ORT_INTER_OP_THREADS", "1"))
|
|
|
+ORT_ENABLE_CPU_MEM_ARENA = os.getenv("ORT_ENABLE_CPU_MEM_ARENA", "true").lower() == "true"
|
|
|
+ORT_ENABLE_MEM_PATTERN = os.getenv("ORT_ENABLE_MEM_PATTERN", "true").lower() == "true"
|
|
|
VOICE_ALIASES = {}
|
|
|
|
|
|
logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
|
|
|
@@ -67,8 +75,11 @@ _KOKORO_ONNX_ENGINE = None
|
|
|
|
|
|
model_lock = threading.Lock()
|
|
|
synthesis_lock = threading.Lock()
|
|
|
+cache_lock = threading.Lock()
|
|
|
request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
|
|
current_requests: Dict[str, Dict] = {}
|
|
|
+memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=MEMORY_CACHE_TTL)
|
|
|
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
|
|
|
|
|
model_session = None
|
|
|
model_name = DEFAULT_MODEL_NAME
|
|
|
@@ -96,6 +107,16 @@ def meta_cache_path(key: str) -> str:
|
|
|
return os.path.join(CACHE_DIR, f"{key}.json")
|
|
|
|
|
|
|
|
|
+def put_memory_cache(key: str, value: dict):
|
|
|
+ with cache_lock:
|
|
|
+ memory_cache[key] = value
|
|
|
+
|
|
|
+
|
|
|
+def get_memory_cache(key: str) -> Optional[dict]:
|
|
|
+ with cache_lock:
|
|
|
+ return memory_cache.get(key)
|
|
|
+
|
|
|
+
|
|
|
def to_mono_numpy(audio) -> np.ndarray:
|
|
|
if audio is None:
|
|
|
return np.array([], dtype=np.float32)
|
|
|
@@ -146,6 +167,71 @@ def read_wav_bytes(wav_path: str) -> Tuple[bytes, int]:
|
|
|
return buf.getvalue(), sr
|
|
|
|
|
|
|
|
|
+async def load_sentence_from_disk(key: str) -> Optional[dict]:
|
|
|
+ wav_path = sentence_cache_path(key)
|
|
|
+ meta_path = meta_cache_path(key)
|
|
|
+ if not Path(wav_path).exists() or not Path(meta_path).exists():
|
|
|
+ return None
|
|
|
+
|
|
|
+ async with aiofiles.open(meta_path, "r") as f:
|
|
|
+ meta_text = await f.read()
|
|
|
+ meta = json.loads(meta_text)
|
|
|
+
|
|
|
+ wav_bytes, sr = await asyncio.get_event_loop().run_in_executor(
|
|
|
+ executor, lambda: read_wav_bytes(wav_path)
|
|
|
+ )
|
|
|
+ meta["audio_bytes"] = wav_bytes
|
|
|
+ meta["audio"] = base64.b64encode(wav_bytes).decode()
|
|
|
+ meta["sample_rate"] = sr
|
|
|
+ return meta
|
|
|
+
|
|
|
+
|
|
|
+async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
|
|
|
+ wav_path = sentence_cache_path(key)
|
|
|
+ meta_path = meta_cache_path(key)
|
|
|
+
|
|
|
+ await asyncio.get_event_loop().run_in_executor(
|
|
|
+ executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
|
|
|
+ )
|
|
|
+
|
|
|
+ async with aiofiles.open(meta_path, "w") as f:
|
|
|
+ await f.write(json.dumps({"sentence": sentence, "sample_rate": sr}))
|
|
|
+
|
|
|
+
|
|
|
+async def clean_disk_cache():
|
|
|
+ files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
|
|
|
+ if len(files) <= DISK_CACHE_SIZE:
|
|
|
+ return
|
|
|
+ remove_n = len(files) - DISK_CACHE_SIZE
|
|
|
+ for f in files[:remove_n]:
|
|
|
+ try:
|
|
|
+ f.unlink()
|
|
|
+ json_path = meta_cache_path(f.stem)
|
|
|
+ if Path(json_path).exists():
|
|
|
+ Path(json_path).unlink()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+def encode_wav_bytes(audio: np.ndarray, sr: int) -> bytes:
|
|
|
+ buf = io.BytesIO()
|
|
|
+ with sf.SoundFile(
|
|
|
+ buf,
|
|
|
+ "w",
|
|
|
+ samplerate=sr,
|
|
|
+ channels=audio.shape[1] if audio.ndim > 1 else 1,
|
|
|
+ format="WAV",
|
|
|
+ subtype="PCM_16",
|
|
|
+ ) as f:
|
|
|
+ f.write(audio)
|
|
|
+ return buf.getvalue()
|
|
|
+
|
|
|
+
|
|
|
+def wav_to_pcm_payload(wav_bytes: bytes) -> bytes:
|
|
|
+ data, _ = sf.read(io.BytesIO(wav_bytes), dtype="int16")
|
|
|
+ return np.asarray(data, dtype=np.int16).tobytes()
|
|
|
+
|
|
|
+
|
|
|
# ============================================================
|
|
|
# 模型加载
|
|
|
# ============================================================
|
|
|
@@ -267,7 +353,12 @@ def load_onnx_session(name: str):
|
|
|
|
|
|
model_path = resolve_model_path(name)
|
|
|
providers = ["CPUExecutionProvider"]
|
|
|
- session = ort.InferenceSession(model_path, providers=providers)
|
|
|
+ sess_options = ort.SessionOptions()
|
|
|
+ sess_options.intra_op_num_threads = ORT_INTRA_OP_THREADS
|
|
|
+ sess_options.inter_op_num_threads = ORT_INTER_OP_THREADS
|
|
|
+ sess_options.enable_cpu_mem_arena = ORT_ENABLE_CPU_MEM_ARENA
|
|
|
+ sess_options.enable_mem_pattern = ORT_ENABLE_MEM_PATTERN
|
|
|
+ session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
|
|
logger.info("ONNX 模型加载完成: %s", model_path)
|
|
|
return session
|
|
|
|
|
|
@@ -518,30 +609,68 @@ def synthesize_audio(text: str, voice: str, speed: float, model_name: Optional[s
|
|
|
|
|
|
|
|
|
def synthesize_wav_bytes(text: str, voice: str, speed: float, model_name: Optional[str] = None) -> io.BytesIO:
|
|
|
- segments: List[np.ndarray] = []
|
|
|
- if _is_english_voice(voice):
|
|
|
- parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
|
|
|
- else:
|
|
|
- parts = split_sentences(text) if text.strip() else []
|
|
|
- if not parts:
|
|
|
- parts = [text.strip()]
|
|
|
-
|
|
|
- with synthesis_lock:
|
|
|
- for part in parts:
|
|
|
- audio = synthesize_audio(part, voice=voice, speed=speed, model_name=model_name)
|
|
|
- if audio.size > 0:
|
|
|
- segments.append(audio)
|
|
|
-
|
|
|
- if not segments:
|
|
|
- raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
|
|
|
-
|
|
|
- audio_concat = np.concatenate(segments, axis=0)
|
|
|
+ audio = synthesize_audio(text=text, voice=voice, speed=speed, model_name=model_name)
|
|
|
buf = io.BytesIO()
|
|
|
- sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
|
|
|
+ sf.write(buf, audio, samplerate=sample_rate, format="WAV", subtype="PCM_16")
|
|
|
buf.seek(0)
|
|
|
return buf
|
|
|
|
|
|
|
|
|
+def iter_text_parts(text: str, split_pattern: Optional[str] = None) -> List[str]:
|
|
|
+ text = (text or "").strip()
|
|
|
+ if not text:
|
|
|
+ return []
|
|
|
+
|
|
|
+ blocks = [text]
|
|
|
+ if split_pattern:
|
|
|
+ try:
|
|
|
+ blocks = [part.strip() for part in re.split(split_pattern, text) if part.strip()]
|
|
|
+ except re.error:
|
|
|
+ logger.warning("split_pattern 非法,回退默认分句: %s", split_pattern)
|
|
|
+
|
|
|
+ parts: List[str] = []
|
|
|
+ for block in blocks:
|
|
|
+ sentences = split_sentences(block)
|
|
|
+ if sentences:
|
|
|
+ parts.extend(sentences)
|
|
|
+ elif block:
|
|
|
+ parts.append(block)
|
|
|
+ return parts
|
|
|
+
|
|
|
+
|
|
|
+async def get_or_create_sentence_cache_item(
|
|
|
+ sentence: str,
|
|
|
+ voice: str,
|
|
|
+ speed: float,
|
|
|
+ model_name: str,
|
|
|
+) -> dict:
|
|
|
+ key = sentence_cache_key(sentence, voice, speed, model_name)
|
|
|
+ cached_item = get_memory_cache(key)
|
|
|
+ if cached_item:
|
|
|
+ return cached_item
|
|
|
+
|
|
|
+ disk_item = await load_sentence_from_disk(key)
|
|
|
+ if disk_item:
|
|
|
+ put_memory_cache(key, disk_item)
|
|
|
+ return disk_item
|
|
|
+
|
|
|
+ def _synthesize_item():
|
|
|
+ audio = synthesize_audio(sentence, voice=voice, speed=speed, model_name=model_name)
|
|
|
+ wav_bytes = encode_wav_bytes(audio, sample_rate)
|
|
|
+ return {
|
|
|
+ "sentence": sentence,
|
|
|
+ "sample_rate": sample_rate,
|
|
|
+ "audio_bytes": wav_bytes,
|
|
|
+ "audio": base64.b64encode(wav_bytes).decode(),
|
|
|
+ }, audio
|
|
|
+
|
|
|
+ item, audio = await asyncio.to_thread(_synthesize_item)
|
|
|
+ put_memory_cache(key, item)
|
|
|
+ await save_sentence_to_disk(key, audio, sample_rate, sentence)
|
|
|
+ await clean_disk_cache()
|
|
|
+ return item
|
|
|
+
|
|
|
+
|
|
|
@app.post("/tts", summary="POST: 传入文本返回 WAV 流")
|
|
|
def tts_post(req: TTSRequest):
|
|
|
buf = synthesize_wav_bytes(
|
|
|
@@ -591,12 +720,7 @@ async def generate_audio_stream(data: Dict = Body(...)):
|
|
|
if not text.strip():
|
|
|
raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
|
|
|
- if _is_english_voice(voice):
|
|
|
- parts = [graphemes for graphemes, _ in _phonemize_en_chunks_for_onnx(text)] if text.strip() else []
|
|
|
- else:
|
|
|
- parts = split_sentences(text) if text.strip() else []
|
|
|
- if not parts:
|
|
|
- parts = [text.strip()]
|
|
|
+ parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
|
|
|
|
|
|
if client_id in current_requests:
|
|
|
current_requests[client_id]["interrupt"] = True
|
|
|
@@ -607,37 +731,17 @@ async def generate_audio_stream(data: Dict = Body(...)):
|
|
|
async def stream():
|
|
|
try:
|
|
|
for idx, part in enumerate(parts):
|
|
|
- if not part:
|
|
|
- continue
|
|
|
if current_requests.get(client_id, {}).get("interrupt"):
|
|
|
break
|
|
|
|
|
|
- audio = await asyncio.to_thread(
|
|
|
- synthesize_audio,
|
|
|
- text=part,
|
|
|
- voice=voice,
|
|
|
- speed=speed,
|
|
|
- model_name=model,
|
|
|
- )
|
|
|
-
|
|
|
- with io.BytesIO() as buf:
|
|
|
- with sf.SoundFile(
|
|
|
- buf,
|
|
|
- "w",
|
|
|
- sample_rate,
|
|
|
- channels=audio.shape[1] if audio.ndim > 1 else 1,
|
|
|
- format="WAV",
|
|
|
- subtype="PCM_16",
|
|
|
- ) as f:
|
|
|
- f.write(audio)
|
|
|
- audio_b64 = base64.b64encode(buf.getvalue()).decode()
|
|
|
+ item = await get_or_create_sentence_cache_item(part, voice, speed, model)
|
|
|
|
|
|
yield json.dumps(
|
|
|
{
|
|
|
"index": idx,
|
|
|
- "sentence": part,
|
|
|
- "audio": audio_b64,
|
|
|
- "sample_rate": sample_rate,
|
|
|
+ "sentence": item["sentence"],
|
|
|
+ "audio": item["audio"],
|
|
|
+ "sample_rate": item["sample_rate"],
|
|
|
}
|
|
|
).encode() + b"\n"
|
|
|
finally:
|
|
|
@@ -646,6 +750,59 @@ async def generate_audio_stream(data: Dict = Body(...)):
|
|
|
return StreamingResponse(stream(), media_type="application/x-ndjson")
|
|
|
|
|
|
|
|
|
+@app.post("/generate_pcm", summary="POST: 返回更轻量的 PCM 分片流")
|
|
|
+async def generate_audio_pcm_stream(data: Dict = Body(...)):
|
|
|
+ async with request_semaphore:
|
|
|
+ text = data.get("text", "")
|
|
|
+ voice = data.get("voice", "af_heart")
|
|
|
+ speed = float(data.get("speed", 1.0))
|
|
|
+ model = data.get("model_name", DEFAULT_MODEL_NAME)
|
|
|
+
|
|
|
+ if not text.strip():
|
|
|
+ raise HTTPException(status_code=400, detail="文本不能为空")
|
|
|
+
|
|
|
+ parts = iter_text_parts(text, data.get("split_pattern", r"\n+"))
|
|
|
+
|
|
|
+ async def pcm_stream():
|
|
|
+ for idx, part in enumerate(parts):
|
|
|
+ item = await get_or_create_sentence_cache_item(part, voice, speed, model)
|
|
|
+ payload = wav_to_pcm_payload(item["audio_bytes"])
|
|
|
+ header = json.dumps(
|
|
|
+ {
|
|
|
+ "index": idx,
|
|
|
+ "sentence": item["sentence"],
|
|
|
+ "sample_rate": item["sample_rate"],
|
|
|
+ "format": "s16le",
|
|
|
+ "bytes": len(payload),
|
|
|
+ }
|
|
|
+ ).encode() + b"\n"
|
|
|
+ yield header
|
|
|
+ yield payload
|
|
|
+
|
|
|
+ return StreamingResponse(
|
|
|
+ pcm_stream(),
|
|
|
+ media_type="application/octet-stream",
|
|
|
+ headers={"X-Audio-Format": "s16le", "X-Sample-Rate": str(sample_rate)},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/clear-cache")
|
|
|
+async def clear_cache():
|
|
|
+ with cache_lock:
|
|
|
+ memory_cache.clear()
|
|
|
+ for f in Path(CACHE_DIR).glob("*"):
|
|
|
+ f.unlink()
|
|
|
+ return {"status": "success"}
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/cache-info")
|
|
|
+async def get_cache_info():
|
|
|
+ with cache_lock:
|
|
|
+ mem_count = len(memory_cache)
|
|
|
+ disk_files = list(Path(CACHE_DIR).glob("*.wav"))
|
|
|
+ return {"memory_cache": mem_count, "disk_cache": len(disk_files)}
|
|
|
+
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
|