| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- import os
- import shutil
- import hashlib
- import asyncio
- from typing import AsyncGenerator
- import aiohttp
- import io
- import logging
- import base64
- import json
- # Set up logging
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- # Initialize FastAPI app
- app = FastAPI()
- # Configure CORS
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # Directory for uploaded files
- UPLOAD_DIRECTORY = "static/files"
- if not os.path.exists(UPLOAD_DIRECTORY):
- os.makedirs(UPLOAD_DIRECTORY)
- # Mount static files
- app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
- app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
- app.mount("/static", StaticFiles(directory="static"), name="static")
- # Audio cache directory
- CACHE_DIR = "audio_cache"
- os.makedirs(CACHE_DIR, exist_ok=True)
- # Root redirect to PDF viewer
- @app.get("/")
- def root():
- return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
- # Sanitize filename
- def sanitize_filename(name: str) -> str:
- return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
- # PDF upload endpoint
- @app.post("/upload-pdf")
- async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
- if file.content_type != 'application/pdf':
- raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
- sanitized_name = sanitize_filename(custom_name)
- if not sanitized_name:
- return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
-
- unique_filename = f"{sanitized_name}.pdf"
- file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
- if os.path.exists(file_path):
- return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
- try:
- with open(file_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
- except Exception as e:
- raise HTTPException(status_code=500, detail="上传过程中出错")
- finally:
- file.file.close()
- file_relative_path = f"/static/files/{unique_filename}"
- return JSONResponse(content={"success": True, "file_path": file_relative_path})
- # List PDFs endpoint
- @app.get("/list-pdfs")
- async def list_pdfs():
- try:
- files = os.listdir(UPLOAD_DIRECTORY)
- pdf_files = [
- {"name": file, "url": f"/static/files/{file}"}
- for file in files if file.lower().endswith(".pdf")
- ]
- return JSONResponse(content={"success": True, "files": pdf_files})
- except Exception as e:
- raise HTTPException(status_code=500, detail="无法获取文件列表")
- # Request models
- from pydantic import BaseModel
- class TextToSpeechRequest(BaseModel):
- user_input: str
- voice: str = 'af_heart' # Default voice
- speed: float = 1.0 # Default speed
- # Text-to-speech endpoint (streaming)
- @app.post("/text-to-speech/")
- async def text_to_speech(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
-
- text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
- if os.path.exists(audio_path):
- with open(audio_path, "rb") as f:
- return Response(content=f.read(), media_type="audio/wav")
-
- async def audio_generator() -> AsyncGenerator[bytes, None]:
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- 'http://141.140.15.30:8028/generate',
- headers={'Content-Type': 'application/json'},
- json={
- 'text': user_input,
- 'voice': request.voice,
- 'speed': request.speed
- }
- ) as response:
- if response.status != 200:
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
-
- # Read NDJSON response
- buffer = ""
- full_audio = io.BytesIO()
- async for chunk in response.content.iter_any():
- buffer += chunk.decode('utf-8')
- lines = buffer.split('\n')
- buffer = lines[-1] # Keep incomplete line
-
- for line in lines[:-1]:
- if not line.strip():
- continue
- try:
- data = json.loads(line)
- if data.get('error'):
- raise HTTPException(status_code=500, detail=data['error'])
- audio_b64 = data.get('audio')
- if audio_b64:
- audio_bytes = base64.b64decode(audio_b64)
- full_audio.write(audio_bytes)
- yield audio_bytes
- except json.JSONDecodeError as e:
- logger.error(f"JSON decode error: {str(e)}")
- continue
-
- # Handle final buffer
- if buffer.strip():
- try:
- data = json.loads(buffer)
- if data.get('audio'):
- audio_bytes = base64.b64decode(data['audio'])
- full_audio.write(audio_bytes)
- yield audio_bytes
- except json.JSONDecodeError:
- pass
-
- # Save to cache
- full_audio.seek(0)
- with open(audio_path, "wb") as f:
- f.write(full_audio.getvalue())
-
- except Exception as e:
- logger.error(f"TTS error: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- return StreamingResponse(audio_generator(), media_type="audio/wav")
- # Page-to-speech endpoint (chunked streaming)
- MAX_CHUNK_SIZE = 200
- def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
- import re
- sentences = re.split('(?<=[.!?]) +', text)
- chunks = []
- current_chunk = ""
- for sentence in sentences:
- if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
- current_chunk += " " + sentence if current_chunk else sentence
- else:
- if current_chunk:
- chunks.append(current_chunk)
- if len(sentence) > max_chunk_size:
- for i in range(0, len(sentence), max_chunk_size):
- chunks.append(sentence[i:i + max_chunk_size])
- current_chunk = ""
- else:
- current_chunk = sentence
- if current_chunk:
- chunks.append(current_chunk)
- return chunks
- async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
- text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
-
- if os.path.exists(audio_path):
- with open(audio_path, "rb") as f:
- yield f.read()
- else:
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- 'http://141.140.15.30:8028/generate',
- headers={'Content-Type': 'application/json'},
- json={
- 'text': chunk,
- 'voice': voice,
- 'speed': speed
- }
- ) as response:
- if response.status != 200:
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
-
- # Read NDJSON response
- buffer = ""
- async for chunk in response.content.iter_any():
- buffer += chunk.decode('utf-8')
- lines = buffer.split('\n')
- buffer = lines[-1]
-
- for line in lines[:-1]:
- if not line.strip():
- continue
- try:
- data = json.loads(line)
- if data.get('error'):
- raise HTTPException(status_code=500, detail=data['error'])
- audio_b64 = data.get('audio')
- if audio_b64:
- audio_bytes = base64.b64decode(audio_b64)
- yield audio_bytes
- # Cache the chunk
- with open(audio_path, "wb") as f:
- f.write(audio_bytes)
- except json.JSONDecodeError as e:
- logger.error(f"JSON decode error: {str(e)}")
- continue
-
- # Handle final buffer
- if buffer.strip():
- try:
- data = json.loads(buffer)
- if data.get('audio'):
- audio_bytes = base64.b64decode(data['audio'])
- yield audio_bytes
- with open(audio_path, "wb") as f:
- f.write(audio_bytes)
- except json.JSONDecodeError:
- pass
-
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
- @app.post("/page-to-speech/")
- async def page_to_speech(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
-
- full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
- full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
-
- if os.path.exists(full_audio_path):
- return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
-
- chunks = split_text_into_chunks(user_input)
- async def audio_generator() -> AsyncGenerator[bytes, None]:
- full_audio_buffer = io.BytesIO() # For caching full audio
- for chunk in chunks:
- async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
- yield audio_data # Stream each chunk's audio
- full_audio_buffer.write(audio_data)
- await asyncio.sleep(0) # Yield control to event loop
-
- # Save the full audio to cache
- full_audio_buffer.seek(0)
- with open(full_audio_path, "wb") as f:
- f.write(full_audio_buffer.getvalue())
- return StreamingResponse(audio_generator(), media_type="audio/wav")
- # Health check
- @app.get("/health")
- async def health_check():
- return {"status": "healthy"}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8005)
|