from fastapi import FastAPI, HTTPException, Body, Request from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager from kokoro import KPipeline import soundfile as sf import io import base64 import re import logging import json import hashlib import asyncio import os import uuid import aiofiles from fastapi.middleware.cors import CORSMiddleware from typing import Dict, List, Optional from cachetools import TTLCache import concurrent.futures import threading import numpy as np from pathlib import Path import sys os.environ["USE_NNPACK"] = "0" os.environ["NNPACK_DISABLE"] = "1" import torch torch.backends.nnpack.enabled = False import warnings warnings.filterwarnings("ignore", message=".*NNPACK.*") # ------------------- 配置 ------------------- USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true" MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200)) DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500)) MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 12)) LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") logging.basicConfig(level=getattr(logging, LOG_LEVEL)) logger = logging.getLogger(__name__) Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+') # 内存缓存:单句缓存结构 {cache_key → { "audio": b64, "sr": 24000 }} memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000) cache_lock = threading.Lock() executor = concurrent.futures.ThreadPoolExecutor(max_workers=16) # 模型(常驻) model_pipeline = None model_lock = asyncio.Lock() # 并发管理 request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) current_requests: Dict[str, Dict] = {} # ============================================================ # 模型加载 # ============================================================ async def load_model(): global model_pipeline async with model_lock: if model_pipeline is None: try: model_pipeline = KPipeline(lang_code='a', device="cpu") logger.info("模型加载完成 (CPU)") except Exception as e: logger.error(f"模型加载失败: {e}") raise @asynccontextmanager async def lifespan(app: FastAPI): await load_model() yield logger.info("应用关闭中...") app = FastAPI(lifespan=lifespan) # CORS 允许所有来源 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================ # 工具函数 # ============================================================ def split_sentences(text: str) -> List[str]: return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()] def sentence_cache_key(sentence: str, voice: str, speed: float) -> str: raw = f"{sentence}|{voice}|{speed}" return hashlib.md5(raw.encode()).hexdigest() def sentence_cache_path(key: str): return os.path.join(CACHE_DIR, f"{key}.wav") def disk_meta_path(key: str): return os.path.join(CACHE_DIR, f"{key}.json") def put_memory_cache(key: str, value: dict): with cache_lock: memory_cache[key] = value def get_memory_cache(key: str) -> Optional[dict]: with cache_lock: return memory_cache.get(key) async def load_sentence_from_disk(key: str) -> Optional[dict]: wav_path = sentence_cache_path(key) meta_path = disk_meta_path(key) if not Path(wav_path).exists() or not Path(meta_path).exists(): return None async with aiofiles.open(meta_path, "r") as f: meta_text = await f.read() meta = json.loads(meta_text) audio, sr = await asyncio.get_event_loop().run_in_executor( executor, lambda: sf.read(wav_path) ) with io.BytesIO() as buf: # sf.write(buf, audio, sr, "WAV") with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format='WAV', subtype='PCM_16') as f: f.write(audio) audio_b64 = base64.b64encode(buf.getvalue()).decode() meta["audio"] = audio_b64 return meta async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str): wav_path = sentence_cache_path(key) meta_path = disk_meta_path(key) await asyncio.get_event_loop().run_in_executor( executor, lambda: sf.write(wav_path, audio, sr, format="WAV") ) meta = { "sentence": sentence, "sample_rate": sr, } async with aiofiles.open(meta_path, "w") as f: await f.write(json.dumps(meta)) return meta async def clean_disk_cache(): files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime) if len(files) <= DISK_CACHE_SIZE: return remove_n = len(files) - DISK_CACHE_SIZE for f in files[:remove_n]: try: json_path = disk_meta_path(f.stem) f.unlink() if Path(json_path).exists(): Path(json_path).unlink() except: pass async def check_client_alive(client_id): while True: await asyncio.sleep(0.3) if client_id not in current_requests or current_requests[client_id].get("interrupt"): return False # ============================================================ # 请求跟踪 # ============================================================ @app.middleware("http") async def track_clients(request: Request, call_next): client_id = ( request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4()) ) if request.url.path == "/generate": current_requests[client_id] = {"active": True} response = await call_next(request) if request.url.path == "/generate": current_requests.pop(client_id, None) return response # ============================================================ # 主接口 # ============================================================ @app.post("/generate") async def generate_audio_stream(data: Dict = Body(...)): async with request_semaphore: text = data.get("text", "") voice = data.get("voice", "af_heart") speed = float(data.get("speed", 1.0)) client_id = data.get("client_id", str(uuid.uuid4())) if not text: raise HTTPException(400, "文本不能为空") await load_model() # 取消旧请求 if client_id in current_requests: current_requests[client_id]["interrupt"] = True await asyncio.sleep(0.05) sentences = split_sentences(text) current_requests[client_id] = {"interrupt": False} async def stream(): client_alive_task = asyncio.create_task(check_client_alive(client_id)) try: for idx, sentence in enumerate(sentences): if client_alive_task.done(): break key = sentence_cache_key(sentence, voice, speed) # 1) 内存缓存 cache_item = get_memory_cache(key) if cache_item: yield json.dumps({ "index": idx, **cache_item }).encode() + b"\n" continue # 2) 磁盘缓存 disk_item = await load_sentence_from_disk(key) if disk_item: put_memory_cache(key, disk_item) yield json.dumps({ "index": idx, **disk_item }).encode() + b"\n" continue # 3) 无缓存 → 推理 generator = model_pipeline(sentence, voice=voice, speed=speed) # 仅取第一个 chunk for _, _, audio in generator: sr = 24000 with io.BytesIO() as buf: # sf.write(buf, audio, sr, "WAV") with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1, format='WAV', subtype='PCM_16') as f: f.write(audio) audio_b64 = base64.b64encode(buf.getvalue()).decode() item = { "sentence": sentence, "audio": audio_b64, "sample_rate": sr } put_memory_cache(key, item) # 保存磁盘缓存 await save_sentence_to_disk(key, audio, sr, sentence) await clean_disk_cache() yield json.dumps({ "index": idx, **item }).encode() + b"\n" break finally: current_requests.pop(client_id, None) if not client_alive_task.done(): client_alive_task.cancel() return StreamingResponse(stream(), media_type="application/x-ndjson") # ============================================================ # 其它接口 # ============================================================ @app.get("/clear-cache") async def clear_cache(): with cache_lock: memory_cache.clear() for f in Path(CACHE_DIR).glob("*"): f.unlink() return {"status": "success"} @app.get("/cache-info") async def get_cache_info(): with cache_lock: mem_count = len(memory_cache) disk_files = list(Path(CACHE_DIR).glob("*.wav")) return { "memory_cache": mem_count, "disk_cache": len(disk_files) } # ============================================================ # 启动 # ============================================================ if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8028)