# from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response # from fastapi.middleware.cors import CORSMiddleware # from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse # from fastapi.staticfiles import StaticFiles # import os # import shutil # import uuid # from pydantic import BaseModel # import hashlib # import asyncio # from typing import AsyncGenerator # import soundfile as sf # import io # import logging # from kokoro import KPipeline # Assuming kokoro is installed and available # # Set up logging # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) # # Initialize FastAPI app # app = FastAPI() # # Configure CORS # origins = ["*"] # app.add_middleware( # CORSMiddleware, # allow_origins=origins, # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # # Directory for uploaded files # UPLOAD_DIRECTORY = "static/files" # if not os.path.exists(UPLOAD_DIRECTORY): # os.makedirs(UPLOAD_DIRECTORY) # # Mount static files # app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files") # app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web") # app.mount("/static", StaticFiles(directory="static"), name="static") # # Audio cache directory # CACHE_DIR = "audio_cache" # os.makedirs(CACHE_DIR, exist_ok=True) # # Root redirect to PDF viewer # @app.get("/") # def root(): # return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf") # # Sanitize filename # def sanitize_filename(name: str) -> str: # return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip() # # PDF upload endpoint # @app.post("/upload-pdf") # async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)): # if file.content_type != 'application/pdf': # raise HTTPException(status_code=400, detail="文件类型必须是 PDF") # sanitized_name = sanitize_filename(custom_name) # if not sanitized_name: # return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"}) # unique_filename = f"{sanitized_name}.pdf" # file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename) # if os.path.exists(file_path): # return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"}) # try: # with open(file_path, "wb") as buffer: # shutil.copyfileobj(file.file, buffer) # except Exception as e: # raise HTTPException(status_code=500, detail="上传过程中出错") # finally: # file.file.close() # file_relative_path = f"/static/files/{unique_filename}" # return JSONResponse(content={"success": True, "file_path": file_relative_path}) # # List PDFs endpoint # @app.get("/list-pdfs") # async def list_pdfs(): # try: # files = os.listdir(UPLOAD_DIRECTORY) # pdf_files = [ # {"name": file, "url": f"/static/files/{file}"} # for file in files if file.lower().endswith(".pdf") # ] # return JSONResponse(content={"success": True, "files": pdf_files}) # except Exception as e: # raise HTTPException(status_code=500, detail="无法获取文件列表") # # TTS Server with Kokoro # class TextToSpeechServer: # def __init__(self): # self.pipeline = None # def load_model(self, lang_code='a'): # try: # logger.info("Loading KPipeline model...") # self.pipeline = KPipeline(lang_code=lang_code) # logger.info("Model loaded successfully") # except Exception as e: # logger.error(f"Failed to load model: {str(e)}") # raise # # Initialize TTS server # tts_server = TextToSpeechServer() # # Startup event to load Kokoro model # @app.on_event("startup") # async def startup_event(): # tts_server.load_model() # # Request models # class TextToSpeechRequest(BaseModel): # user_input: str # voice: str = 'af_heart' # Default voice for Kokoro # speed: float = 1.0 # Default speed # # Text-to-speech endpoint (streaming) # @app.post("/text-to-speech/") # async def text_to_speech(request: TextToSpeechRequest): # user_input = request.user_input.strip() # if not user_input: # raise HTTPException(status_code=400, detail="输入文本为空") # text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest() # audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav") # if os.path.exists(audio_path): # with open(audio_path, "rb") as f: # return Response(content=f.read(), media_type="audio/wav") # async def audio_generator() -> AsyncGenerator[bytes, None]: # try: # if not tts_server.pipeline: # raise HTTPException(status_code=503, detail="Model not initialized") # print(user_input) # generator = tts_server.pipeline( # text=user_input, # voice=request.voice, # speed=request.speed, # split_pattern=r'\n+' # ) # # 用于拼接所有音频数据的 NumPy 数组 # full_audio_data = [] # for i, (gs, ps, audio) in enumerate(generator): # print(f"Generating segment {i}") # full_audio_data.append(audio) # 假设 audio 是 NumPy 数组 # # 将所有音频片段拼接成一个完整的音频 # import numpy as np # concatenated_audio = np.concatenate(full_audio_data) # # 将拼接后的音频写入 WAV 文件 # buffer = io.BytesIO() # sf.write(buffer, concatenated_audio, 24000, format='WAV') # buffer.seek(0) # audio_data = buffer.getvalue() # # 流式传输整个音频 # yield audio_data # # 保存到缓存 # with open(audio_path, "wb") as f: # f.write(audio_data) # except Exception as e: # logger.error(f"TTS error: {str(e)}") # raise HTTPException(status_code=500, detail=str(e)) # return StreamingResponse(audio_generator(), media_type="audio/wav") # # Page-to-speech endpoint (chunked streaming) # MAX_CHUNK_SIZE = 200 # def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list: # import re # sentences = re.split('(?<=[.!?]) +', text) # chunks = [] # current_chunk = "" # for sentence in sentences: # if len(current_chunk) + len(sentence) + 1 <= max_chunk_size: # current_chunk += " " + sentence if current_chunk else sentence # else: # if current_chunk: # chunks.append(current_chunk) # if len(sentence) > max_chunk_size: # for i in range(0, len(sentence), max_chunk_size): # chunks.append(sentence[i:i + max_chunk_size]) # current_chunk = "" # else: # current_chunk = sentence # if current_chunk: # chunks.append(current_chunk) # return chunks # async def generate_kokoro_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]: # text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() # audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav") # if os.path.exists(audio_path): # with open(audio_path, "rb") as f: # yield f.read() # else: # try: # if not tts_server.pipeline: # raise HTTPException(status_code=503, detail="Model not initialized") # generator = tts_server.pipeline( # text=chunk, # voice=voice, # speed=speed, # split_pattern=r'\n+' # ) # full_audio_buffer = io.BytesIO() # For caching # for i, (gs, ps, audio) in enumerate(generator): # buffer = io.BytesIO() # sf.write(buffer, audio, 24000, format='WAV') # buffer.seek(0) # audio_data = buffer.getvalue() # yield audio_data # Stream immediately # full_audio_buffer.write(audio_data) # break # Take first segment (adjust if multiple segments needed) # # Cache the chunk # full_audio_buffer.seek(0) # with open(audio_path, "wb") as f: # f.write(full_audio_buffer.getvalue()) # except Exception as e: # raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}") # @app.post("/page-to-speech/") # async def page_to_speech(request: TextToSpeechRequest): # user_input = request.user_input.strip() # if not user_input: # raise HTTPException(status_code=400, detail="输入文本为空") # full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest() # full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav") # if os.path.exists(full_audio_path): # return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav") # chunks = split_text_into_chunks(user_input) # async def audio_generator() -> AsyncGenerator[bytes, None]: # full_audio_buffer = io.BytesIO() # For caching full audio # for chunk in chunks: # async for audio_data in generate_kokoro_audio(chunk, request.voice, request.speed): # yield audio_data # Stream each chunk's audio # full_audio_buffer.write(audio_data) # await asyncio.sleep(0) # Yield control to event loop # # Save the full audio to cache # full_audio_buffer.seek(0) # with open(full_audio_path, "wb") as f: # f.write(full_audio_buffer.getvalue()) # return StreamingResponse(audio_generator(), media_type="audio/wav") # # Health check # @app.get("/health") # async def health_check(): # return {"status": "healthy" if tts_server.pipeline else "model_not_loaded"} # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=8005) from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles import os import shutil import uuid from pydantic import BaseModel import hashlib import asyncio from typing import AsyncGenerator import soundfile as sf import io import logging import numpy as np import re from kokoro import KPipeline # 假设 kokoro 已安装并可用 # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 初始化 FastAPI 应用 app = FastAPI() # 配置 CORS origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 上传文件目录 UPLOAD_DIRECTORY = "static/files" if not os.path.exists(UPLOAD_DIRECTORY): os.makedirs(UPLOAD_DIRECTORY) # 挂载静态文件 app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files") app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web") app.mount("/static", StaticFiles(directory="static"), name="static") # 音频缓存目录 CACHE_DIR = "audio_cache" os.makedirs(CACHE_DIR, exist_ok=True) # 根路径重定向到 PDF 查看器 @app.get("/") def root(): return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf") # 清理文件名 def sanitize_filename(name: str) -> str: return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip() # PDF 上传端点 @app.post("/upload-pdf") async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)): if file.content_type != 'application/pdf': raise HTTPException(status_code=400, detail="文件类型必须是 PDF") sanitized_name = sanitize_filename(custom_name) if not sanitized_name: return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"}) unique_filename = f"{sanitized_name}.pdf" file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename) if os.path.exists(file_path): return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"}) try: with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) except Exception as e: raise HTTPException(status_code=500, detail="上传过程中出错") finally: file.file.close() file_relative_path = f"/static/files/{unique_filename}" return JSONResponse(content={"success": True, "file_path": file_relative_path}) # 列出 PDF 文件端点 @app.get("/list-pdfs") async def list_pdfs(): try: files = os.listdir(UPLOAD_DIRECTORY) pdf_files = [ {"name": file, "url": f"/static/files/{file}"} for file in files if file.lower().endswith(".pdf") ] return JSONResponse(content={"success": True, "files": pdf_files}) except Exception as e: raise HTTPException(status_code=500, detail="无法获取文件列表") # TTS 服务类 class TextToSpeechServer: def __init__(self): self.pipeline = None def load_model(self, lang_code='a'): try: logger.info("加载 KPipeline 模型...") self.pipeline = KPipeline(lang_code=lang_code) logger.info("模型加载成功") except Exception as e: logger.error(f"模型加载失败: {str(e)}") raise # 初始化 TTS 服务 tts_server = TextToSpeechServer() # 应用启动时加载 Kokoro 模型 @app.on_event("startup") async def startup_event(): tts_server.load_model() # 请求模型 class TextToSpeechRequest(BaseModel): user_input: str voice: str = 'af_heart' # 默认语音 speed: float = 1.0 # 默认速度 # 文本转语音端点(流式) @app.post("/text-to-speech/") async def text_to_speech(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest() audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav") if os.path.exists(audio_path): with open(audio_path, "rb") as f: return Response(content=f.read(), media_type="audio/wav") async def audio_generator() -> AsyncGenerator[bytes, None]: try: if not tts_server.pipeline: raise HTTPException(status_code=503, detail="模型未初始化") generator = tts_server.pipeline( text=user_input, voice=request.voice, speed=request.speed, split_pattern=r'\n+' ) full_audio_data = [] for i, (gs, ps, audio) in enumerate(generator): full_audio_data.append(audio) concatenated_audio = np.concatenate(full_audio_data) buffer = io.BytesIO() sf.write(buffer, concatenated_audio, 24000, format='WAV') buffer.seek(0) audio_data = buffer.getvalue() yield audio_data with open(audio_path, "wb") as f: f.write(audio_data) except Exception as e: logger.error(f"TTS 错误: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) return StreamingResponse(audio_generator(), media_type="audio/wav") # 按句子分割文本 def split_text_into_sentences(text: str) -> list: # 使用正则表达式按句号、问号、感叹号分割句子 sentences = re.split(r'(?<=[.!?])\s+', text.strip()) return [s.strip() for s in sentences if s.strip()] # 生成单句音频 async def generate_kokoro_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]: text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav") if os.path.exists(audio_path): with open(audio_path, "rb") as f: yield f.read() else: try: if not tts_server.pipeline: raise HTTPException(status_code=503, detail="模型未初始化") generator = tts_server.pipeline( text=chunk, voice=voice, speed=speed, split_pattern=r'\n+' ) full_audio_buffer = io.BytesIO() for i, (gs, ps, audio) in enumerate(generator): buffer = io.BytesIO() sf.write(buffer, audio, 24000, format='WAV') buffer.seek(0) audio_data = buffer.getvalue() yield audio_data full_audio_buffer.write(audio_data) break # 仅取第一个片段 full_audio_buffer.seek(0) with open(audio_path, "wb") as f: f.write(full_audio_buffer.getvalue()) except Exception as e: raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}") # 页面转语音端点(按句子逐句转换并播放) @app.post("/page-to-speech/") async def page_to_speech(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest() full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav") if os.path.exists(full_audio_path): return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav") sentences = split_text_into_sentences(user_input) if not sentences: raise HTTPException(status_code=400, detail="没有有效的句子") async def audio_generator() -> AsyncGenerator[bytes, None]: full_audio_buffer = io.BytesIO() # 用于缓存完整音频 for sentence in sentences: logger.info(f"处理句子: {sentence}") async for audio_data in generate_kokoro_audio(sentence, request.voice, request.speed): yield audio_data # 立即流式传输当前句子的音频 full_audio_buffer.write(audio_data) await asyncio.sleep(0) # 让出控制权给事件循环 # 保存完整音频到缓存 full_audio_buffer.seek(0) with open(full_audio_path, "wb") as f: f.write(full_audio_buffer.getvalue()) return StreamingResponse(audio_generator(), media_type="audio/wav") # 健康检查 @app.get("/health") async def health_check(): return {"status": "healthy" if tts_server.pipeline else "model_not_loaded"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8005)