from fastapi import FastAPI, HTTPException, Body, Request from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager from kokoro import KPipeline import soundfile as sf import io import base64 import re import logging import json import hashlib import asyncio import os import uuid from fastapi.middleware.cors import CORSMiddleware from typing import Dict, List, Optional from cachetools import TTLCache import concurrent.futures import threading import numpy as np from pathlib import Path import torch import time import logging logging.getLogger("torch").setLevel(logging.ERROR) # torch.backends.nnpack.enabled = False # 配置 USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true" MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 100)) # 内存缓存大小 DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 200)) # 磁盘缓存最大文件数 MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 10)) LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper() CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") # 磁盘缓存目录 IDLE_TIMEOUT = 30 * 60 # 20分钟空闲超时(秒) # 配置日志 logging.basicConfig(level=getattr(logging, LOG_LEVEL)) logger = logging.getLogger(__name__) # 确保缓存目录存在 Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) # 句子分割正则 SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+') # 内存缓存 memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000) cache_lock = threading.Lock() # 线程池用于I/O操作 executor = concurrent.futures.ThreadPoolExecutor(max_workers=20) # 全局模型实例和状态 model_pipeline = None current_requests: Dict[str, Dict] = {} request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) last_request_time = time.time() # 记录最后一次请求时间 model_lock = asyncio.Lock() # 用于同步模型加载/卸载 async def unload_model(): """释放模型和GPU资源""" global model_pipeline async with model_lock: if model_pipeline is not None: try: del model_pipeline model_pipeline = None torch.cuda.empty_cache() # 清理GPU缓存 logger.info("模型已卸载,GPU资源已释放") except Exception as e: logger.error(f"释放模型失败: {str(e)}") async def load_model(): """加载模型""" global model_pipeline async with model_lock: if model_pipeline is None: try: model_pipeline = KPipeline(lang_code='a') logger.info("模型加载成功") except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise async def idle_checker(): """检查空闲时间并释放资源""" global last_request_time while True: await asyncio.sleep(60) # 每分钟检查一次 if time.time() - last_request_time > IDLE_TIMEOUT and model_pipeline is not None: logger.info("超过20分钟无请求,释放GPU资源") await unload_model() @asynccontextmanager async def lifespan(app: FastAPI): """管理模型生命周期和空闲检测""" global model_pipeline, last_request_time try: await load_model() # 启动时加载模型 idle_task = asyncio.create_task(idle_checker()) # 启动空闲检测任务 yield except Exception as e: logger.error(f"生命周期启动失败: {str(e)}") raise finally: idle_task.cancel() # 停止空闲检测 await unload_model() # 关闭时释放模型 logger.info("应用关闭,资源已清理") app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def split_sentences(text: str) -> List[str]: """将文本分割成句子""" return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()] def generate_cache_key(text: str, voice: str, speed: float) -> str: """生成缓存键""" key_str = f"{text}_{voice}_{speed}" return hashlib.md5(key_str.encode('utf-8')).hexdigest() def get_cache_file_path(cache_key: str) -> str: """获取磁盘缓存文件路径""" return os.path.join(CACHE_DIR, f"{cache_key}.wav") async def save_to_memory_cache(cache_key: str, data: List[Dict]): """保存到内存缓存""" loop = asyncio.get_running_loop() try: with cache_lock: await loop.run_in_executor(executor, lambda: memory_cache.__setitem__(cache_key, data)) except Exception as e: logger.error(f"保存到内存缓存失败: {str(e)}") async def load_from_memory_cache(cache_key: str) -> Optional[List[Dict]]: """从内存缓存加载""" loop = asyncio.get_running_loop() try: with cache_lock: return await loop.run_in_executor(executor, lambda: memory_cache.get(cache_key)) except Exception as e: logger.error(f"从内存缓存加载失败: {str(e)}") return None async def save_to_disk_cache(cache_key: str, audio_data: np.ndarray): """保存音频数据到磁盘缓存""" loop = asyncio.get_running_loop() try: file_path = get_cache_file_path(cache_key) await loop.run_in_executor( executor, lambda: sf.write(file_path, audio_data, 24000, format="WAV") ) logger.debug(f"音频已保存到磁盘缓存: {file_path}") except Exception as e: logger.error(f"保存到磁盘缓存失败: {str(e)}") async def load_from_disk_cache(cache_key: str) -> Optional[np.ndarray]: """从磁盘缓存加载音频数据""" loop = asyncio.get_running_loop() try: file_path = get_cache_file_path(cache_key) if os.path.exists(file_path): audio_data, sample_rate = await loop.run_in_executor( executor, lambda: sf.read(file_path) ) return audio_data except Exception as e: logger.error(f"从磁盘缓存加载失败: {str(e)}") return None async def clean_disk_cache(): """清理磁盘缓存""" loop = asyncio.get_running_loop() try: cache_files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime) if len(cache_files) > DISK_CACHE_SIZE: files_to_delete = cache_files[:len(cache_files) - DISK_CACHE_SIZE] await loop.run_in_executor( executor, lambda: [f.unlink() for f in files_to_delete] ) logger.debug(f"已清理 {len(files_to_delete)} 个磁盘缓存文件") except Exception as e: logger.error(f"清理磁盘缓存失败: {str(e)}") @app.middleware("http") async def track_clients(request: Request, call_next): """跟踪客户端连接并更新最后请求时间""" global last_request_time client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4()) if request.url.path == "/generate": last_request_time = time.time() # 更新最后请求时间 current_requests[client_id] = {"active": True} logger.debug(f"客户端 {client_id} 已连接") response = await call_next(request) if request.url.path == "/generate" and client_id in current_requests: del current_requests[client_id] logger.debug(f"客户端 {client_id} 已断开") return response async def check_client_active(client_id: str, interval: float = 0.5) -> bool: """检查客户端是否活跃""" while True: await asyncio.sleep(interval) if client_id not in current_requests: logger.debug(f"客户端 {client_id} 不再活跃") return False if current_requests[client_id].get("interrupt"): return False @app.post("/generate") async def generate_audio_stream(data: Dict = Body(...)): """流式音频生成,带批处理和缓存""" global last_request_time async with request_semaphore: last_request_time = time.time() # 更新最后请求时间 text = data.get("text", "") voice = data.get("voice", "af_heart") speed = float(data.get("speed", 1.0)) client_id = data.get("client_id", str(uuid.uuid4())) # 确保模型已加载 await load_model() if not model_pipeline: raise HTTPException(500, "模型未初始化") if not text: raise HTTPException(400, "文本不能为空") # 中断之前的请求 if client_id in current_requests: current_requests[client_id]["interrupt"] = True await asyncio.sleep(0.05) cache_key = generate_cache_key(text, voice, speed) # 1. 检查内存缓存 cached_data = await load_from_memory_cache(cache_key) if cached_data: logger.debug(f"从内存缓存流式传输 (key: {cache_key})") async def cached_stream(): for idx, item in enumerate(cached_data): yield json.dumps({ "index": idx, "sentence": item["sentence"], "audio": item["audio"], "sample_rate": 24000, "format": "wav", "cached": True }, ensure_ascii=False).encode() + b"\n" return StreamingResponse( cached_stream(), media_type="application/x-ndjson", headers={ "X-Content-Type-Options": "nosniff", "Cache-Control": "no-store", "X-Client-ID": client_id } ) # 2. 检查磁盘缓存 cached_audio = await load_from_disk_cache(cache_key) if cached_audio is not None: logger.debug(f"从磁盘缓存流式传输 (key: {cache_key})") with io.BytesIO() as buffer: sf.write(buffer, cached_audio, 24000, format="WAV") audio_bytes = buffer.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") cached_data = [{ "index": 0, "sentence": text, "audio": audio_b64, "sample_rate": 24000, "format": "wav" }] await save_to_memory_cache(cache_key, cached_data) return StreamingResponse( iter([json.dumps({ "index": 0, "sentence": text, "audio": audio_b64, "sample_rate": 24000, "format": "wav", "cached": True }, ensure_ascii=False).encode() + b"\n"]), media_type="application/x-ndjson", headers={ "X-Content-Type-Options": "nosniff", "Cache-Control": "no-store", "X-Client-ID": client_id } ) # 3. 没有缓存,需要生成 sentences = split_sentences(text) logger.debug(f"处理 {len(sentences)} 个句子") current_requests[client_id] = {"interrupt": False, "sentences": sentences, "active": True} async def audio_stream(): memory_cache_items = [] all_audio = [] client_check_task = asyncio.create_task(check_client_active(client_id)) try: batch_size = 4 for i in range(0, len(sentences), batch_size): batch = sentences[i:i + batch_size] if client_check_task.done() or current_requests[client_id].get("interrupt"): logger.debug(f"请求中断或客户端 {client_id} 已断开") break try: generators = [model_pipeline(s, voice=voice, speed=speed, split_pattern=None) for s in batch] for idx, (sentence, generator) in enumerate(zip(batch, generators)): global_idx = i + idx for _, _, audio in generator: if audio is None or len(audio) == 0: logger.warning(f"句子 {global_idx} 生成了空音频: {sentence}") yield json.dumps({ "index": global_idx, "sentence": sentence, "audio": "", "sample_rate": 24000, "format": "wav", "warning": "没有生成音频" }, ensure_ascii=False).encode() + b"\n" break all_audio.append(audio) with io.BytesIO() as buffer: sf.write(buffer, audio, 24000, format="WAV") audio_bytes = buffer.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") memory_cache_items.append({ "index": global_idx, "sentence": sentence, "audio": audio_b64, "sample_rate": 24000, "format": "wav" }) yield json.dumps({ "index": global_idx, "sentence": sentence, "audio": audio_b64, "sample_rate": 24000, "format": "wav" }, ensure_ascii=False).encode() + b"\n" break except Exception as e: logger.error(f"处理批次 {i} 时出错: {str(e)}") yield json.dumps({ "error": f"批次处理失败于索引 {i}", "message": str(e) }, ensure_ascii=False).encode() + b"\n" if memory_cache_items: await save_to_memory_cache(cache_key, memory_cache_items) if all_audio: combined_audio = np.concatenate(all_audio) await save_to_disk_cache(cache_key, combined_audio) logger.debug(f"已保存合并音频到磁盘缓存 (key: {cache_key})") await clean_disk_cache() except asyncio.CancelledError: logger.debug(f"客户端 {client_id} 的请求已取消") except Exception as e: logger.error(f"流错误: {str(e)}") finally: if client_id in current_requests: del current_requests[client_id] if not client_check_task.done(): client_check_task.cancel() logger.debug(f"已清理客户端 {client_id}") return StreamingResponse( audio_stream(), media_type="application/x-ndjson", headers={ "X-Content-Type-Options": "nosniff", "Cache-Control": "no-store", "X-Client-ID": client_id } ) @app.get("/clear-cache") async def clear_cache(): """清除所有缓存""" try: with cache_lock: memory_cache.clear() loop = asyncio.get_running_loop() await loop.run_in_executor( executor, lambda: [f.unlink() for f in Path(CACHE_DIR).glob("*.wav")] ) return {"status": "success", "message": "所有缓存已清除"} except Exception as e: raise HTTPException(500, f"清除缓存失败: {str(e)}") @app.get("/cache-info") async def get_cache_info(): """获取缓存信息""" try: with cache_lock: memory_count = len(memory_cache) memory_size = sum(len(json.dumps(item)) for item in memory_cache.values()) disk_files = list(Path(CACHE_DIR).glob("*.wav")) disk_count = len(disk_files) disk_size = sum(f.stat().st_size for f in disk_files) return { "memory_cache": { "count": memory_count, "size": memory_size, "max_size": MEMORY_CACHE_SIZE }, "disk_cache": { "count": disk_count, "size": disk_size, "max_size": DISK_CACHE_SIZE, "directory": CACHE_DIR } } except Exception as e: raise HTTPException(500, f"获取缓存信息失败: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8028)