speech_tts_LRUandLocal.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. from fastapi import FastAPI, HTTPException, Body, Request
  2. from fastapi.responses import StreamingResponse
  3. from contextlib import asynccontextmanager
  4. from kokoro import KPipeline
  5. import soundfile as sf
  6. import io
  7. import base64
  8. import re
  9. import logging
  10. import json
  11. import hashlib
  12. import asyncio
  13. import os
  14. import uuid
  15. from fastapi.middleware.cors import CORSMiddleware
  16. from typing import Dict, List, Optional
  17. from cachetools import TTLCache
  18. import concurrent.futures
  19. import threading
  20. import numpy as np
  21. from pathlib import Path
  22. import torch
  23. import time
  24. import logging
  25. logging.getLogger("torch").setLevel(logging.ERROR)
  26. # torch.backends.nnpack.enabled = False
  27. # 配置
  28. USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
  29. MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 100)) # 内存缓存大小
  30. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 200)) # 磁盘缓存最大文件数
  31. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 10))
  32. LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper()
  33. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache") # 磁盘缓存目录
  34. IDLE_TIMEOUT = 30 * 60 # 20分钟空闲超时(秒)
  35. # 配置日志
  36. logging.basicConfig(level=getattr(logging, LOG_LEVEL))
  37. logger = logging.getLogger(__name__)
  38. # 确保缓存目录存在
  39. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  40. # 句子分割正则
  41. SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
  42. # 内存缓存
  43. memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
  44. cache_lock = threading.Lock()
  45. # 线程池用于I/O操作
  46. executor = concurrent.futures.ThreadPoolExecutor(max_workers=20)
  47. # 全局模型实例和状态
  48. model_pipeline = None
  49. current_requests: Dict[str, Dict] = {}
  50. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  51. last_request_time = time.time() # 记录最后一次请求时间
  52. model_lock = asyncio.Lock() # 用于同步模型加载/卸载
  53. async def unload_model():
  54. """释放模型和GPU资源"""
  55. global model_pipeline
  56. async with model_lock:
  57. if model_pipeline is not None:
  58. try:
  59. del model_pipeline
  60. model_pipeline = None
  61. torch.cuda.empty_cache() # 清理GPU缓存
  62. logger.info("模型已卸载,GPU资源已释放")
  63. except Exception as e:
  64. logger.error(f"释放模型失败: {str(e)}")
  65. async def load_model():
  66. """加载模型"""
  67. global model_pipeline
  68. async with model_lock:
  69. if model_pipeline is None:
  70. try:
  71. model_pipeline = KPipeline(lang_code='a')
  72. logger.info("模型加载成功")
  73. except Exception as e:
  74. logger.error(f"模型加载失败: {str(e)}")
  75. raise
  76. async def idle_checker():
  77. """检查空闲时间并释放资源"""
  78. global last_request_time
  79. while True:
  80. await asyncio.sleep(60) # 每分钟检查一次
  81. if time.time() - last_request_time > IDLE_TIMEOUT and model_pipeline is not None:
  82. logger.info("超过20分钟无请求,释放GPU资源")
  83. await unload_model()
  84. @asynccontextmanager
  85. async def lifespan(app: FastAPI):
  86. """管理模型生命周期和空闲检测"""
  87. global model_pipeline, last_request_time
  88. try:
  89. await load_model() # 启动时加载模型
  90. idle_task = asyncio.create_task(idle_checker()) # 启动空闲检测任务
  91. yield
  92. except Exception as e:
  93. logger.error(f"生命周期启动失败: {str(e)}")
  94. raise
  95. finally:
  96. idle_task.cancel() # 停止空闲检测
  97. await unload_model() # 关闭时释放模型
  98. logger.info("应用关闭,资源已清理")
  99. app = FastAPI(lifespan=lifespan)
  100. app.add_middleware(
  101. CORSMiddleware,
  102. allow_origins=["*"],
  103. allow_credentials=True,
  104. allow_methods=["*"],
  105. allow_headers=["*"],
  106. )
  107. def split_sentences(text: str) -> List[str]:
  108. """将文本分割成句子"""
  109. return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  110. def generate_cache_key(text: str, voice: str, speed: float) -> str:
  111. """生成缓存键"""
  112. key_str = f"{text}_{voice}_{speed}"
  113. return hashlib.md5(key_str.encode('utf-8')).hexdigest()
  114. def get_cache_file_path(cache_key: str) -> str:
  115. """获取磁盘缓存文件路径"""
  116. return os.path.join(CACHE_DIR, f"{cache_key}.wav")
  117. async def save_to_memory_cache(cache_key: str, data: List[Dict]):
  118. """保存到内存缓存"""
  119. loop = asyncio.get_running_loop()
  120. try:
  121. with cache_lock:
  122. await loop.run_in_executor(executor, lambda: memory_cache.__setitem__(cache_key, data))
  123. except Exception as e:
  124. logger.error(f"保存到内存缓存失败: {str(e)}")
  125. async def load_from_memory_cache(cache_key: str) -> Optional[List[Dict]]:
  126. """从内存缓存加载"""
  127. loop = asyncio.get_running_loop()
  128. try:
  129. with cache_lock:
  130. return await loop.run_in_executor(executor, lambda: memory_cache.get(cache_key))
  131. except Exception as e:
  132. logger.error(f"从内存缓存加载失败: {str(e)}")
  133. return None
  134. async def save_to_disk_cache(cache_key: str, audio_data: np.ndarray):
  135. """保存音频数据到磁盘缓存"""
  136. loop = asyncio.get_running_loop()
  137. try:
  138. file_path = get_cache_file_path(cache_key)
  139. await loop.run_in_executor(
  140. executor,
  141. lambda: sf.write(file_path, audio_data, 24000, format="WAV")
  142. )
  143. logger.debug(f"音频已保存到磁盘缓存: {file_path}")
  144. except Exception as e:
  145. logger.error(f"保存到磁盘缓存失败: {str(e)}")
  146. async def load_from_disk_cache(cache_key: str) -> Optional[np.ndarray]:
  147. """从磁盘缓存加载音频数据"""
  148. loop = asyncio.get_running_loop()
  149. try:
  150. file_path = get_cache_file_path(cache_key)
  151. if os.path.exists(file_path):
  152. audio_data, sample_rate = await loop.run_in_executor(
  153. executor,
  154. lambda: sf.read(file_path)
  155. )
  156. return audio_data
  157. except Exception as e:
  158. logger.error(f"从磁盘缓存加载失败: {str(e)}")
  159. return None
  160. async def clean_disk_cache():
  161. """清理磁盘缓存"""
  162. loop = asyncio.get_running_loop()
  163. try:
  164. cache_files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
  165. if len(cache_files) > DISK_CACHE_SIZE:
  166. files_to_delete = cache_files[:len(cache_files) - DISK_CACHE_SIZE]
  167. await loop.run_in_executor(
  168. executor,
  169. lambda: [f.unlink() for f in files_to_delete]
  170. )
  171. logger.debug(f"已清理 {len(files_to_delete)} 个磁盘缓存文件")
  172. except Exception as e:
  173. logger.error(f"清理磁盘缓存失败: {str(e)}")
  174. @app.middleware("http")
  175. async def track_clients(request: Request, call_next):
  176. """跟踪客户端连接并更新最后请求时间"""
  177. global last_request_time
  178. client_id = request.query_params.get("client_id") or request.headers.get("X-Client-ID") or str(uuid.uuid4())
  179. if request.url.path == "/generate":
  180. last_request_time = time.time() # 更新最后请求时间
  181. current_requests[client_id] = {"active": True}
  182. logger.debug(f"客户端 {client_id} 已连接")
  183. response = await call_next(request)
  184. if request.url.path == "/generate" and client_id in current_requests:
  185. del current_requests[client_id]
  186. logger.debug(f"客户端 {client_id} 已断开")
  187. return response
  188. async def check_client_active(client_id: str, interval: float = 0.5) -> bool:
  189. """检查客户端是否活跃"""
  190. while True:
  191. await asyncio.sleep(interval)
  192. if client_id not in current_requests:
  193. logger.debug(f"客户端 {client_id} 不再活跃")
  194. return False
  195. if current_requests[client_id].get("interrupt"):
  196. return False
  197. @app.post("/generate")
  198. async def generate_audio_stream(data: Dict = Body(...)):
  199. """流式音频生成,带批处理和缓存"""
  200. global last_request_time
  201. async with request_semaphore:
  202. last_request_time = time.time() # 更新最后请求时间
  203. text = data.get("text", "")
  204. voice = data.get("voice", "af_heart")
  205. speed = float(data.get("speed", 1.0))
  206. client_id = data.get("client_id", str(uuid.uuid4()))
  207. # 确保模型已加载
  208. await load_model()
  209. if not model_pipeline:
  210. raise HTTPException(500, "模型未初始化")
  211. if not text:
  212. raise HTTPException(400, "文本不能为空")
  213. # 中断之前的请求
  214. if client_id in current_requests:
  215. current_requests[client_id]["interrupt"] = True
  216. await asyncio.sleep(0.05)
  217. cache_key = generate_cache_key(text, voice, speed)
  218. # 1. 检查内存缓存
  219. cached_data = await load_from_memory_cache(cache_key)
  220. if cached_data:
  221. logger.debug(f"从内存缓存流式传输 (key: {cache_key})")
  222. async def cached_stream():
  223. for idx, item in enumerate(cached_data):
  224. yield json.dumps({
  225. "index": idx,
  226. "sentence": item["sentence"],
  227. "audio": item["audio"],
  228. "sample_rate": 24000,
  229. "format": "wav",
  230. "cached": True
  231. }, ensure_ascii=False).encode() + b"\n"
  232. return StreamingResponse(
  233. cached_stream(),
  234. media_type="application/x-ndjson",
  235. headers={
  236. "X-Content-Type-Options": "nosniff",
  237. "Cache-Control": "no-store",
  238. "X-Client-ID": client_id
  239. }
  240. )
  241. # 2. 检查磁盘缓存
  242. cached_audio = await load_from_disk_cache(cache_key)
  243. if cached_audio is not None:
  244. logger.debug(f"从磁盘缓存流式传输 (key: {cache_key})")
  245. with io.BytesIO() as buffer:
  246. sf.write(buffer, cached_audio, 24000, format="WAV")
  247. audio_bytes = buffer.getvalue()
  248. audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
  249. cached_data = [{
  250. "index": 0,
  251. "sentence": text,
  252. "audio": audio_b64,
  253. "sample_rate": 24000,
  254. "format": "wav"
  255. }]
  256. await save_to_memory_cache(cache_key, cached_data)
  257. return StreamingResponse(
  258. iter([json.dumps({
  259. "index": 0,
  260. "sentence": text,
  261. "audio": audio_b64,
  262. "sample_rate": 24000,
  263. "format": "wav",
  264. "cached": True
  265. }, ensure_ascii=False).encode() + b"\n"]),
  266. media_type="application/x-ndjson",
  267. headers={
  268. "X-Content-Type-Options": "nosniff",
  269. "Cache-Control": "no-store",
  270. "X-Client-ID": client_id
  271. }
  272. )
  273. # 3. 没有缓存,需要生成
  274. sentences = split_sentences(text)
  275. logger.debug(f"处理 {len(sentences)} 个句子")
  276. current_requests[client_id] = {"interrupt": False, "sentences": sentences, "active": True}
  277. async def audio_stream():
  278. memory_cache_items = []
  279. all_audio = []
  280. client_check_task = asyncio.create_task(check_client_active(client_id))
  281. try:
  282. batch_size = 4
  283. for i in range(0, len(sentences), batch_size):
  284. batch = sentences[i:i + batch_size]
  285. if client_check_task.done() or current_requests[client_id].get("interrupt"):
  286. logger.debug(f"请求中断或客户端 {client_id} 已断开")
  287. break
  288. try:
  289. generators = [model_pipeline(s, voice=voice, speed=speed, split_pattern=None) for s in batch]
  290. for idx, (sentence, generator) in enumerate(zip(batch, generators)):
  291. global_idx = i + idx
  292. for _, _, audio in generator:
  293. if audio is None or len(audio) == 0:
  294. logger.warning(f"句子 {global_idx} 生成了空音频: {sentence}")
  295. yield json.dumps({
  296. "index": global_idx,
  297. "sentence": sentence,
  298. "audio": "",
  299. "sample_rate": 24000,
  300. "format": "wav",
  301. "warning": "没有生成音频"
  302. }, ensure_ascii=False).encode() + b"\n"
  303. break
  304. all_audio.append(audio)
  305. with io.BytesIO() as buffer:
  306. sf.write(buffer, audio, 24000, format="WAV")
  307. audio_bytes = buffer.getvalue()
  308. audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
  309. memory_cache_items.append({
  310. "index": global_idx,
  311. "sentence": sentence,
  312. "audio": audio_b64,
  313. "sample_rate": 24000,
  314. "format": "wav"
  315. })
  316. yield json.dumps({
  317. "index": global_idx,
  318. "sentence": sentence,
  319. "audio": audio_b64,
  320. "sample_rate": 24000,
  321. "format": "wav"
  322. }, ensure_ascii=False).encode() + b"\n"
  323. break
  324. except Exception as e:
  325. logger.error(f"处理批次 {i} 时出错: {str(e)}")
  326. yield json.dumps({
  327. "error": f"批次处理失败于索引 {i}",
  328. "message": str(e)
  329. }, ensure_ascii=False).encode() + b"\n"
  330. if memory_cache_items:
  331. await save_to_memory_cache(cache_key, memory_cache_items)
  332. if all_audio:
  333. combined_audio = np.concatenate(all_audio)
  334. await save_to_disk_cache(cache_key, combined_audio)
  335. logger.debug(f"已保存合并音频到磁盘缓存 (key: {cache_key})")
  336. await clean_disk_cache()
  337. except asyncio.CancelledError:
  338. logger.debug(f"客户端 {client_id} 的请求已取消")
  339. except Exception as e:
  340. logger.error(f"流错误: {str(e)}")
  341. finally:
  342. if client_id in current_requests:
  343. del current_requests[client_id]
  344. if not client_check_task.done():
  345. client_check_task.cancel()
  346. logger.debug(f"已清理客户端 {client_id}")
  347. return StreamingResponse(
  348. audio_stream(),
  349. media_type="application/x-ndjson",
  350. headers={
  351. "X-Content-Type-Options": "nosniff",
  352. "Cache-Control": "no-store",
  353. "X-Client-ID": client_id
  354. }
  355. )
  356. @app.get("/clear-cache")
  357. async def clear_cache():
  358. """清除所有缓存"""
  359. try:
  360. with cache_lock:
  361. memory_cache.clear()
  362. loop = asyncio.get_running_loop()
  363. await loop.run_in_executor(
  364. executor,
  365. lambda: [f.unlink() for f in Path(CACHE_DIR).glob("*.wav")]
  366. )
  367. return {"status": "success", "message": "所有缓存已清除"}
  368. except Exception as e:
  369. raise HTTPException(500, f"清除缓存失败: {str(e)}")
  370. @app.get("/cache-info")
  371. async def get_cache_info():
  372. """获取缓存信息"""
  373. try:
  374. with cache_lock:
  375. memory_count = len(memory_cache)
  376. memory_size = sum(len(json.dumps(item)) for item in memory_cache.values())
  377. disk_files = list(Path(CACHE_DIR).glob("*.wav"))
  378. disk_count = len(disk_files)
  379. disk_size = sum(f.stat().st_size for f in disk_files)
  380. return {
  381. "memory_cache": {
  382. "count": memory_count,
  383. "size": memory_size,
  384. "max_size": MEMORY_CACHE_SIZE
  385. },
  386. "disk_cache": {
  387. "count": disk_count,
  388. "size": disk_size,
  389. "max_size": DISK_CACHE_SIZE,
  390. "directory": CACHE_DIR
  391. }
  392. }
  393. except Exception as e:
  394. raise HTTPException(500, f"获取缓存信息失败: {str(e)}")
  395. if __name__ == "__main__":
  396. import uvicorn
  397. uvicorn.run(app, host="0.0.0.0", port=8028)