Ver código fonte

轻量级tts

sequoia 1 semana atrás
commit
0fde664b79
7 arquivos alterados com 1028 adições e 0 exclusões
  1. 3 0
      .gitignore
  2. 27 0
      nohup.out
  3. 6 0
      requirement.txt
  4. 448 0
      speech_tts_LRUandLocal.py
  5. 338 0
      speech_tts_cpu.py
  6. 4 0
      start_tts_5.sh
  7. 202 0
      tts_zh.py

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+__pycache__/
+audio_cache/
+.codex

+ 27 - 0
nohup.out

@@ -0,0 +1,27 @@
+INFO:     Started server process [1488309]
+INFO:     Waiting for application startup.
+/root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1
+  warnings.warn(
+/root/miniconda3/envs/py311/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:144: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
+  WeightNorm.apply(module, name, dim)
+INFO:speech_tts_cpu:模型加载完成 (CPU)
+INFO:     Application startup complete.
+INFO:     Uvicorn running on http://0.0.0.0:8028 (Press CTRL+C to quit)
+[W401 23:05:57.020141034 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.022192361 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.024095751 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.026630998 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.026837536 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+WARNING: Defaulting repo_id to hexgrad/Kokoro-82M. Pass repo_id='hexgrad/Kokoro-82M' to suppress this warning.
+INFO:     141.140.15.30:45960 - "POST /generate HTTP/1.1" 200 OK
+[W401 23:05:57.279924627 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.282296856 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.284107661 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.286825082 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.287040627 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.604346925 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:57.604690382 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:58.581911577 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:58.582329733 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:59.978522511 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.
+[W401 23:05:59.979096069 NNPACK.cpp:56] Could not initialize NNPACK! Reason: Unsupported hardware.

+ 6 - 0
requirement.txt

@@ -0,0 +1,6 @@
+kokoro>=0.9.4
+soundfile
+fastapi
+cachetools
+pydub
+uvicorn

+ 448 - 0
speech_tts_LRUandLocal.py

@@ -0,0 +1,448 @@
+from fastapi import FastAPI, HTTPException, Body, Request
+from fastapi.responses import StreamingResponse
+from contextlib import asynccontextmanager
+from kokoro import KPipeline  
+import soundfile as sf
+import io
+import base64
+import re
+import logging
+import json
+import hashlib
+import asyncio
+import os
+import uuid
+from fastapi.middleware.cors import CORSMiddleware
+from typing import Dict, List, Optional
+from cachetools import TTLCache
+import concurrent.futures
+import threading
+import numpy as np
+from pathlib import Path
+import torch
+import time
+
+import logging
+logging.getLogger("torch").setLevel(logging.ERROR) 
+
+# torch.backends.nnpack.enabled = False
+
+# 配置
+USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
+MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 100))  # 内存缓存大小
+DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 200))      # 磁盘缓存最大文件数
+MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 10))
+LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper()
+CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")  # 磁盘缓存目录
+IDLE_TIMEOUT = 30 * 60  # 20分钟空闲超时(秒)
+
+# 配置日志
+logging.basicConfig(level=getattr(logging, LOG_LEVEL))
+logger = logging.getLogger(__name__)
+
+# 确保缓存目录存在
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+
+# 句子分割正则
+SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
+
+# 内存缓存
+memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
+cache_lock = threading.Lock()
+
+# 线程池用于I/O操作
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=20)
+
+# 全局模型实例和状态
+model_pipeline = None
+current_requests: Dict[str, Dict] = {}
+request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
+last_request_time = time.time()  # 记录最后一次请求时间
+model_lock = asyncio.Lock()  # 用于同步模型加载/卸载
+
+async def unload_model():
+    """释放模型和GPU资源"""
+    global model_pipeline
+    async with model_lock:
+        if model_pipeline is not None:
+            try:
+                del model_pipeline
+                model_pipeline = None
+                torch.cuda.empty_cache()  # 清理GPU缓存
+                logger.info("模型已卸载,GPU资源已释放")
+            except Exception as e:
+                logger.error(f"释放模型失败: {str(e)}")
+
+async def load_model():
+    """加载模型"""
+    global model_pipeline
+    async with model_lock:
+        if model_pipeline is None:
+            try:
+                model_pipeline = KPipeline(lang_code='a')
+                logger.info("模型加载成功")
+            except Exception as e:
+                logger.error(f"模型加载失败: {str(e)}")
+                raise
+
+async def idle_checker():
+    """检查空闲时间并释放资源"""
+    global last_request_time
+    while True:
+        await asyncio.sleep(60)  # 每分钟检查一次
+        if time.time() - last_request_time > IDLE_TIMEOUT and model_pipeline is not None:
+            logger.info("超过20分钟无请求,释放GPU资源")
+            await unload_model()
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    """管理模型生命周期和空闲检测"""
+    global model_pipeline, last_request_time
+    try:
+        await load_model()  # 启动时加载模型
+        idle_task = asyncio.create_task(idle_checker())  # 启动空闲检测任务
+        yield
+    except Exception as e:
+        logger.error(f"生命周期启动失败: {str(e)}")
+        raise
+    finally:
+        idle_task.cancel()  # 停止空闲检测
+        await unload_model()  # 关闭时释放模型
+        logger.info("应用关闭,资源已清理")
+
+app = FastAPI(lifespan=lifespan)
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+def split_sentences(text: str) -> List[str]:
+    """将文本分割成句子"""
+    return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
+
+def generate_cache_key(text: str, voice: str, speed: float) -> str:
+    """生成缓存键"""
+    key_str = f"{text}_{voice}_{speed}"
+    return hashlib.md5(key_str.encode('utf-8')).hexdigest()
+
+def get_cache_file_path(cache_key: str) -> str:
+    """获取磁盘缓存文件路径"""
+    return os.path.join(CACHE_DIR, f"{cache_key}.wav")
+
+async def save_to_memory_cache(cache_key: str, data: List[Dict]):
+    """保存到内存缓存"""
+    loop = asyncio.get_running_loop()
+    try:
+        with cache_lock:
+            await loop.run_in_executor(executor, lambda: memory_cache.__setitem__(cache_key, data))
+    except Exception as e:
+        logger.error(f"保存到内存缓存失败: {str(e)}")
+
+async def load_from_memory_cache(cache_key: str) -> Optional[List[Dict]]:
+    """从内存缓存加载"""
+    loop = asyncio.get_running_loop()
+    try:
+        with cache_lock:
+            return await loop.run_in_executor(executor, lambda: memory_cache.get(cache_key))
+    except Exception as e:
+        logger.error(f"从内存缓存加载失败: {str(e)}")
+    return None
+
+async def save_to_disk_cache(cache_key: str, audio_data: np.ndarray):
+    """保存音频数据到磁盘缓存"""
+    loop = asyncio.get_running_loop()
+    try:
+        file_path = get_cache_file_path(cache_key)
+        await loop.run_in_executor(
+            executor, 
+            lambda: sf.write(file_path, audio_data, 24000, format="WAV")
+        )
+        logger.debug(f"音频已保存到磁盘缓存: {file_path}")
+    except Exception as e:
+        logger.error(f"保存到磁盘缓存失败: {str(e)}")
+
+async def load_from_disk_cache(cache_key: str) -> Optional[np.ndarray]:
+    """从磁盘缓存加载音频数据"""
+    loop = asyncio.get_running_loop()
+    try:
+        file_path = get_cache_file_path(cache_key)
+        if os.path.exists(file_path):
+            audio_data, sample_rate = await loop.run_in_executor(
+                executor,
+                lambda: sf.read(file_path)
+            )
+            return audio_data
+    except Exception as e:
+        logger.error(f"从磁盘缓存加载失败: {str(e)}")
+    return None
+
+async def clean_disk_cache():
+    """清理磁盘缓存"""
+    loop = asyncio.get_running_loop()
+    try:
+        cache_files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
+        if len(cache_files) > DISK_CACHE_SIZE:
+            files_to_delete = cache_files[:len(cache_files) - DISK_CACHE_SIZE]
+            await loop.run_in_executor(
+                executor,
+                lambda: [f.unlink() for f in files_to_delete]
+            )
+            logger.debug(f"已清理 {len(files_to_delete)} 个磁盘缓存文件")
+    except Exception as e:
+        logger.error(f"清理磁盘缓存失败: {str(e)}")
+
+@app.middleware("http")
+async def track_clients(request: Request, call_next):
+    """跟踪客户端连接并更新最后请求时间"""
+    global last_request_time
+    client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
+    if request.url.path == "/generate":
+        last_request_time = time.time()  # 更新最后请求时间
+        current_requests[client_id] = {"active": True}
+        logger.debug(f"客户端 {client_id} 已连接")
+
+    response = await call_next(request)
+
+    if request.url.path == "/generate" and client_id in current_requests:
+        del current_requests[client_id]
+        logger.debug(f"客户端 {client_id} 已断开")
+
+    return response
+
+async def check_client_active(client_id: str, interval: float = 0.5) -> bool:
+    """检查客户端是否活跃"""
+    while True:
+        await asyncio.sleep(interval)
+        if client_id not in current_requests:
+            logger.debug(f"客户端 {client_id} 不再活跃")
+            return False
+        if current_requests[client_id].get("interrupt"):
+            return False
+
+@app.post("/generate")
+async def generate_audio_stream(data: Dict = Body(...)):
+    """流式音频生成,带批处理和缓存"""
+    global last_request_time
+    async with request_semaphore:
+        last_request_time = time.time()  # 更新最后请求时间
+        text = data.get("text", "")
+        voice = data.get("voice", "af_heart")
+        speed = float(data.get("speed", 1.0))
+        client_id = data.get("client_id", str(uuid.uuid4()))
+
+        # 确保模型已加载
+        await load_model()
+        if not model_pipeline:
+            raise HTTPException(500, "模型未初始化")
+        if not text:
+            raise HTTPException(400, "文本不能为空")
+
+        # 中断之前的请求
+        if client_id in current_requests:
+            current_requests[client_id]["interrupt"] = True
+            await asyncio.sleep(0.05)
+
+        cache_key = generate_cache_key(text, voice, speed)
+
+        # 1. 检查内存缓存
+        cached_data = await load_from_memory_cache(cache_key)
+        if cached_data:
+            logger.debug(f"从内存缓存流式传输 (key: {cache_key})")
+            async def cached_stream():
+                for idx, item in enumerate(cached_data):
+                    yield json.dumps({
+                        "index": idx,
+                        "sentence": item["sentence"],
+                        "audio": item["audio"],
+                        "sample_rate": 24000,
+                        "format": "wav",
+                        "cached": True
+                    }, ensure_ascii=False).encode() + b"\n"
+
+            return StreamingResponse(
+                cached_stream(),
+                media_type="application/x-ndjson",
+                headers={
+                    "X-Content-Type-Options": "nosniff",
+                    "Cache-Control": "no-store",
+                    "X-Client-ID": client_id
+                }
+            )
+
+        # 2. 检查磁盘缓存
+        cached_audio = await load_from_disk_cache(cache_key)
+        if cached_audio is not None:
+            logger.debug(f"从磁盘缓存流式传输 (key: {cache_key})")
+            with io.BytesIO() as buffer:
+                sf.write(buffer, cached_audio, 24000, format="WAV")
+                audio_bytes = buffer.getvalue()
+
+            audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
+            cached_data = [{
+                "index": 0,
+                "sentence": text,
+                "audio": audio_b64,
+                "sample_rate": 24000,
+                "format": "wav"
+            }]
+            await save_to_memory_cache(cache_key, cached_data)
+
+            return StreamingResponse(
+                iter([json.dumps({
+                    "index": 0,
+                    "sentence": text,
+                    "audio": audio_b64,
+                    "sample_rate": 24000,
+                    "format": "wav",
+                    "cached": True
+                }, ensure_ascii=False).encode() + b"\n"]),
+                media_type="application/x-ndjson",
+                headers={
+                    "X-Content-Type-Options": "nosniff",
+                    "Cache-Control": "no-store",
+                    "X-Client-ID": client_id
+                }
+            )
+
+        # 3. 没有缓存,需要生成
+        sentences = split_sentences(text)
+        logger.debug(f"处理 {len(sentences)} 个句子")
+        current_requests[client_id] = {"interrupt": False, "sentences": sentences, "active": True}
+
+        async def audio_stream():
+            memory_cache_items = []
+            all_audio = []
+            client_check_task = asyncio.create_task(check_client_active(client_id))
+
+            try:
+                batch_size = 4
+                for i in range(0, len(sentences), batch_size):
+                    batch = sentences[i:i + batch_size]
+                    if client_check_task.done() or current_requests[client_id].get("interrupt"):
+                        logger.debug(f"请求中断或客户端 {client_id} 已断开")
+                        break
+
+                    try:
+                        generators = [model_pipeline(s, voice=voice, speed=speed, split_pattern=None) for s in batch]
+                        for idx, (sentence, generator) in enumerate(zip(batch, generators)):
+                            global_idx = i + idx
+                            for _, _, audio in generator:
+                                if audio is None or len(audio) == 0:
+                                    logger.warning(f"句子 {global_idx} 生成了空音频: {sentence}")
+                                    yield json.dumps({
+                                        "index": global_idx,
+                                        "sentence": sentence,
+                                        "audio": "",
+                                        "sample_rate": 24000,
+                                        "format": "wav",
+                                        "warning": "没有生成音频"
+                                    }, ensure_ascii=False).encode() + b"\n"
+                                    break
+
+                                all_audio.append(audio)
+                                with io.BytesIO() as buffer:
+                                    sf.write(buffer, audio, 24000, format="WAV")
+                                    audio_bytes = buffer.getvalue()
+
+                                audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
+                                memory_cache_items.append({
+                                    "index": global_idx,
+                                    "sentence": sentence,
+                                    "audio": audio_b64,
+                                    "sample_rate": 24000,
+                                    "format": "wav"
+                                })
+
+                                yield json.dumps({
+                                    "index": global_idx,
+                                    "sentence": sentence,
+                                    "audio": audio_b64,
+                                    "sample_rate": 24000,
+                                    "format": "wav"
+                                }, ensure_ascii=False).encode() + b"\n"
+                                break
+                    except Exception as e:
+                        logger.error(f"处理批次 {i} 时出错: {str(e)}")
+                        yield json.dumps({
+                            "error": f"批次处理失败于索引 {i}",
+                            "message": str(e)
+                        }, ensure_ascii=False).encode() + b"\n"
+
+                if memory_cache_items:
+                    await save_to_memory_cache(cache_key, memory_cache_items)
+                    if all_audio:
+                        combined_audio = np.concatenate(all_audio)
+                        await save_to_disk_cache(cache_key, combined_audio)
+                        logger.debug(f"已保存合并音频到磁盘缓存 (key: {cache_key})")
+                        await clean_disk_cache()
+
+            except asyncio.CancelledError:
+                logger.debug(f"客户端 {client_id} 的请求已取消")
+            except Exception as e:
+                logger.error(f"流错误: {str(e)}")
+            finally:
+                if client_id in current_requests:
+                    del current_requests[client_id]
+                if not client_check_task.done():
+                    client_check_task.cancel()
+                logger.debug(f"已清理客户端 {client_id}")
+
+        return StreamingResponse(
+            audio_stream(),
+            media_type="application/x-ndjson",
+            headers={
+                "X-Content-Type-Options": "nosniff",
+                "Cache-Control": "no-store",
+                "X-Client-ID": client_id
+            }
+        )
+
+@app.get("/clear-cache")
+async def clear_cache():
+    """清除所有缓存"""
+    try:
+        with cache_lock:
+            memory_cache.clear()
+        loop = asyncio.get_running_loop()
+        await loop.run_in_executor(
+            executor,
+            lambda: [f.unlink() for f in Path(CACHE_DIR).glob("*.wav")]
+        )
+        return {"status": "success", "message": "所有缓存已清除"}
+    except Exception as e:
+        raise HTTPException(500, f"清除缓存失败: {str(e)}")
+
+@app.get("/cache-info")
+async def get_cache_info():
+    """获取缓存信息"""
+    try:
+        with cache_lock:
+            memory_count = len(memory_cache)
+            memory_size = sum(len(json.dumps(item)) for item in memory_cache.values())
+        disk_files = list(Path(CACHE_DIR).glob("*.wav"))
+        disk_count = len(disk_files)
+        disk_size = sum(f.stat().st_size for f in disk_files)
+
+        return {
+            "memory_cache": {
+                "count": memory_count,
+                "size": memory_size,
+                "max_size": MEMORY_CACHE_SIZE
+            },
+            "disk_cache": {
+                "count": disk_count,
+                "size": disk_size,
+                "max_size": DISK_CACHE_SIZE,
+                "directory": CACHE_DIR
+            }
+        }
+    except Exception as e:
+        raise HTTPException(500, f"获取缓存信息失败: {str(e)}")
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8028)

+ 338 - 0
speech_tts_cpu.py

@@ -0,0 +1,338 @@
+from fastapi import FastAPI, HTTPException, Body, Request
+from fastapi.responses import StreamingResponse
+from contextlib import asynccontextmanager
+from kokoro import KPipeline
+import soundfile as sf
+import io
+import base64
+import re
+import logging
+import json
+import hashlib
+import asyncio
+import os
+import uuid
+import aiofiles
+from fastapi.middleware.cors import CORSMiddleware
+from typing import Dict, List, Optional
+from cachetools import TTLCache
+import concurrent.futures
+import threading
+import numpy as np
+from pathlib import Path
+import sys
+os.environ["USE_NNPACK"] = "0"
+os.environ["NNPACK_DISABLE"] = "1"
+import torch
+torch.backends.nnpack.enabled = False
+import warnings
+warnings.filterwarnings("ignore", message=".*NNPACK.*")
+
+# ------------------- 配置 -------------------
+USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
+MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
+DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
+MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 12))
+LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
+CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
+
+logging.basicConfig(level=getattr(logging, LOG_LEVEL))
+logger = logging.getLogger(__name__)
+
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+
+SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
+
+# 内存缓存:单句缓存结构 {cache_key → { "audio": b64, "sr": 24000 }}
+memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
+cache_lock = threading.Lock()
+
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
+
+# 模型(常驻)
+model_pipeline = None
+model_lock = asyncio.Lock()
+
+# 并发管理
+request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
+current_requests: Dict[str, Dict] = {}
+
+
+# ============================================================
+# 模型加载
+# ============================================================
+async def load_model():
+    global model_pipeline
+    async with model_lock:
+        if model_pipeline is None:
+            try:
+                model_pipeline = KPipeline(lang_code='a', device="cpu")
+                logger.info("模型加载完成 (CPU)")
+            except Exception as e:
+                logger.error(f"模型加载失败: {e}")
+                raise
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    await load_model()
+    yield
+    logger.info("应用关闭中...")
+
+
+app = FastAPI(lifespan=lifespan)
+
+
+# CORS 允许所有来源
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+# ============================================================
+# 工具函数
+# ============================================================
+
+def split_sentences(text: str) -> List[str]:
+    return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
+
+
+def sentence_cache_key(sentence: str, voice: str, speed: float) -> str:
+    raw = f"{sentence}|{voice}|{speed}"
+    return hashlib.md5(raw.encode()).hexdigest()
+
+
+def sentence_cache_path(key: str):
+    return os.path.join(CACHE_DIR, f"{key}.wav")
+
+
+def disk_meta_path(key: str):
+    return os.path.join(CACHE_DIR, f"{key}.json")
+
+
+def put_memory_cache(key: str, value: dict):
+    with cache_lock:
+        memory_cache[key] = value
+
+
+def get_memory_cache(key: str) -> Optional[dict]:
+    with cache_lock:
+        return memory_cache.get(key)
+
+
+async def load_sentence_from_disk(key: str) -> Optional[dict]:
+    wav_path = sentence_cache_path(key)
+    meta_path = disk_meta_path(key)
+
+    if not Path(wav_path).exists() or not Path(meta_path).exists():
+        return None
+
+    async with aiofiles.open(meta_path, "r") as f:
+        meta_text = await f.read()
+        meta = json.loads(meta_text)
+
+    audio, sr = await asyncio.get_event_loop().run_in_executor(
+        executor, lambda: sf.read(wav_path)
+    )
+    with io.BytesIO() as buf:
+        # sf.write(buf, audio, sr, "WAV")
+        with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1, 
+                  format='WAV', subtype='PCM_16') as f:
+            f.write(audio)
+        audio_b64 = base64.b64encode(buf.getvalue()).decode()
+
+    meta["audio"] = audio_b64
+    return meta
+
+
+async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
+    wav_path = sentence_cache_path(key)
+    meta_path = disk_meta_path(key)
+
+    await asyncio.get_event_loop().run_in_executor(
+        executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
+    )
+
+    meta = {
+        "sentence": sentence,
+        "sample_rate": sr,
+    }
+
+    async with aiofiles.open(meta_path, "w") as f:
+        await f.write(json.dumps(meta))
+
+    return meta
+
+
+async def clean_disk_cache():
+    files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
+    if len(files) <= DISK_CACHE_SIZE:
+        return
+    remove_n = len(files) - DISK_CACHE_SIZE
+    for f in files[:remove_n]:
+        try:
+            json_path = disk_meta_path(f.stem)
+            f.unlink()
+            if Path(json_path).exists():
+                Path(json_path).unlink()
+        except:
+            pass
+
+
+async def check_client_alive(client_id):
+    while True:
+        await asyncio.sleep(0.3)
+        if client_id not in current_requests or current_requests[client_id].get("interrupt"):
+            return False
+
+
+# ============================================================
+# 请求跟踪
+# ============================================================
+
+@app.middleware("http")
+async def track_clients(request: Request, call_next):
+    client_id = (
+        request.query_params.get("client_id")
+        or request.headers.get("X-Client-ID")
+        or str(uuid.uuid4())
+    )
+    if request.url.path == "/generate":
+        current_requests[client_id] = {"active": True}
+
+    response = await call_next(request)
+
+    if request.url.path == "/generate":
+        current_requests.pop(client_id, None)
+
+    return response
+
+
+# ============================================================
+# 主接口
+# ============================================================
+
+@app.post("/generate")
+async def generate_audio_stream(data: Dict = Body(...)):
+    async with request_semaphore:
+        text = data.get("text", "")
+        voice = data.get("voice", "af_heart")
+        speed = float(data.get("speed", 1.0))
+        client_id = data.get("client_id", str(uuid.uuid4()))
+
+        if not text:
+            raise HTTPException(400, "文本不能为空")
+
+        await load_model()
+
+        # 取消旧请求
+        if client_id in current_requests:
+            current_requests[client_id]["interrupt"] = True
+            await asyncio.sleep(0.05)
+
+        sentences = split_sentences(text)
+        current_requests[client_id] = {"interrupt": False}
+
+        async def stream():
+            client_alive_task = asyncio.create_task(check_client_alive(client_id))
+
+            try:
+                for idx, sentence in enumerate(sentences):
+                    if client_alive_task.done():
+                        break
+
+                    key = sentence_cache_key(sentence, voice, speed)
+
+                    # 1) 内存缓存
+                    cache_item = get_memory_cache(key)
+                    if cache_item:
+                        yield json.dumps({
+                            "index": idx,
+                            **cache_item
+                        }).encode() + b"\n"
+                        continue
+
+                    # 2) 磁盘缓存
+                    disk_item = await load_sentence_from_disk(key)
+                    if disk_item:
+                        put_memory_cache(key, disk_item)
+                        yield json.dumps({
+                            "index": idx,
+                            **disk_item
+                        }).encode() + b"\n"
+                        continue
+
+                    # 3) 无缓存 → 推理
+                    generator = model_pipeline(sentence, voice=voice, speed=speed)
+
+                    # 仅取第一个 chunk
+                    for _, _, audio in generator:
+                        sr = 24000
+                        with io.BytesIO() as buf:
+                            # sf.write(buf, audio, sr, "WAV")
+                            with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1, 
+                  format='WAV', subtype='PCM_16') as f:
+                                f.write(audio)
+                            audio_b64 = base64.b64encode(buf.getvalue()).decode()
+
+                        item = {
+                            "sentence": sentence,
+                            "audio": audio_b64,
+                            "sample_rate": sr
+                        }
+
+                        put_memory_cache(key, item)
+
+                        # 保存磁盘缓存
+                        await save_sentence_to_disk(key, audio, sr, sentence)
+                        await clean_disk_cache()
+
+                        yield json.dumps({
+                            "index": idx,
+                            **item
+                        }).encode() + b"\n"
+                        break
+
+            finally:
+                current_requests.pop(client_id, None)
+                if not client_alive_task.done():
+                    client_alive_task.cancel()
+
+        return StreamingResponse(stream(), media_type="application/x-ndjson")
+
+
+# ============================================================
+# 其它接口
+# ============================================================
+
+@app.get("/clear-cache")
+async def clear_cache():
+    with cache_lock:
+        memory_cache.clear()
+    for f in Path(CACHE_DIR).glob("*"):
+        f.unlink()
+    return {"status": "success"}
+
+
+@app.get("/cache-info")
+async def get_cache_info():
+    with cache_lock:
+        mem_count = len(memory_cache)
+    disk_files = list(Path(CACHE_DIR).glob("*.wav"))
+    return {
+        "memory_cache": mem_count,
+        "disk_cache": len(disk_files)
+    }
+
+
+# ============================================================
+# 启动
+# ============================================================
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8028)

+ 4 - 0
start_tts_5.sh

@@ -0,0 +1,4 @@
+#!/bin/bash
+eval "$(/root/miniconda3/bin/conda shell.bash hook)"
+conda activate py311
+nohup uvicorn speech_tts_cpu:app --host 0.0.0.0 --port 8028 &

+ 202 - 0
tts_zh.py

@@ -0,0 +1,202 @@
+
+from fastapi import FastAPI, HTTPException, Query
+from fastapi.responses import StreamingResponse
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel, field_validator
+from typing import Optional
+from io import BytesIO
+import numpy as np
+import soundfile as sf
+from kokoro import KPipeline
+import threading
+
+# 全局初始化 TTS Pipeline(确保 lang_code 与 voice 匹配)
+# 中文女声示例:voice='zf_xiaoxiao',lang_code='z'
+pipeline = KPipeline(lang_code='z')
+
+# 优先使用 pipeline 提供的采样率,若无则回退到 24000
+sample_rate = getattr(pipeline, "sample_rate", 24000)
+
+# 为了避免底层模型并发问题,使用锁串行化 TTS 推理
+synthesis_lock = threading.Lock()
+
+app = FastAPI(title="Online TTS Service (Kokoro)")
+
+# CORS 允许来自 Notebook/浏览器的跨域调用
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],  # 生产环境请按需配置
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+class TTSRequest(BaseModel):
+    text: str
+    voice: Optional[str] = "zf_xiaoxiao"
+    speed: Optional[float] = 1.0
+    split_pattern: Optional[str] = r"\n+"
+
+    @field_validator("speed")
+    @classmethod
+    def check_speed(cls, v):
+        if v is None:
+            return 1.0
+        try:
+            v = float(v)
+        except Exception:
+            raise ValueError("speed 必须为数值")
+        if v <= 0:
+            raise ValueError("speed 必须大于 0")
+        return v
+
+
+def to_mono_numpy(audio) -> np.ndarray:
+    """
+    将 pipeline 返回的 audio 安全地转换为 numpy 1D float32 单声道数组。
+    兼容 numpy.ndarray、PyTorch Tensor 以及其他可转 numpy 的类型。
+    """
+    if audio is None:
+        return np.array([], dtype=np.float32)
+
+    # 已是 numpy
+    if isinstance(audio, np.ndarray):
+        arr = audio
+    else:
+        # 尝试兼容 PyTorch Tensor / 具有 numpy() 的对象
+        arr = None
+        # PyTorch Tensor 情况
+        if hasattr(audio, "detach") and hasattr(audio, "cpu") and hasattr(audio, "numpy"):
+            try:
+                arr = audio.detach().cpu().numpy()
+            except Exception:
+                arr = None
+        # 其他框架的 numpy() 情况
+        if arr is None and hasattr(audio, "numpy") and callable(getattr(audio, "numpy")):
+            try:
+                arr = audio.numpy()
+            except Exception:
+                arr = None
+        # 兜底转换
+        if arr is None:
+            try:
+                arr = np.asarray(audio)
+            except Exception:
+                # 无法转换,返回空数组以便后续过滤
+                return np.array([], dtype=np.float32)
+
+    # 标准化形状与 dtype
+    arr = np.asarray(arr)
+    if arr.size == 0:
+        return np.array([], dtype=np.float32)
+
+    # 常见返回为 [T]、[T, 1] 或 [1, T]
+    if arr.ndim == 2:
+        if arr.shape[1] == 1:
+            arr = arr[:, 0]
+        elif arr.shape[0] == 1:
+            arr = arr[0]
+        else:
+            # 多声道时做下混为单声道
+            arr = arr.mean(axis=1)
+    elif arr.ndim > 2:
+        # 形状异常时,拉平为 1D
+        arr = arr.reshape(-1)
+
+    if arr.ndim == 0:
+        arr = arr.reshape(1)
+
+    # 转 float32,soundfile 会在写入时转为 PCM_16
+    if arr.dtype != np.float32:
+        arr = arr.astype(np.float32, copy=False)
+
+    return arr
+
+
+def synthesize_wav_bytes(
+    text: str,
+    voice: str = "zf_xiaoxiao",
+    speed: float = 1.0,
+    split_pattern: str = r"\n+",
+) -> BytesIO:
+    # 生成完整音频并打包为 WAV 字节
+    segments = []
+    with synthesis_lock:
+        generator = pipeline(
+            text,
+            voice=voice,
+            speed=speed,
+            split_pattern=split_pattern,
+        )
+        for _, _, audio in generator:
+            arr = to_mono_numpy(audio)
+            if arr.size > 0 and np.isfinite(arr).all():
+                segments.append(arr)
+
+    if not segments:
+        raise HTTPException(status_code=400, detail="未生成音频,请检查输入文本或参数。")
+
+    try:
+        audio_concat = np.concatenate(segments, axis=0)
+    except Exception as e:
+        # 捕获异常并返回清晰错误信息
+        raise HTTPException(
+            status_code=500,
+            detail=f"音频拼接失败:{type(e).__name__}: {str(e)}"
+        )
+
+    buf = BytesIO()
+    # 以 PCM_16 写入 WAV
+    sf.write(buf, audio_concat, samplerate=sample_rate, format="WAV", subtype="PCM_16")
+    buf.seek(0)
+    return buf
+
+
+@app.post("/tts", summary="POST: 传入文本返回 WAV 流")
+def tts_post(req: TTSRequest):
+    buf = synthesize_wav_bytes(
+        text=req.text,
+        voice=req.voice or "zf_xiaoxiao",
+        speed=req.speed if req.speed is not None else 1.0,
+        split_pattern=req.split_pattern or r"\n+",
+    )
+    return StreamingResponse(
+        buf,
+        media_type="audio/wav",
+        headers={
+            "Content-Disposition": 'inline; filename="tts.wav"'
+        },
+    )
+
+
+@app.get("/tts", summary="GET: 传入文本返回 WAV 流(便于直接以 URL 播放)")
+def tts_get(
+    text: str = Query(..., description="待合成文本"),
+    voice: str = Query("zf_xiaoxiao"),
+    speed: float = Query(1.0),
+    split_pattern: str = Query(r"\n+"),
+):
+    if speed is None or float(speed) <= 0:
+        raise HTTPException(status_code=400, detail="speed 必须为大于 0 的数值")
+    buf = synthesize_wav_bytes(
+        text=text,
+        voice=voice,
+        speed=float(speed),
+        split_pattern=split_pattern,
+    )
+    return StreamingResponse(
+        buf,
+        media_type="audio/wav",
+        headers={
+            "Content-Disposition": 'inline; filename="tts.wav"'
+        },
+    )
+
+
+# 运行:
+#   uvicorn server:app --host 0.0.0.0 --port 8000 --workers 1
+# 建议 workers=1 或者保持串行,避免占用同一设备的并发导致显存/模型冲突
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run("tts_zh:app", host="0.0.0.0", port=18000, reload=False, workers=1)