| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- from fastapi import FastAPI, HTTPException, Body, Request
- from fastapi.responses import StreamingResponse
- from contextlib import asynccontextmanager
- from kokoro import KPipeline
- import soundfile as sf
- import io
- import base64
- import re
- import logging
- import json
- import hashlib
- import asyncio
- import os
- import uuid
- import aiofiles
- from fastapi.middleware.cors import CORSMiddleware
- from typing import Dict, List, Optional
- from cachetools import TTLCache
- import concurrent.futures
- import threading
- import numpy as np
- from pathlib import Path
- import sys
- os.environ["USE_NNPACK"] = "0"
- os.environ["NNPACK_DISABLE"] = "1"
- import torch
- torch.backends.nnpack.enabled = False
- import warnings
- warnings.filterwarnings("ignore", message=".*NNPACK.*")
- # ------------------- 配置 -------------------
- USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
- MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
- DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
- MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 12))
- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
- CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
- logging.basicConfig(level=getattr(logging, LOG_LEVEL))
- logger = logging.getLogger(__name__)
- Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
- SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
- # 内存缓存:单句缓存结构 {cache_key → { "audio": b64, "sr": 24000 }}
- memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
- cache_lock = threading.Lock()
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
- # 模型(常驻)
- model_pipeline = None
- model_lock = asyncio.Lock()
- # 并发管理
- request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
- current_requests: Dict[str, Dict] = {}
- # ============================================================
- # 模型加载
- # ============================================================
- async def load_model():
- global model_pipeline
- async with model_lock:
- if model_pipeline is None:
- try:
- model_pipeline = KPipeline(lang_code='a', device="cpu")
- logger.info("模型加载完成 (CPU)")
- except Exception as e:
- logger.error(f"模型加载失败: {e}")
- raise
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- await load_model()
- yield
- logger.info("应用关闭中...")
- app = FastAPI(lifespan=lifespan)
- # CORS 允许所有来源
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # ============================================================
- # 工具函数
- # ============================================================
- def split_sentences(text: str) -> List[str]:
- return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
- def sentence_cache_key(sentence: str, voice: str, speed: float) -> str:
- raw = f"{sentence}|{voice}|{speed}"
- return hashlib.md5(raw.encode()).hexdigest()
- def sentence_cache_path(key: str):
- return os.path.join(CACHE_DIR, f"{key}.wav")
- def disk_meta_path(key: str):
- return os.path.join(CACHE_DIR, f"{key}.json")
- def put_memory_cache(key: str, value: dict):
- with cache_lock:
- memory_cache[key] = value
- def get_memory_cache(key: str) -> Optional[dict]:
- with cache_lock:
- return memory_cache.get(key)
- async def load_sentence_from_disk(key: str) -> Optional[dict]:
- wav_path = sentence_cache_path(key)
- meta_path = disk_meta_path(key)
- if not Path(wav_path).exists() or not Path(meta_path).exists():
- return None
- async with aiofiles.open(meta_path, "r") as f:
- meta_text = await f.read()
- meta = json.loads(meta_text)
- audio, sr = await asyncio.get_event_loop().run_in_executor(
- executor, lambda: sf.read(wav_path)
- )
- with io.BytesIO() as buf:
- # sf.write(buf, audio, sr, "WAV")
- with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1,
- format='WAV', subtype='PCM_16') as f:
- f.write(audio)
- audio_b64 = base64.b64encode(buf.getvalue()).decode()
- meta["audio"] = audio_b64
- return meta
- async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
- wav_path = sentence_cache_path(key)
- meta_path = disk_meta_path(key)
- await asyncio.get_event_loop().run_in_executor(
- executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
- )
- meta = {
- "sentence": sentence,
- "sample_rate": sr,
- }
- async with aiofiles.open(meta_path, "w") as f:
- await f.write(json.dumps(meta))
- return meta
- async def clean_disk_cache():
- files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
- if len(files) <= DISK_CACHE_SIZE:
- return
- remove_n = len(files) - DISK_CACHE_SIZE
- for f in files[:remove_n]:
- try:
- json_path = disk_meta_path(f.stem)
- f.unlink()
- if Path(json_path).exists():
- Path(json_path).unlink()
- except:
- pass
- async def check_client_alive(client_id):
- while True:
- await asyncio.sleep(0.3)
- if client_id not in current_requests or current_requests[client_id].get("interrupt"):
- return False
- # ============================================================
- # 请求跟踪
- # ============================================================
- @app.middleware("http")
- async def track_clients(request: Request, call_next):
- client_id = (
- request.query_params.get("client_id")
- or request.headers.get("X-Client-ID")
- or str(uuid.uuid4())
- )
- if request.url.path == "/generate":
- current_requests[client_id] = {"active": True}
- response = await call_next(request)
- if request.url.path == "/generate":
- current_requests.pop(client_id, None)
- return response
- # ============================================================
- # 主接口
- # ============================================================
- @app.post("/generate")
- async def generate_audio_stream(data: Dict = Body(...)):
- async with request_semaphore:
- text = data.get("text", "")
- voice = data.get("voice", "af_heart")
- speed = float(data.get("speed", 1.0))
- client_id = data.get("client_id", str(uuid.uuid4()))
- if not text:
- raise HTTPException(400, "文本不能为空")
- await load_model()
- # 取消旧请求
- if client_id in current_requests:
- current_requests[client_id]["interrupt"] = True
- await asyncio.sleep(0.05)
- sentences = split_sentences(text)
- current_requests[client_id] = {"interrupt": False}
- async def stream():
- client_alive_task = asyncio.create_task(check_client_alive(client_id))
- try:
- for idx, sentence in enumerate(sentences):
- if client_alive_task.done():
- break
- key = sentence_cache_key(sentence, voice, speed)
- # 1) 内存缓存
- cache_item = get_memory_cache(key)
- if cache_item:
- yield json.dumps({
- "index": idx,
- **cache_item
- }).encode() + b"\n"
- continue
- # 2) 磁盘缓存
- disk_item = await load_sentence_from_disk(key)
- if disk_item:
- put_memory_cache(key, disk_item)
- yield json.dumps({
- "index": idx,
- **disk_item
- }).encode() + b"\n"
- continue
- # 3) 无缓存 → 推理
- generator = model_pipeline(sentence, voice=voice, speed=speed)
- # 仅取第一个 chunk
- for _, _, audio in generator:
- sr = 24000
- with io.BytesIO() as buf:
- # sf.write(buf, audio, sr, "WAV")
- with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1,
- format='WAV', subtype='PCM_16') as f:
- f.write(audio)
- audio_b64 = base64.b64encode(buf.getvalue()).decode()
- item = {
- "sentence": sentence,
- "audio": audio_b64,
- "sample_rate": sr
- }
- put_memory_cache(key, item)
- # 保存磁盘缓存
- await save_sentence_to_disk(key, audio, sr, sentence)
- await clean_disk_cache()
- yield json.dumps({
- "index": idx,
- **item
- }).encode() + b"\n"
- break
- finally:
- current_requests.pop(client_id, None)
- if not client_alive_task.done():
- client_alive_task.cancel()
- return StreamingResponse(stream(), media_type="application/x-ndjson")
- # ============================================================
- # 其它接口
- # ============================================================
- @app.get("/clear-cache")
- async def clear_cache():
- with cache_lock:
- memory_cache.clear()
- for f in Path(CACHE_DIR).glob("*"):
- f.unlink()
- return {"status": "success"}
- @app.get("/cache-info")
- async def get_cache_info():
- with cache_lock:
- mem_count = len(memory_cache)
- disk_files = list(Path(CACHE_DIR).glob("*.wav"))
- return {
- "memory_cache": mem_count,
- "disk_cache": len(disk_files)
- }
- # ============================================================
- # 启动
- # ============================================================
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8028)
|