|
@@ -0,0 +1,448 @@
|
|
|
|
|
+from fastapi import FastAPI, HTTPException, Body, Request
|
|
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
|
|
+from contextlib import asynccontextmanager
|
|
|
|
|
+from kokoro import KPipeline
|
|
|
|
|
+import soundfile as sf
|
|
|
|
|
+import io
|
|
|
|
|
+import base64
|
|
|
|
|
+import re
|
|
|
|
|
+import logging
|
|
|
|
|
+import json
|
|
|
|
|
+import hashlib
|
|
|
|
|
+import asyncio
|
|
|
|
|
+import os
|
|
|
|
|
+import uuid
|
|
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
+from typing import Dict, List, Optional
|
|
|
|
|
+from cachetools import TTLCache
|
|
|
|
|
+import concurrent.futures
|
|
|
|
|
+import threading
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+import torch
|
|
|
|
|
+import time
|
|
|
|
|
+
|
|
|
|
|
+import logging
|
|
|
|
|
+logging.getLogger("torch").setLevel(logging.ERROR)
|
|
|
|
|
+
|
|
|
|
|
+# torch.backends.nnpack.enabled = False
|
|
|
|
|
+
|
|
|
|
|
+# 配置
|
|
|
|
|
+USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
|
|
|
|
|
+MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 100)) # 内存缓存大小
|
|
|
|
|
+DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 200)) # 磁盘缓存最大文件数
|
|
|
|
|
+MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 10))
|
|
|
|
|
+LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper()
|
|
|
|
|
+CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") # 磁盘缓存目录
|
|
|
|
|
+IDLE_TIMEOUT = 30 * 60 # 20分钟空闲超时(秒)
|
|
|
|
|
+
|
|
|
|
|
+# 配置日志
|
|
|
|
|
+logging.basicConfig(level=getattr(logging, LOG_LEVEL))
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+# 确保缓存目录存在
|
|
|
|
|
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+# 句子分割正则
|
|
|
|
|
+SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
|
|
|
|
|
+
|
|
|
|
|
+# 内存缓存
|
|
|
|
|
+memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
|
|
|
|
|
+cache_lock = threading.Lock()
|
|
|
|
|
+
|
|
|
|
|
+# 线程池用于I/O操作
|
|
|
|
|
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=20)
|
|
|
|
|
+
|
|
|
|
|
+# 全局模型实例和状态
|
|
|
|
|
+model_pipeline = None
|
|
|
|
|
+current_requests: Dict[str, Dict] = {}
|
|
|
|
|
+request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
|
|
|
|
+last_request_time = time.time() # 记录最后一次请求时间
|
|
|
|
|
+model_lock = asyncio.Lock() # 用于同步模型加载/卸载
|
|
|
|
|
+
|
|
|
|
|
+async def unload_model():
|
|
|
|
|
+ """释放模型和GPU资源"""
|
|
|
|
|
+ global model_pipeline
|
|
|
|
|
+ async with model_lock:
|
|
|
|
|
+ if model_pipeline is not None:
|
|
|
|
|
+ try:
|
|
|
|
|
+ del model_pipeline
|
|
|
|
|
+ model_pipeline = None
|
|
|
|
|
+ torch.cuda.empty_cache() # 清理GPU缓存
|
|
|
|
|
+ logger.info("模型已卸载,GPU资源已释放")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"释放模型失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+async def load_model():
|
|
|
|
|
+ """加载模型"""
|
|
|
|
|
+ global model_pipeline
|
|
|
|
|
+ async with model_lock:
|
|
|
|
|
+ if model_pipeline is None:
|
|
|
|
|
+ try:
|
|
|
|
|
+ model_pipeline = KPipeline(lang_code='a')
|
|
|
|
|
+ logger.info("模型加载成功")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"模型加载失败: {str(e)}")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
|
|
+async def idle_checker():
|
|
|
|
|
+ """检查空闲时间并释放资源"""
|
|
|
|
|
+ global last_request_time
|
|
|
|
|
+ while True:
|
|
|
|
|
+ await asyncio.sleep(60) # 每分钟检查一次
|
|
|
|
|
+ if time.time() - last_request_time > IDLE_TIMEOUT and model_pipeline is not None:
|
|
|
|
|
+ logger.info("超过20分钟无请求,释放GPU资源")
|
|
|
|
|
+ await unload_model()
|
|
|
|
|
+
|
|
|
|
|
+@asynccontextmanager
|
|
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
|
|
+ """管理模型生命周期和空闲检测"""
|
|
|
|
|
+ global model_pipeline, last_request_time
|
|
|
|
|
+ try:
|
|
|
|
|
+ await load_model() # 启动时加载模型
|
|
|
|
|
+ idle_task = asyncio.create_task(idle_checker()) # 启动空闲检测任务
|
|
|
|
|
+ yield
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"生命周期启动失败: {str(e)}")
|
|
|
|
|
+ raise
|
|
|
|
|
+ finally:
|
|
|
|
|
+ idle_task.cancel() # 停止空闲检测
|
|
|
|
|
+ await unload_model() # 关闭时释放模型
|
|
|
|
|
+ logger.info("应用关闭,资源已清理")
|
|
|
|
|
+
|
|
|
|
|
+app = FastAPI(lifespan=lifespan)
|
|
|
|
|
+
|
|
|
|
|
+app.add_middleware(
|
|
|
|
|
+ CORSMiddleware,
|
|
|
|
|
+ allow_origins=["*"],
|
|
|
|
|
+ allow_credentials=True,
|
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+def split_sentences(text: str) -> List[str]:
|
|
|
|
|
+ """将文本分割成句子"""
|
|
|
|
|
+ return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
|
|
|
|
|
+
|
|
|
|
|
+def generate_cache_key(text: str, voice: str, speed: float) -> str:
|
|
|
|
|
+ """生成缓存键"""
|
|
|
|
|
+ key_str = f"{text}_{voice}_{speed}"
|
|
|
|
|
+ return hashlib.md5(key_str.encode('utf-8')).hexdigest()
|
|
|
|
|
+
|
|
|
|
|
+def get_cache_file_path(cache_key: str) -> str:
|
|
|
|
|
+ """获取磁盘缓存文件路径"""
|
|
|
|
|
+ return os.path.join(CACHE_DIR, f"{cache_key}.wav")
|
|
|
|
|
+
|
|
|
|
|
+async def save_to_memory_cache(cache_key: str, data: List[Dict]):
|
|
|
|
|
+ """保存到内存缓存"""
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with cache_lock:
|
|
|
|
|
+ await loop.run_in_executor(executor, lambda: memory_cache.__setitem__(cache_key, data))
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"保存到内存缓存失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+async def load_from_memory_cache(cache_key: str) -> Optional[List[Dict]]:
|
|
|
|
|
+ """从内存缓存加载"""
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with cache_lock:
|
|
|
|
|
+ return await loop.run_in_executor(executor, lambda: memory_cache.get(cache_key))
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"从内存缓存加载失败: {str(e)}")
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+async def save_to_disk_cache(cache_key: str, audio_data: np.ndarray):
|
|
|
|
|
+ """保存音频数据到磁盘缓存"""
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ try:
|
|
|
|
|
+ file_path = get_cache_file_path(cache_key)
|
|
|
|
|
+ await loop.run_in_executor(
|
|
|
|
|
+ executor,
|
|
|
|
|
+ lambda: sf.write(file_path, audio_data, 24000, format="WAV")
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.debug(f"音频已保存到磁盘缓存: {file_path}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"保存到磁盘缓存失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+async def load_from_disk_cache(cache_key: str) -> Optional[np.ndarray]:
|
|
|
|
|
+ """从磁盘缓存加载音频数据"""
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ try:
|
|
|
|
|
+ file_path = get_cache_file_path(cache_key)
|
|
|
|
|
+ if os.path.exists(file_path):
|
|
|
|
|
+ audio_data, sample_rate = await loop.run_in_executor(
|
|
|
|
|
+ executor,
|
|
|
|
|
+ lambda: sf.read(file_path)
|
|
|
|
|
+ )
|
|
|
|
|
+ return audio_data
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"从磁盘缓存加载失败: {str(e)}")
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+async def clean_disk_cache():
|
|
|
|
|
+ """清理磁盘缓存"""
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ try:
|
|
|
|
|
+ cache_files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
|
|
|
|
|
+ if len(cache_files) > DISK_CACHE_SIZE:
|
|
|
|
|
+ files_to_delete = cache_files[:len(cache_files) - DISK_CACHE_SIZE]
|
|
|
|
|
+ await loop.run_in_executor(
|
|
|
|
|
+ executor,
|
|
|
|
|
+ lambda: [f.unlink() for f in files_to_delete]
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.debug(f"已清理 {len(files_to_delete)} 个磁盘缓存文件")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"清理磁盘缓存失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+@app.middleware("http")
|
|
|
|
|
+async def track_clients(request: Request, call_next):
|
|
|
|
|
+ """跟踪客户端连接并更新最后请求时间"""
|
|
|
|
|
+ global last_request_time
|
|
|
|
|
+ client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
|
|
|
|
|
+ if request.url.path == "/generate":
|
|
|
|
|
+ last_request_time = time.time() # 更新最后请求时间
|
|
|
|
|
+ current_requests[client_id] = {"active": True}
|
|
|
|
|
+ logger.debug(f"客户端 {client_id} 已连接")
|
|
|
|
|
+
|
|
|
|
|
+ response = await call_next(request)
|
|
|
|
|
+
|
|
|
|
|
+ if request.url.path == "/generate" and client_id in current_requests:
|
|
|
|
|
+ del current_requests[client_id]
|
|
|
|
|
+ logger.debug(f"客户端 {client_id} 已断开")
|
|
|
|
|
+
|
|
|
|
|
+ return response
|
|
|
|
|
+
|
|
|
|
|
+async def check_client_active(client_id: str, interval: float = 0.5) -> bool:
|
|
|
|
|
+ """检查客户端是否活跃"""
|
|
|
|
|
+ while True:
|
|
|
|
|
+ await asyncio.sleep(interval)
|
|
|
|
|
+ if client_id not in current_requests:
|
|
|
|
|
+ logger.debug(f"客户端 {client_id} 不再活跃")
|
|
|
|
|
+ return False
|
|
|
|
|
+ if current_requests[client_id].get("interrupt"):
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/generate")
|
|
|
|
|
+async def generate_audio_stream(data: Dict = Body(...)):
|
|
|
|
|
+ """流式音频生成,带批处理和缓存"""
|
|
|
|
|
+ global last_request_time
|
|
|
|
|
+ async with request_semaphore:
|
|
|
|
|
+ last_request_time = time.time() # 更新最后请求时间
|
|
|
|
|
+ text = data.get("text", "")
|
|
|
|
|
+ voice = data.get("voice", "af_heart")
|
|
|
|
|
+ speed = float(data.get("speed", 1.0))
|
|
|
|
|
+ client_id = data.get("client_id", str(uuid.uuid4()))
|
|
|
|
|
+
|
|
|
|
|
+ # 确保模型已加载
|
|
|
|
|
+ await load_model()
|
|
|
|
|
+ if not model_pipeline:
|
|
|
|
|
+ raise HTTPException(500, "模型未初始化")
|
|
|
|
|
+ if not text:
|
|
|
|
|
+ raise HTTPException(400, "文本不能为空")
|
|
|
|
|
+
|
|
|
|
|
+ # 中断之前的请求
|
|
|
|
|
+ if client_id in current_requests:
|
|
|
|
|
+ current_requests[client_id]["interrupt"] = True
|
|
|
|
|
+ await asyncio.sleep(0.05)
|
|
|
|
|
+
|
|
|
|
|
+ cache_key = generate_cache_key(text, voice, speed)
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 检查内存缓存
|
|
|
|
|
+ cached_data = await load_from_memory_cache(cache_key)
|
|
|
|
|
+ if cached_data:
|
|
|
|
|
+ logger.debug(f"从内存缓存流式传输 (key: {cache_key})")
|
|
|
|
|
+ async def cached_stream():
|
|
|
|
|
+ for idx, item in enumerate(cached_data):
|
|
|
|
|
+ yield json.dumps({
|
|
|
|
|
+ "index": idx,
|
|
|
|
|
+ "sentence": item["sentence"],
|
|
|
|
|
+ "audio": item["audio"],
|
|
|
|
|
+ "sample_rate": 24000,
|
|
|
|
|
+ "format": "wav",
|
|
|
|
|
+ "cached": True
|
|
|
|
|
+ }, ensure_ascii=False).encode() + b"\n"
|
|
|
|
|
+
|
|
|
|
|
+ return StreamingResponse(
|
|
|
|
|
+ cached_stream(),
|
|
|
|
|
+ media_type="application/x-ndjson",
|
|
|
|
|
+ headers={
|
|
|
|
|
+ "X-Content-Type-Options": "nosniff",
|
|
|
|
|
+ "Cache-Control": "no-store",
|
|
|
|
|
+ "X-Client-ID": client_id
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 检查磁盘缓存
|
|
|
|
|
+ cached_audio = await load_from_disk_cache(cache_key)
|
|
|
|
|
+ if cached_audio is not None:
|
|
|
|
|
+ logger.debug(f"从磁盘缓存流式传输 (key: {cache_key})")
|
|
|
|
|
+ with io.BytesIO() as buffer:
|
|
|
|
|
+ sf.write(buffer, cached_audio, 24000, format="WAV")
|
|
|
|
|
+ audio_bytes = buffer.getvalue()
|
|
|
|
|
+
|
|
|
|
|
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
|
|
|
|
+ cached_data = [{
|
|
|
|
|
+ "index": 0,
|
|
|
|
|
+ "sentence": text,
|
|
|
|
|
+ "audio": audio_b64,
|
|
|
|
|
+ "sample_rate": 24000,
|
|
|
|
|
+ "format": "wav"
|
|
|
|
|
+ }]
|
|
|
|
|
+ await save_to_memory_cache(cache_key, cached_data)
|
|
|
|
|
+
|
|
|
|
|
+ return StreamingResponse(
|
|
|
|
|
+ iter([json.dumps({
|
|
|
|
|
+ "index": 0,
|
|
|
|
|
+ "sentence": text,
|
|
|
|
|
+ "audio": audio_b64,
|
|
|
|
|
+ "sample_rate": 24000,
|
|
|
|
|
+ "format": "wav",
|
|
|
|
|
+ "cached": True
|
|
|
|
|
+ }, ensure_ascii=False).encode() + b"\n"]),
|
|
|
|
|
+ media_type="application/x-ndjson",
|
|
|
|
|
+ headers={
|
|
|
|
|
+ "X-Content-Type-Options": "nosniff",
|
|
|
|
|
+ "Cache-Control": "no-store",
|
|
|
|
|
+ "X-Client-ID": client_id
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 没有缓存,需要生成
|
|
|
|
|
+ sentences = split_sentences(text)
|
|
|
|
|
+ logger.debug(f"处理 {len(sentences)} 个句子")
|
|
|
|
|
+ current_requests[client_id] = {"interrupt": False, "sentences": sentences, "active": True}
|
|
|
|
|
+
|
|
|
|
|
+ async def audio_stream():
|
|
|
|
|
+ memory_cache_items = []
|
|
|
|
|
+ all_audio = []
|
|
|
|
|
+ client_check_task = asyncio.create_task(check_client_active(client_id))
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ batch_size = 4
|
|
|
|
|
+ for i in range(0, len(sentences), batch_size):
|
|
|
|
|
+ batch = sentences[i:i + batch_size]
|
|
|
|
|
+ if client_check_task.done() or current_requests[client_id].get("interrupt"):
|
|
|
|
|
+ logger.debug(f"请求中断或客户端 {client_id} 已断开")
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ generators = [model_pipeline(s, voice=voice, speed=speed, split_pattern=None) for s in batch]
|
|
|
|
|
+ for idx, (sentence, generator) in enumerate(zip(batch, generators)):
|
|
|
|
|
+ global_idx = i + idx
|
|
|
|
|
+ for _, _, audio in generator:
|
|
|
|
|
+ if audio is None or len(audio) == 0:
|
|
|
|
|
+ logger.warning(f"句子 {global_idx} 生成了空音频: {sentence}")
|
|
|
|
|
+ yield json.dumps({
|
|
|
|
|
+ "index": global_idx,
|
|
|
|
|
+ "sentence": sentence,
|
|
|
|
|
+ "audio": "",
|
|
|
|
|
+ "sample_rate": 24000,
|
|
|
|
|
+ "format": "wav",
|
|
|
|
|
+ "warning": "没有生成音频"
|
|
|
|
|
+ }, ensure_ascii=False).encode() + b"\n"
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ all_audio.append(audio)
|
|
|
|
|
+ with io.BytesIO() as buffer:
|
|
|
|
|
+ sf.write(buffer, audio, 24000, format="WAV")
|
|
|
|
|
+ audio_bytes = buffer.getvalue()
|
|
|
|
|
+
|
|
|
|
|
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
|
|
|
|
+ memory_cache_items.append({
|
|
|
|
|
+ "index": global_idx,
|
|
|
|
|
+ "sentence": sentence,
|
|
|
|
|
+ "audio": audio_b64,
|
|
|
|
|
+ "sample_rate": 24000,
|
|
|
|
|
+ "format": "wav"
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ yield json.dumps({
|
|
|
|
|
+ "index": global_idx,
|
|
|
|
|
+ "sentence": sentence,
|
|
|
|
|
+ "audio": audio_b64,
|
|
|
|
|
+ "sample_rate": 24000,
|
|
|
|
|
+ "format": "wav"
|
|
|
|
|
+ }, ensure_ascii=False).encode() + b"\n"
|
|
|
|
|
+ break
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"处理批次 {i} 时出错: {str(e)}")
|
|
|
|
|
+ yield json.dumps({
|
|
|
|
|
+ "error": f"批次处理失败于索引 {i}",
|
|
|
|
|
+ "message": str(e)
|
|
|
|
|
+ }, ensure_ascii=False).encode() + b"\n"
|
|
|
|
|
+
|
|
|
|
|
+ if memory_cache_items:
|
|
|
|
|
+ await save_to_memory_cache(cache_key, memory_cache_items)
|
|
|
|
|
+ if all_audio:
|
|
|
|
|
+ combined_audio = np.concatenate(all_audio)
|
|
|
|
|
+ await save_to_disk_cache(cache_key, combined_audio)
|
|
|
|
|
+ logger.debug(f"已保存合并音频到磁盘缓存 (key: {cache_key})")
|
|
|
|
|
+ await clean_disk_cache()
|
|
|
|
|
+
|
|
|
|
|
+ except asyncio.CancelledError:
|
|
|
|
|
+ logger.debug(f"客户端 {client_id} 的请求已取消")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"流错误: {str(e)}")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if client_id in current_requests:
|
|
|
|
|
+ del current_requests[client_id]
|
|
|
|
|
+ if not client_check_task.done():
|
|
|
|
|
+ client_check_task.cancel()
|
|
|
|
|
+ logger.debug(f"已清理客户端 {client_id}")
|
|
|
|
|
+
|
|
|
|
|
+ return StreamingResponse(
|
|
|
|
|
+ audio_stream(),
|
|
|
|
|
+ media_type="application/x-ndjson",
|
|
|
|
|
+ headers={
|
|
|
|
|
+ "X-Content-Type-Options": "nosniff",
|
|
|
|
|
+ "Cache-Control": "no-store",
|
|
|
|
|
+ "X-Client-ID": client_id
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/clear-cache")
|
|
|
|
|
+async def clear_cache():
|
|
|
|
|
+ """清除所有缓存"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ with cache_lock:
|
|
|
|
|
+ memory_cache.clear()
|
|
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
|
|
+ await loop.run_in_executor(
|
|
|
|
|
+ executor,
|
|
|
|
|
+ lambda: [f.unlink() for f in Path(CACHE_DIR).glob("*.wav")]
|
|
|
|
|
+ )
|
|
|
|
|
+ return {"status": "success", "message": "所有缓存已清除"}
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ raise HTTPException(500, f"清除缓存失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/cache-info")
|
|
|
|
|
+async def get_cache_info():
|
|
|
|
|
+ """获取缓存信息"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ with cache_lock:
|
|
|
|
|
+ memory_count = len(memory_cache)
|
|
|
|
|
+ memory_size = sum(len(json.dumps(item)) for item in memory_cache.values())
|
|
|
|
|
+ disk_files = list(Path(CACHE_DIR).glob("*.wav"))
|
|
|
|
|
+ disk_count = len(disk_files)
|
|
|
|
|
+ disk_size = sum(f.stat().st_size for f in disk_files)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "memory_cache": {
|
|
|
|
|
+ "count": memory_count,
|
|
|
|
|
+ "size": memory_size,
|
|
|
|
|
+ "max_size": MEMORY_CACHE_SIZE
|
|
|
|
|
+ },
|
|
|
|
|
+ "disk_cache": {
|
|
|
|
|
+ "count": disk_count,
|
|
|
|
|
+ "size": disk_size,
|
|
|
|
|
+ "max_size": DISK_CACHE_SIZE,
|
|
|
|
|
+ "directory": CACHE_DIR
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ raise HTTPException(500, f"获取缓存信息失败: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ import uvicorn
|
|
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=8028)
|