speech_tts_cpu.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. from fastapi import FastAPI, HTTPException, Body, Request
  2. from fastapi.responses import StreamingResponse
  3. from contextlib import asynccontextmanager
  4. from kokoro import KPipeline
  5. import soundfile as sf
  6. import io
  7. import base64
  8. import re
  9. import logging
  10. import json
  11. import hashlib
  12. import asyncio
  13. import os
  14. import uuid
  15. import aiofiles
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from typing import Dict, List, Optional
  18. from cachetools import TTLCache
  19. import concurrent.futures
  20. import threading
  21. import numpy as np
  22. from pathlib import Path
  23. import sys
  24. os.environ["USE_NNPACK"] = "0"
  25. os.environ["NNPACK_DISABLE"] = "1"
  26. os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
  27. import torch
  28. torch.backends.nnpack.enabled = False
  29. import warnings
  30. warnings.filterwarnings("ignore", message=".*NNPACK.*")
  31. # ------------------- 配置 -------------------
  32. USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
  33. MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
  34. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
  35. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 12))
  36. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
  37. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
  38. logging.basicConfig(level=getattr(logging, LOG_LEVEL))
  39. logger = logging.getLogger(__name__)
  40. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  41. SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
  42. # 内存缓存:单句缓存结构 {cache_key → { "audio": b64, "sr": 24000 }}
  43. memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
  44. cache_lock = threading.Lock()
  45. executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
  46. # 模型(常驻)
  47. model_pipeline = None
  48. model_lock = asyncio.Lock()
  49. # 并发管理
  50. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  51. current_requests: Dict[str, Dict] = {}
  52. # ============================================================
  53. # 模型加载
  54. # ============================================================
  55. async def load_model():
  56. global model_pipeline
  57. async with model_lock:
  58. if model_pipeline is None:
  59. try:
  60. model_pipeline = KPipeline(lang_code='a', device="cpu")
  61. logger.info("模型加载完成 (CPU)")
  62. except Exception as e:
  63. logger.error(f"模型加载失败: {e}")
  64. raise
  65. @asynccontextmanager
  66. async def lifespan(app: FastAPI):
  67. await load_model()
  68. yield
  69. logger.info("应用关闭中...")
  70. app = FastAPI(lifespan=lifespan)
  71. # CORS 允许所有来源
  72. app.add_middleware(
  73. CORSMiddleware,
  74. allow_origins=["*"],
  75. allow_credentials=True,
  76. allow_methods=["*"],
  77. allow_headers=["*"],
  78. )
  79. # ============================================================
  80. # 工具函数
  81. # ============================================================
  82. def split_sentences(text: str) -> List[str]:
  83. return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  84. def sentence_cache_key(sentence: str, voice: str, speed: float) -> str:
  85. raw = f"{sentence}|{voice}|{speed}"
  86. return hashlib.md5(raw.encode()).hexdigest()
  87. def sentence_cache_path(key: str):
  88. return os.path.join(CACHE_DIR, f"{key}.wav")
  89. def disk_meta_path(key: str):
  90. return os.path.join(CACHE_DIR, f"{key}.json")
  91. def put_memory_cache(key: str, value: dict):
  92. with cache_lock:
  93. memory_cache[key] = value
  94. def get_memory_cache(key: str) -> Optional[dict]:
  95. with cache_lock:
  96. return memory_cache.get(key)
  97. async def load_sentence_from_disk(key: str) -> Optional[dict]:
  98. wav_path = sentence_cache_path(key)
  99. meta_path = disk_meta_path(key)
  100. if not Path(wav_path).exists() or not Path(meta_path).exists():
  101. return None
  102. async with aiofiles.open(meta_path, "r") as f:
  103. meta_text = await f.read()
  104. meta = json.loads(meta_text)
  105. audio, sr = await asyncio.get_event_loop().run_in_executor(
  106. executor, lambda: sf.read(wav_path)
  107. )
  108. with io.BytesIO() as buf:
  109. # sf.write(buf, audio, sr, "WAV")
  110. with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1,
  111. format='WAV', subtype='PCM_16') as f:
  112. f.write(audio)
  113. audio_b64 = base64.b64encode(buf.getvalue()).decode()
  114. meta["audio"] = audio_b64
  115. return meta
  116. async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
  117. wav_path = sentence_cache_path(key)
  118. meta_path = disk_meta_path(key)
  119. await asyncio.get_event_loop().run_in_executor(
  120. executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
  121. )
  122. meta = {
  123. "sentence": sentence,
  124. "sample_rate": sr,
  125. }
  126. async with aiofiles.open(meta_path, "w") as f:
  127. await f.write(json.dumps(meta))
  128. return meta
  129. async def clean_disk_cache():
  130. files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
  131. if len(files) <= DISK_CACHE_SIZE:
  132. return
  133. remove_n = len(files) - DISK_CACHE_SIZE
  134. for f in files[:remove_n]:
  135. try:
  136. json_path = disk_meta_path(f.stem)
  137. f.unlink()
  138. if Path(json_path).exists():
  139. Path(json_path).unlink()
  140. except:
  141. pass
  142. async def check_client_alive(client_id):
  143. while True:
  144. await asyncio.sleep(0.3)
  145. if client_id not in current_requests or current_requests[client_id].get("interrupt"):
  146. return False
  147. # ============================================================
  148. # 请求跟踪
  149. # ============================================================
  150. @app.middleware("http")
  151. async def track_clients(request: Request, call_next):
  152. client_id = (
  153. request.query_params.get("client_id")
  154. or request.headers.get("X-Client-ID")
  155. or str(uuid.uuid4())
  156. )
  157. if request.url.path == "/generate":
  158. current_requests[client_id] = {"active": True}
  159. response = await call_next(request)
  160. if request.url.path == "/generate":
  161. current_requests.pop(client_id, None)
  162. return response
  163. # ============================================================
  164. # 主接口
  165. # ============================================================
  166. @app.post("/generate")
  167. async def generate_audio_stream(data: Dict = Body(...)):
  168. async with request_semaphore:
  169. text = data.get("text", "")
  170. voice = data.get("voice", "af_heart")
  171. speed = float(data.get("speed", 1.0))
  172. client_id = data.get("client_id", str(uuid.uuid4()))
  173. if not text:
  174. raise HTTPException(400, "文本不能为空")
  175. await load_model()
  176. # 取消旧请求
  177. if client_id in current_requests:
  178. current_requests[client_id]["interrupt"] = True
  179. await asyncio.sleep(0.05)
  180. sentences = split_sentences(text)
  181. current_requests[client_id] = {"interrupt": False}
  182. async def stream():
  183. client_alive_task = asyncio.create_task(check_client_alive(client_id))
  184. try:
  185. for idx, sentence in enumerate(sentences):
  186. if client_alive_task.done():
  187. break
  188. key = sentence_cache_key(sentence, voice, speed)
  189. # 1) 内存缓存
  190. cache_item = get_memory_cache(key)
  191. if cache_item:
  192. yield json.dumps({
  193. "index": idx,
  194. **cache_item
  195. }).encode() + b"\n"
  196. continue
  197. # 2) 磁盘缓存
  198. disk_item = await load_sentence_from_disk(key)
  199. if disk_item:
  200. put_memory_cache(key, disk_item)
  201. yield json.dumps({
  202. "index": idx,
  203. **disk_item
  204. }).encode() + b"\n"
  205. continue
  206. # 3) 无缓存 → 推理
  207. generator = model_pipeline(sentence, voice=voice, speed=speed)
  208. # 仅取第一个 chunk
  209. for _, _, audio in generator:
  210. sr = 24000
  211. with io.BytesIO() as buf:
  212. # sf.write(buf, audio, sr, "WAV")
  213. with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1,
  214. format='WAV', subtype='PCM_16') as f:
  215. f.write(audio)
  216. audio_b64 = base64.b64encode(buf.getvalue()).decode()
  217. item = {
  218. "sentence": sentence,
  219. "audio": audio_b64,
  220. "sample_rate": sr
  221. }
  222. put_memory_cache(key, item)
  223. # 保存磁盘缓存
  224. await save_sentence_to_disk(key, audio, sr, sentence)
  225. await clean_disk_cache()
  226. yield json.dumps({
  227. "index": idx,
  228. **item
  229. }).encode() + b"\n"
  230. break
  231. finally:
  232. current_requests.pop(client_id, None)
  233. if not client_alive_task.done():
  234. client_alive_task.cancel()
  235. return StreamingResponse(stream(), media_type="application/x-ndjson")
  236. # ============================================================
  237. # 其它接口
  238. # ============================================================
  239. @app.get("/clear-cache")
  240. async def clear_cache():
  241. with cache_lock:
  242. memory_cache.clear()
  243. for f in Path(CACHE_DIR).glob("*"):
  244. f.unlink()
  245. return {"status": "success"}
  246. @app.get("/cache-info")
  247. async def get_cache_info():
  248. with cache_lock:
  249. mem_count = len(memory_cache)
  250. disk_files = list(Path(CACHE_DIR).glob("*.wav"))
  251. return {
  252. "memory_cache": mem_count,
  253. "disk_cache": len(disk_files)
  254. }
  255. # ============================================================
  256. # 启动
  257. # ============================================================
  258. if __name__ == "__main__":
  259. import uvicorn
  260. uvicorn.run(app, host="0.0.0.0", port=8028)