from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles import os import shutil import hashlib import asyncio from typing import AsyncGenerator import aiohttp import io import logging import base64 import json # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Configure CORS origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Directory for uploaded files UPLOAD_DIRECTORY = "static/files" if not os.path.exists(UPLOAD_DIRECTORY): os.makedirs(UPLOAD_DIRECTORY) # Mount static files app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files") app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web") app.mount("/static", StaticFiles(directory="static"), name="static") # Audio cache directory CACHE_DIR = "audio_cache" os.makedirs(CACHE_DIR, exist_ok=True) # Root redirect to PDF viewer @app.get("/") def root(): return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf") # Sanitize filename def sanitize_filename(name: str) -> str: return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip() # PDF upload endpoint @app.post("/upload-pdf") async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)): if file.content_type != 'application/pdf': raise HTTPException(status_code=400, detail="文件类型必须是 PDF") sanitized_name = sanitize_filename(custom_name) if not sanitized_name: return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"}) unique_filename = f"{sanitized_name}.pdf" file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename) if os.path.exists(file_path): return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"}) try: with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) except Exception as e: raise HTTPException(status_code=500, detail="上传过程中出错") finally: file.file.close() file_relative_path = f"/static/files/{unique_filename}" return JSONResponse(content={"success": True, "file_path": file_relative_path}) # List PDFs endpoint @app.get("/list-pdfs") async def list_pdfs(): try: files = os.listdir(UPLOAD_DIRECTORY) pdf_files = [ {"name": file, "url": f"/static/files/{file}"} for file in files if file.lower().endswith(".pdf") ] return JSONResponse(content={"success": True, "files": pdf_files}) except Exception as e: raise HTTPException(status_code=500, detail="无法获取文件列表") # Request models from pydantic import BaseModel class TextToSpeechRequest(BaseModel): user_input: str voice: str = 'af_heart' # Default voice speed: float = 1.0 # Default speed # Text-to-speech endpoint (streaming) @app.post("/text-to-speech/") async def text_to_speech(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest() audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav") if os.path.exists(audio_path): with open(audio_path, "rb") as f: return Response(content=f.read(), media_type="audio/wav") async def audio_generator() -> AsyncGenerator[bytes, None]: try: async with aiohttp.ClientSession() as session: async with session.post( 'http://141.140.15.30:8028/generate', headers={'Content-Type': 'application/json'}, json={ 'text': user_input, 'voice': request.voice, 'speed': request.speed } ) as response: if response.status != 200: raise HTTPException(status_code=500, detail="TTS API 请求失败") # Read NDJSON response buffer = "" full_audio = io.BytesIO() async for chunk in response.content.iter_any(): buffer += chunk.decode('utf-8') lines = buffer.split('\n') buffer = lines[-1] # Keep incomplete line for line in lines[:-1]: if not line.strip(): continue try: data = json.loads(line) if data.get('error'): raise HTTPException(status_code=500, detail=data['error']) audio_b64 = data.get('audio') if audio_b64: audio_bytes = base64.b64decode(audio_b64) full_audio.write(audio_bytes) yield audio_bytes except json.JSONDecodeError as e: logger.error(f"JSON decode error: {str(e)}") continue # Handle final buffer if buffer.strip(): try: data = json.loads(buffer) if data.get('audio'): audio_bytes = base64.b64decode(data['audio']) full_audio.write(audio_bytes) yield audio_bytes except json.JSONDecodeError: pass # Save to cache full_audio.seek(0) with open(audio_path, "wb") as f: f.write(full_audio.getvalue()) except Exception as e: logger.error(f"TTS error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) return StreamingResponse(audio_generator(), media_type="audio/wav") # Page-to-speech endpoint (chunked streaming) MAX_CHUNK_SIZE = 200 def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list: import re sentences = re.split('(?<=[.!?]) +', text) chunks = [] current_chunk = "" for sentence in sentences: if len(current_chunk) + len(sentence) + 1 <= max_chunk_size: current_chunk += " " + sentence if current_chunk else sentence else: if current_chunk: chunks.append(current_chunk) if len(sentence) > max_chunk_size: for i in range(0, len(sentence), max_chunk_size): chunks.append(sentence[i:i + max_chunk_size]) current_chunk = "" else: current_chunk = sentence if current_chunk: chunks.append(current_chunk) return chunks async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]: text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest() audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav") if os.path.exists(audio_path): with open(audio_path, "rb") as f: yield f.read() else: try: async with aiohttp.ClientSession() as session: async with session.post( 'http://141.140.15.30:8028/generate', headers={'Content-Type': 'application/json'}, json={ 'text': chunk, 'voice': voice, 'speed': speed } ) as response: if response.status != 200: raise HTTPException(status_code=500, detail="TTS API 请求失败") # Read NDJSON response buffer = "" async for chunk in response.content.iter_any(): buffer += chunk.decode('utf-8') lines = buffer.split('\n') buffer = lines[-1] for line in lines[:-1]: if not line.strip(): continue try: data = json.loads(line) if data.get('error'): raise HTTPException(status_code=500, detail=data['error']) audio_b64 = data.get('audio') if audio_b64: audio_bytes = base64.b64decode(audio_b64) yield audio_bytes # Cache the chunk with open(audio_path, "wb") as f: f.write(audio_bytes) except json.JSONDecodeError as e: logger.error(f"JSON decode error: {str(e)}") continue # Handle final buffer if buffer.strip(): try: data = json.loads(buffer) if data.get('audio'): audio_bytes = base64.b64decode(data['audio']) yield audio_bytes with open(audio_path, "wb") as f: f.write(audio_bytes) except json.JSONDecodeError: pass except Exception as e: raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}") @app.post("/page-to-speech/") async def page_to_speech(request: TextToSpeechRequest): user_input = request.user_input.strip() if not user_input: raise HTTPException(status_code=400, detail="输入文本为空") full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest() full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav") if os.path.exists(full_audio_path): return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav") chunks = split_text_into_chunks(user_input) async def audio_generator() -> AsyncGenerator[bytes, None]: full_audio_buffer = io.BytesIO() # For caching full audio for chunk in chunks: async for audio_data in generate_api_audio(chunk, request.voice, request.speed): yield audio_data # Stream each chunk's audio full_audio_buffer.write(audio_data) await asyncio.sleep(0) # Yield control to event loop # Save the full audio to cache full_audio_buffer.seek(0) with open(full_audio_path, "wb") as f: f.write(full_audio_buffer.getvalue()) return StreamingResponse(audio_generator(), media_type="audio/wav") # Health check @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8005)