speech_tts_cpu.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from fastapi import FastAPI, HTTPException, Body, Request
  2. from fastapi.responses import StreamingResponse
  3. from contextlib import asynccontextmanager
  4. from kokoro import KPipeline
  5. import soundfile as sf
  6. import io
  7. import base64
  8. import re
  9. import logging
  10. import json
  11. import hashlib
  12. import asyncio
  13. import os
  14. import uuid
  15. import aiofiles
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from typing import Dict, List, Optional
  18. from cachetools import TTLCache
  19. import concurrent.futures
  20. import threading
  21. import numpy as np
  22. from pathlib import Path
  23. import sys
  24. os.environ["USE_NNPACK"] = "0"
  25. os.environ["NNPACK_DISABLE"] = "1"
  26. import torch
  27. torch.backends.nnpack.enabled = False
  28. import warnings
  29. warnings.filterwarnings("ignore", message=".*NNPACK.*")
  30. # ------------------- 配置 -------------------
  31. USE_MP3 = os.getenv("USE_MP3", "false").lower() == "true"
  32. MEMORY_CACHE_SIZE = int(os.getenv("MEMORY_CACHE_SIZE", 200))
  33. DISK_CACHE_SIZE = int(os.getenv("DISK_CACHE_SIZE", 500))
  34. MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", 12))
  35. LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
  36. CACHE_DIR = os.getenv("CACHE_DIR", "./audio_cache")
  37. logging.basicConfig(level=getattr(logging, LOG_LEVEL))
  38. logger = logging.getLogger(__name__)
  39. Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
  40. SENT_SPLIT_RE = re.compile(r'(?<=[.!?,:])\s+')
  41. # 内存缓存:单句缓存结构 {cache_key → { "audio": b64, "sr": 24000 }}
  42. memory_cache = TTLCache(maxsize=MEMORY_CACHE_SIZE, ttl=18000)
  43. cache_lock = threading.Lock()
  44. executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
  45. # 模型(常驻)
  46. model_pipeline = None
  47. model_lock = asyncio.Lock()
  48. # 并发管理
  49. request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  50. current_requests: Dict[str, Dict] = {}
  51. # ============================================================
  52. # 模型加载
  53. # ============================================================
  54. async def load_model():
  55. global model_pipeline
  56. async with model_lock:
  57. if model_pipeline is None:
  58. try:
  59. model_pipeline = KPipeline(lang_code='a', device="cpu")
  60. logger.info("模型加载完成 (CPU)")
  61. except Exception as e:
  62. logger.error(f"模型加载失败: {e}")
  63. raise
  64. @asynccontextmanager
  65. async def lifespan(app: FastAPI):
  66. await load_model()
  67. yield
  68. logger.info("应用关闭中...")
  69. app = FastAPI(lifespan=lifespan)
  70. # CORS 允许所有来源
  71. app.add_middleware(
  72. CORSMiddleware,
  73. allow_origins=["*"],
  74. allow_credentials=True,
  75. allow_methods=["*"],
  76. allow_headers=["*"],
  77. )
  78. # ============================================================
  79. # 工具函数
  80. # ============================================================
  81. def split_sentences(text: str) -> List[str]:
  82. return [s.strip() for s in SENT_SPLIT_RE.split(text.strip()) if s.strip()]
  83. def sentence_cache_key(sentence: str, voice: str, speed: float) -> str:
  84. raw = f"{sentence}|{voice}|{speed}"
  85. return hashlib.md5(raw.encode()).hexdigest()
  86. def sentence_cache_path(key: str):
  87. return os.path.join(CACHE_DIR, f"{key}.wav")
  88. def disk_meta_path(key: str):
  89. return os.path.join(CACHE_DIR, f"{key}.json")
  90. def put_memory_cache(key: str, value: dict):
  91. with cache_lock:
  92. memory_cache[key] = value
  93. def get_memory_cache(key: str) -> Optional[dict]:
  94. with cache_lock:
  95. return memory_cache.get(key)
  96. async def load_sentence_from_disk(key: str) -> Optional[dict]:
  97. wav_path = sentence_cache_path(key)
  98. meta_path = disk_meta_path(key)
  99. if not Path(wav_path).exists() or not Path(meta_path).exists():
  100. return None
  101. async with aiofiles.open(meta_path, "r") as f:
  102. meta_text = await f.read()
  103. meta = json.loads(meta_text)
  104. audio, sr = await asyncio.get_event_loop().run_in_executor(
  105. executor, lambda: sf.read(wav_path)
  106. )
  107. with io.BytesIO() as buf:
  108. # sf.write(buf, audio, sr, "WAV")
  109. with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1,
  110. format='WAV', subtype='PCM_16') as f:
  111. f.write(audio)
  112. audio_b64 = base64.b64encode(buf.getvalue()).decode()
  113. meta["audio"] = audio_b64
  114. return meta
  115. async def save_sentence_to_disk(key: str, audio: np.ndarray, sr: int, sentence: str):
  116. wav_path = sentence_cache_path(key)
  117. meta_path = disk_meta_path(key)
  118. await asyncio.get_event_loop().run_in_executor(
  119. executor, lambda: sf.write(wav_path, audio, sr, format="WAV")
  120. )
  121. meta = {
  122. "sentence": sentence,
  123. "sample_rate": sr,
  124. }
  125. async with aiofiles.open(meta_path, "w") as f:
  126. await f.write(json.dumps(meta))
  127. return meta
  128. async def clean_disk_cache():
  129. files = sorted(Path(CACHE_DIR).glob("*.wav"), key=os.path.getmtime)
  130. if len(files) <= DISK_CACHE_SIZE:
  131. return
  132. remove_n = len(files) - DISK_CACHE_SIZE
  133. for f in files[:remove_n]:
  134. try:
  135. json_path = disk_meta_path(f.stem)
  136. f.unlink()
  137. if Path(json_path).exists():
  138. Path(json_path).unlink()
  139. except:
  140. pass
  141. async def check_client_alive(client_id):
  142. while True:
  143. await asyncio.sleep(0.3)
  144. if client_id not in current_requests or current_requests[client_id].get("interrupt"):
  145. return False
  146. # ============================================================
  147. # 请求跟踪
  148. # ============================================================
  149. @app.middleware("http")
  150. async def track_clients(request: Request, call_next):
  151. client_id = (
  152. request.query_params.get("client_id")
  153. or request.headers.get("X-Client-ID")
  154. or str(uuid.uuid4())
  155. )
  156. if request.url.path == "/generate":
  157. current_requests[client_id] = {"active": True}
  158. response = await call_next(request)
  159. if request.url.path == "/generate":
  160. current_requests.pop(client_id, None)
  161. return response
  162. # ============================================================
  163. # 主接口
  164. # ============================================================
  165. @app.post("/generate")
  166. async def generate_audio_stream(data: Dict = Body(...)):
  167. async with request_semaphore:
  168. text = data.get("text", "")
  169. voice = data.get("voice", "af_heart")
  170. speed = float(data.get("speed", 1.0))
  171. client_id = data.get("client_id", str(uuid.uuid4()))
  172. if not text:
  173. raise HTTPException(400, "文本不能为空")
  174. await load_model()
  175. # 取消旧请求
  176. if client_id in current_requests:
  177. current_requests[client_id]["interrupt"] = True
  178. await asyncio.sleep(0.05)
  179. sentences = split_sentences(text)
  180. current_requests[client_id] = {"interrupt": False}
  181. async def stream():
  182. client_alive_task = asyncio.create_task(check_client_alive(client_id))
  183. try:
  184. for idx, sentence in enumerate(sentences):
  185. if client_alive_task.done():
  186. break
  187. key = sentence_cache_key(sentence, voice, speed)
  188. # 1) 内存缓存
  189. cache_item = get_memory_cache(key)
  190. if cache_item:
  191. yield json.dumps({
  192. "index": idx,
  193. **cache_item
  194. }).encode() + b"\n"
  195. continue
  196. # 2) 磁盘缓存
  197. disk_item = await load_sentence_from_disk(key)
  198. if disk_item:
  199. put_memory_cache(key, disk_item)
  200. yield json.dumps({
  201. "index": idx,
  202. **disk_item
  203. }).encode() + b"\n"
  204. continue
  205. # 3) 无缓存 → 推理
  206. generator = model_pipeline(sentence, voice=voice, speed=speed)
  207. # 仅取第一个 chunk
  208. for _, _, audio in generator:
  209. sr = 24000
  210. with io.BytesIO() as buf:
  211. # sf.write(buf, audio, sr, "WAV")
  212. with sf.SoundFile(buf, 'w', sr, channels=audio.shape[1] if audio.ndim > 1 else 1,
  213. format='WAV', subtype='PCM_16') as f:
  214. f.write(audio)
  215. audio_b64 = base64.b64encode(buf.getvalue()).decode()
  216. item = {
  217. "sentence": sentence,
  218. "audio": audio_b64,
  219. "sample_rate": sr
  220. }
  221. put_memory_cache(key, item)
  222. # 保存磁盘缓存
  223. await save_sentence_to_disk(key, audio, sr, sentence)
  224. await clean_disk_cache()
  225. yield json.dumps({
  226. "index": idx,
  227. **item
  228. }).encode() + b"\n"
  229. break
  230. finally:
  231. current_requests.pop(client_id, None)
  232. if not client_alive_task.done():
  233. client_alive_task.cancel()
  234. return StreamingResponse(stream(), media_type="application/x-ndjson")
  235. # ============================================================
  236. # 其它接口
  237. # ============================================================
  238. @app.get("/clear-cache")
  239. async def clear_cache():
  240. with cache_lock:
  241. memory_cache.clear()
  242. for f in Path(CACHE_DIR).glob("*"):
  243. f.unlink()
  244. return {"status": "success"}
  245. @app.get("/cache-info")
  246. async def get_cache_info():
  247. with cache_lock:
  248. mem_count = len(memory_cache)
  249. disk_files = list(Path(CACHE_DIR).glob("*.wav"))
  250. return {
  251. "memory_cache": mem_count,
  252. "disk_cache": len(disk_files)
  253. }
  254. # ============================================================
  255. # 启动
  256. # ============================================================
  257. if __name__ == "__main__":
  258. import uvicorn
  259. uvicorn.run(app, host="0.0.0.0", port=8028)