| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- from fastapi import FastAPI, HTTPException, Body, Request
- from fastapi.responses import StreamingResponse
- from contextlib import asynccontextmanager
- from kokoro import KPipeline
- import soundfile as sf
- import io
- import base64
- import re
- import logging
- import json
- import hashlib
- import asyncio
- import os
- import uuid
- from fastapi.middleware.cors import CORSMiddleware
- from typing import Dict, List, Optional
- from cachetools import TTLCache
- import concurrent.futures
- import threading
- import numpy as np
- from pathlib import Path
- import torch
- import time
- import logging
- logging.getLogger("torch").setLevel(logging.ERROR)
- # torch.backends.nnpack.enabled = False
- # 配置
- USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
- MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 100)) # 内存缓存大小
- DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 200)) # 磁盘缓存最大文件数
- MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 10))
- LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper()
- CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") # 磁盘缓存目录
- IDLE_TIMEOUT = 30 * 60 # 20分钟空闲超时(秒)
- # 配置日志
- logging.basicConfig(level=getattr(logging, LOG_LEVEL))
- logger = logging.getLogger(__name__)
- # 确保缓存目录存在
- Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
- # 句子分割正则
- SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
- # 内存缓存
- memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
- cache_lock = threading.Lock()
- # 线程池用于I/O操作
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=20)
- # 全局模型实例和状态
- model_pipeline = None
- current_requests: Dict[str, Dict] = {}
- request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
- last_request_time = time.time() # 记录最后一次请求时间
- model_lock = asyncio.Lock() # 用于同步模型加载/卸载
- async def unload_model():
- """释放模型和GPU资源"""
- global model_pipeline
- async with model_lock:
- if model_pipeline is not None:
- try:
- del model_pipeline
- model_pipeline = None
- torch.cuda.empty_cache() # 清理GPU缓存
- logger.info("模型已卸载,GPU资源已释放")
- except Exception as e:
- logger.error(f"释放模型失败: {str(e)}")
- async def load_model():
- """加载模型"""
- global model_pipeline
- async with model_lock:
- if model_pipeline is None:
- try:
- model_pipeline = KPipeline(lang_code='a')
- logger.info("模型加载成功")
- except Exception as e:
- logger.error(f"模型加载失败: {str(e)}")
- raise
- async def idle_checker():
- """检查空闲时间并释放资源"""
- global last_request_time
- while True:
- await asyncio.sleep(60) # 每分钟检查一次
- if time.time() - last_request_time > IDLE_TIMEOUT and model_pipeline is not None:
- logger.info("超过20分钟无请求,释放GPU资源")
- await unload_model()
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """管理模型生命周期和空闲检测"""
- global model_pipeline, last_request_time
- try:
- await load_model() # 启动时加载模型
- idle_task = asyncio.create_task(idle_checker()) # 启动空闲检测任务
- yield
- except Exception as e:
- logger.error(f"生命周期启动失败: {str(e)}")
- raise
- finally:
- idle_task.cancel() # 停止空闲检测
- await unload_model() # 关闭时释放模型
- logger.info("应用关闭,资源已清理")
- app = FastAPI(lifespan=lifespan)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- def split_sentences(text: str) -> List[str]:
- """将文本分割成句子"""
- return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
- def generate_cache_key(text: str, voice: str, speed: float) -> str:
- """生成缓存键"""
- key_str = f"{text}_{voice}_{speed}"
- return hashlib.md5(key_str.encode('utf-8')).hexdigest()
- def get_cache_file_path(cache_key: str) -> str:
- """获取磁盘缓存文件路径"""
- return os.path.join(CACHE_DIR, f"{cache_key}.wav")
- async def save_to_memory_cache(cache_key: str, data: List[Dict]):
- """保存到内存缓存"""
- loop = asyncio.get_running_loop()
- try:
- with cache_lock:
- await loop.run_in_executor(executor, lambda: memory_cache.__setitem__(cache_key, data))
- except Exception as e:
- logger.error(f"保存到内存缓存失败: {str(e)}")
- async def load_from_memory_cache(cache_key: str) -> Optional[List[Dict]]:
- """从内存缓存加载"""
- loop = asyncio.get_running_loop()
- try:
- with cache_lock:
- return await loop.run_in_executor(executor, lambda: memory_cache.get(cache_key))
- except Exception as e:
- logger.error(f"从内存缓存加载失败: {str(e)}")
- return None
- async def save_to_disk_cache(cache_key: str, audio_data: np.ndarray):
- """保存音频数据到磁盘缓存"""
- loop = asyncio.get_running_loop()
- try:
- file_path = get_cache_file_path(cache_key)
- await loop.run_in_executor(
- executor,
- lambda: sf.write(file_path, audio_data, 24000, format="WAV")
- )
- logger.debug(f"音频已保存到磁盘缓存: {file_path}")
- except Exception as e:
- logger.error(f"保存到磁盘缓存失败: {str(e)}")
- async def load_from_disk_cache(cache_key: str) -> Optional[np.ndarray]:
- """从磁盘缓存加载音频数据"""
- loop = asyncio.get_running_loop()
- try:
- file_path = get_cache_file_path(cache_key)
- if os.path.exists(file_path):
- audio_data, sample_rate = await loop.run_in_executor(
- executor,
- lambda: sf.read(file_path)
- )
- return audio_data
- except Exception as e:
- logger.error(f"从磁盘缓存加载失败: {str(e)}")
- return None
- async def clean_disk_cache():
- """清理磁盘缓存"""
- loop = asyncio.get_running_loop()
- try:
- cache_files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
- if len(cache_files) > DISK_CACHE_SIZE:
- files_to_delete = cache_files[:len(cache_files) - DISK_CACHE_SIZE]
- await loop.run_in_executor(
- executor,
- lambda: [f.unlink() for f in files_to_delete]
- )
- logger.debug(f"已清理 {len(files_to_delete)} 个磁盘缓存文件")
- except Exception as e:
- logger.error(f"清理磁盘缓存失败: {str(e)}")
- @app.middleware("http")
- async def track_clients(request: Request, call_next):
- """跟踪客户端连接并更新最后请求时间"""
- global last_request_time
- client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
- if request.url.path == "/generate":
- last_request_time = time.time() # 更新最后请求时间
- current_requests[client_id] = {"active": True}
- logger.debug(f"客户端 {client_id} 已连接")
- response = await call_next(request)
- if request.url.path == "/generate" and client_id in current_requests:
- del current_requests[client_id]
- logger.debug(f"客户端 {client_id} 已断开")
- return response
- async def check_client_active(client_id: str, interval: float = 0.5) -> bool:
- """检查客户端是否活跃"""
- while True:
- await asyncio.sleep(interval)
- if client_id not in current_requests:
- logger.debug(f"客户端 {client_id} 不再活跃")
- return False
- if current_requests[client_id].get("interrupt"):
- return False
- @app.post("/generate")
- async def generate_audio_stream(data: Dict = Body(...)):
- """流式音频生成,带批处理和缓存"""
- global last_request_time
- async with request_semaphore:
- last_request_time = time.time() # 更新最后请求时间
- text = data.get("text", "")
- voice = data.get("voice", "af_heart")
- speed = float(data.get("speed", 1.0))
- client_id = data.get("client_id", str(uuid.uuid4()))
- # 确保模型已加载
- await load_model()
- if not model_pipeline:
- raise HTTPException(500, "模型未初始化")
- if not text:
- raise HTTPException(400, "文本不能为空")
- # 中断之前的请求
- if client_id in current_requests:
- current_requests[client_id]["interrupt"] = True
- await asyncio.sleep(0.05)
- cache_key = generate_cache_key(text, voice, speed)
- # 1. 检查内存缓存
- cached_data = await load_from_memory_cache(cache_key)
- if cached_data:
- logger.debug(f"从内存缓存流式传输 (key: {cache_key})")
- async def cached_stream():
- for idx, item in enumerate(cached_data):
- yield json.dumps({
- "index": idx,
- "sentence": item["sentence"],
- "audio": item["audio"],
- "sample_rate": 24000,
- "format": "wav",
- "cached": True
- }, ensure_ascii=False).encode() + b"\n"
- return StreamingResponse(
- cached_stream(),
- media_type="application/x-ndjson",
- headers={
- "X-Content-Type-Options": "nosniff",
- "Cache-Control": "no-store",
- "X-Client-ID": client_id
- }
- )
- # 2. 检查磁盘缓存
- cached_audio = await load_from_disk_cache(cache_key)
- if cached_audio is not None:
- logger.debug(f"从磁盘缓存流式传输 (key: {cache_key})")
- with io.BytesIO() as buffer:
- sf.write(buffer, cached_audio, 24000, format="WAV")
- audio_bytes = buffer.getvalue()
- audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
- cached_data = [{
- "index": 0,
- "sentence": text,
- "audio": audio_b64,
- "sample_rate": 24000,
- "format": "wav"
- }]
- await save_to_memory_cache(cache_key, cached_data)
- return StreamingResponse(
- iter([json.dumps({
- "index": 0,
- "sentence": text,
- "audio": audio_b64,
- "sample_rate": 24000,
- "format": "wav",
- "cached": True
- }, ensure_ascii=False).encode() + b"\n"]),
- media_type="application/x-ndjson",
- headers={
- "X-Content-Type-Options": "nosniff",
- "Cache-Control": "no-store",
- "X-Client-ID": client_id
- }
- )
- # 3. 没有缓存,需要生成
- sentences = split_sentences(text)
- logger.debug(f"处理 {len(sentences)} 个句子")
- current_requests[client_id] = {"interrupt": False, "sentences": sentences, "active": True}
- async def audio_stream():
- memory_cache_items = []
- all_audio = []
- client_check_task = asyncio.create_task(check_client_active(client_id))
- try:
- batch_size = 4
- for i in range(0, len(sentences), batch_size):
- batch = sentences[i:i + batch_size]
- if client_check_task.done() or current_requests[client_id].get("interrupt"):
- logger.debug(f"请求中断或客户端 {client_id} 已断开")
- break
- try:
- generators = [model_pipeline(s, voice=voice, speed=speed, split_pattern=None) for s in batch]
- for idx, (sentence, generator) in enumerate(zip(batch, generators)):
- global_idx = i + idx
- for _, _, audio in generator:
- if audio is None or len(audio) == 0:
- logger.warning(f"句子 {global_idx} 生成了空音频: {sentence}")
- yield json.dumps({
- "index": global_idx,
- "sentence": sentence,
- "audio": "",
- "sample_rate": 24000,
- "format": "wav",
- "warning": "没有生成音频"
- }, ensure_ascii=False).encode() + b"\n"
- break
- all_audio.append(audio)
- with io.BytesIO() as buffer:
- sf.write(buffer, audio, 24000, format="WAV")
- audio_bytes = buffer.getvalue()
- audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
- memory_cache_items.append({
- "index": global_idx,
- "sentence": sentence,
- "audio": audio_b64,
- "sample_rate": 24000,
- "format": "wav"
- })
- yield json.dumps({
- "index": global_idx,
- "sentence": sentence,
- "audio": audio_b64,
- "sample_rate": 24000,
- "format": "wav"
- }, ensure_ascii=False).encode() + b"\n"
- break
- except Exception as e:
- logger.error(f"处理批次 {i} 时出错: {str(e)}")
- yield json.dumps({
- "error": f"批次处理失败于索引 {i}",
- "message": str(e)
- }, ensure_ascii=False).encode() + b"\n"
- if memory_cache_items:
- await save_to_memory_cache(cache_key, memory_cache_items)
- if all_audio:
- combined_audio = np.concatenate(all_audio)
- await save_to_disk_cache(cache_key, combined_audio)
- logger.debug(f"已保存合并音频到磁盘缓存 (key: {cache_key})")
- await clean_disk_cache()
- except asyncio.CancelledError:
- logger.debug(f"客户端 {client_id} 的请求已取消")
- except Exception as e:
- logger.error(f"流错误: {str(e)}")
- finally:
- if client_id in current_requests:
- del current_requests[client_id]
- if not client_check_task.done():
- client_check_task.cancel()
- logger.debug(f"已清理客户端 {client_id}")
- return StreamingResponse(
- audio_stream(),
- media_type="application/x-ndjson",
- headers={
- "X-Content-Type-Options": "nosniff",
- "Cache-Control": "no-store",
- "X-Client-ID": client_id
- }
- )
- @app.get("/clear-cache")
- async def clear_cache():
- """清除所有缓存"""
- try:
- with cache_lock:
- memory_cache.clear()
- loop = asyncio.get_running_loop()
- await loop.run_in_executor(
- executor,
- lambda: [f.unlink() for f in Path(CACHE_DIR).glob("*.wav")]
- )
- return {"status": "success", "message": "所有缓存已清除"}
- except Exception as e:
- raise HTTPException(500, f"清除缓存失败: {str(e)}")
- @app.get("/cache-info")
- async def get_cache_info():
- """获取缓存信息"""
- try:
- with cache_lock:
- memory_count = len(memory_cache)
- memory_size = sum(len(json.dumps(item)) for item in memory_cache.values())
- disk_files = list(Path(CACHE_DIR).glob("*.wav"))
- disk_count = len(disk_files)
- disk_size = sum(f.stat().st_size for f in disk_files)
- return {
- "memory_cache": {
- "count": memory_count,
- "size": memory_size,
- "max_size": MEMORY_CACHE_SIZE
- },
- "disk_cache": {
- "count": disk_count,
- "size": disk_size,
- "max_size": DISK_CACHE_SIZE,
- "directory": CACHE_DIR
- }
- }
- except Exception as e:
- raise HTTPException(500, f"获取缓存信息失败: {str(e)}")
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8028)
|