|
|
@@ -21,6 +21,17 @@ import pymysql
|
|
|
|
|
|
from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
|
|
|
|
|
|
+OPENAI_TTS_BASE_URL = os.getenv("OPENAI_TTS_BASE_URL", "https://api.aimanyi.top")
|
|
|
+OPENAI_TTS_API_KEY = os.getenv(
|
|
|
+ "OPENAI_TTS_API_KEY",
|
|
|
+ "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa",
|
|
|
+)
|
|
|
+OPENAI_TTS_MODEL = os.getenv("OPENAI_TTS_MODEL", "gpt-4o-mini-tts")
|
|
|
+OPENAI_TTS_DEFAULT_VOICE = os.getenv("OPENAI_TTS_DEFAULT_VOICE", "sage")
|
|
|
+OPENAI_TTS_FORMAT = os.getenv("OPENAI_TTS_FORMAT", "wav")
|
|
|
+CLIENT_COOKIE = "reader_pro_client"
|
|
|
+PROGRESS_FILE = "reading_progress.json"
|
|
|
+
|
|
|
# Set up logging
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
@@ -145,6 +156,31 @@ def sanitize_filename(name: str) -> str:
|
|
|
return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
|
|
|
|
|
|
|
|
|
+def build_file_url(filename: str) -> str:
|
|
|
+ return f"/static/files/{filename}"
|
|
|
+
|
|
|
+
|
|
|
+def get_or_create_client_id(client_id: Optional[str]) -> str:
|
|
|
+ normalized = (client_id or "").strip()
|
|
|
+ return normalized or secrets.token_urlsafe(24)
|
|
|
+
|
|
|
+
|
|
|
+def load_progress_store() -> dict:
|
|
|
+ if not os.path.exists(PROGRESS_FILE):
|
|
|
+ return {}
|
|
|
+ try:
|
|
|
+ with open(PROGRESS_FILE, "r", encoding="utf-8") as f:
|
|
|
+ data = json.load(f)
|
|
|
+ return data if isinstance(data, dict) else {}
|
|
|
+ except Exception:
|
|
|
+ return {}
|
|
|
+
|
|
|
+
|
|
|
+def save_progress_store(data: dict) -> None:
|
|
|
+ with open(PROGRESS_FILE, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+
|
|
|
def get_user_dir(username: str) -> str:
|
|
|
safe = sanitize_filename(username) or "user"
|
|
|
path = os.path.join(BASE_STATIC_FILES_DIR, safe)
|
|
|
@@ -217,24 +253,25 @@ def on_startup():
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
-def root(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
- user = get_user_by_session(session_token)
|
|
|
- if not user:
|
|
|
- return RedirectResponse(url="/login", status_code=302)
|
|
|
+def root(client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
|
|
|
+ current_client_id = get_or_create_client_id(client_id)
|
|
|
+ progress = load_progress_store().get(current_client_id, {})
|
|
|
+ last_file = (progress.get("last_file") or "").strip()
|
|
|
|
|
|
- last_file = user.get("last_file")
|
|
|
if last_file:
|
|
|
- return RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
|
|
|
-
|
|
|
- # fallback: first file in user's own directory
|
|
|
- user_dir = get_user_dir(user["username"])
|
|
|
- files = sorted([f for f in os.listdir(user_dir) if f.lower().endswith(".pdf")])
|
|
|
- if files:
|
|
|
- file_url = build_user_file_url(user["username"], files[0])
|
|
|
- return RedirectResponse(url=f"/static/web/viewer.html?file={file_url}", status_code=302)
|
|
|
+ response = RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
|
|
|
+ else:
|
|
|
+ files = sorted([f for f in os.listdir(BASE_STATIC_FILES_DIR) if f.lower().endswith(".pdf")])
|
|
|
+ if files:
|
|
|
+ response = RedirectResponse(
|
|
|
+ url=f"/static/web/viewer.html?file={build_file_url(files[0])}",
|
|
|
+ status_code=302,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ response = RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
|
|
|
|
|
|
- # fallback to sample document
|
|
|
- return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
|
|
|
+ response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
|
|
|
+ return response
|
|
|
|
|
|
|
|
|
@app.get("/login")
|
|
|
@@ -655,9 +692,8 @@ async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[
|
|
|
async def upload_pdf(
|
|
|
file: UploadFile = File(...),
|
|
|
custom_name: str = Form(...),
|
|
|
- session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
|
|
|
+ client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE),
|
|
|
):
|
|
|
- user = require_user(session_token)
|
|
|
if file.content_type != "application/pdf":
|
|
|
raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
|
|
|
|
|
|
@@ -666,8 +702,7 @@ async def upload_pdf(
|
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
|
|
|
|
|
|
unique_filename = f"{sanitized_name}.pdf"
|
|
|
- user_dir = get_user_dir(user["username"])
|
|
|
- file_path = os.path.join(user_dir, unique_filename)
|
|
|
+ file_path = os.path.join(BASE_STATIC_FILES_DIR, unique_filename)
|
|
|
|
|
|
if os.path.exists(file_path):
|
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
|
|
|
@@ -680,30 +715,28 @@ async def upload_pdf(
|
|
|
finally:
|
|
|
file.file.close()
|
|
|
|
|
|
- file_relative_path = build_user_file_url(user["username"], unique_filename)
|
|
|
- now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
- conn = db_conn()
|
|
|
- try:
|
|
|
- with conn.cursor() as cur:
|
|
|
- cur.execute(
|
|
|
- "UPDATE user SET last_file=%s, last_page=1, updated_at=%s WHERE id=%s",
|
|
|
- (file_relative_path, now, user["id"]),
|
|
|
- )
|
|
|
- finally:
|
|
|
- conn.close()
|
|
|
+ current_client_id = get_or_create_client_id(client_id)
|
|
|
+ file_relative_path = build_file_url(unique_filename)
|
|
|
+ store = load_progress_store()
|
|
|
+ store[current_client_id] = {
|
|
|
+ "last_file": file_relative_path,
|
|
|
+ "last_page": 1,
|
|
|
+ "updated_at": datetime.now(timezone.utc).isoformat(),
|
|
|
+ }
|
|
|
+ save_progress_store(store)
|
|
|
|
|
|
- return JSONResponse(content={"success": True, "file_path": file_relative_path})
|
|
|
+ response = JSONResponse(content={"success": True, "file_path": file_relative_path})
|
|
|
+ response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
|
|
|
+ return response
|
|
|
|
|
|
|
|
|
# List PDFs endpoint
|
|
|
@app.get("/list-pdfs")
|
|
|
-async def list_pdfs(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
- user = require_user(session_token)
|
|
|
+async def list_pdfs():
|
|
|
try:
|
|
|
- user_dir = get_user_dir(user["username"])
|
|
|
- files = os.listdir(user_dir)
|
|
|
+ files = os.listdir(BASE_STATIC_FILES_DIR)
|
|
|
pdf_files = [
|
|
|
- {"name": file, "url": build_user_file_url(user["username"], file)}
|
|
|
+ {"name": file, "url": build_file_url(file)}
|
|
|
for file in files
|
|
|
if file.lower().endswith(".pdf")
|
|
|
]
|
|
|
@@ -719,30 +752,21 @@ class ReadingProgressRequest(BaseModel):
|
|
|
|
|
|
|
|
|
@app.get("/reading-progress")
|
|
|
-async def get_reading_progress(file: str, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
- user = require_user(session_token)
|
|
|
+async def get_reading_progress(file: str, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
|
|
|
normalized_file = (file or "").strip()
|
|
|
if not normalized_file:
|
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
|
|
|
|
|
|
- conn = db_conn()
|
|
|
- try:
|
|
|
- with conn.cursor() as cur:
|
|
|
- cur.execute(
|
|
|
- "SELECT page FROM user_progress WHERE user_id=%s AND file_path=%s",
|
|
|
- (user["id"], normalized_file),
|
|
|
- )
|
|
|
- row = cur.fetchone()
|
|
|
- page = row["page"] if row else None
|
|
|
- finally:
|
|
|
- conn.close()
|
|
|
-
|
|
|
- return JSONResponse(content={"success": True, "file": normalized_file, "page": page})
|
|
|
+ current_client_id = get_or_create_client_id(client_id)
|
|
|
+ progress = load_progress_store().get(current_client_id, {})
|
|
|
+ page = progress.get("last_page") if progress.get("last_file") == normalized_file else None
|
|
|
+ response = JSONResponse(content={"success": True, "file": normalized_file, "page": page})
|
|
|
+ response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
|
|
|
+ return response
|
|
|
|
|
|
|
|
|
@app.post("/reading-progress")
|
|
|
-async def save_reading_progress(request: ReadingProgressRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
- user = require_user(session_token)
|
|
|
+async def save_reading_progress(request: ReadingProgressRequest, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
|
|
|
normalized_file = (request.file or "").strip()
|
|
|
page = int(request.page)
|
|
|
if not normalized_file:
|
|
|
@@ -750,31 +774,22 @@ async def save_reading_progress(request: ReadingProgressRequest, session_token:
|
|
|
if page < 1:
|
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
|
|
|
|
|
|
- now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
- conn = db_conn()
|
|
|
- try:
|
|
|
- with conn.cursor() as cur:
|
|
|
- cur.execute(
|
|
|
- """
|
|
|
- INSERT INTO user_progress (user_id, file_path, page, updated_at)
|
|
|
- VALUES (%s, %s, %s, %s)
|
|
|
- ON DUPLICATE KEY UPDATE page=VALUES(page), updated_at=VALUES(updated_at)
|
|
|
- """,
|
|
|
- (user["id"], normalized_file, page, now),
|
|
|
- )
|
|
|
- cur.execute(
|
|
|
- "UPDATE user SET last_file=%s, last_page=%s, updated_at=%s WHERE id=%s",
|
|
|
- (normalized_file, page, now, user["id"]),
|
|
|
- )
|
|
|
- finally:
|
|
|
- conn.close()
|
|
|
-
|
|
|
- return JSONResponse(content={"success": True})
|
|
|
+ current_client_id = get_or_create_client_id(client_id)
|
|
|
+ store = load_progress_store()
|
|
|
+ store[current_client_id] = {
|
|
|
+ "last_file": normalized_file,
|
|
|
+ "last_page": page,
|
|
|
+ "updated_at": datetime.now(timezone.utc).isoformat(),
|
|
|
+ }
|
|
|
+ save_progress_store(store)
|
|
|
+ response = JSONResponse(content={"success": True})
|
|
|
+ response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
|
|
|
+ return response
|
|
|
|
|
|
|
|
|
class TextToSpeechRequest(BaseModel):
|
|
|
user_input: str
|
|
|
- voice: str = "af_heart"
|
|
|
+ voice: str = OPENAI_TTS_DEFAULT_VOICE
|
|
|
speed: float = 1.0
|
|
|
|
|
|
|
|
|
@@ -786,24 +801,83 @@ async def generate_proxy(request: TextToSpeechRequest):
|
|
|
|
|
|
async def stream_generator() -> AsyncGenerator[bytes, None]:
|
|
|
try:
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- async with session.post(
|
|
|
- "http://141.140.15.30:8028/generate",
|
|
|
- headers={"Content-Type": "application/json"},
|
|
|
- json={"text": user_input, "voice": request.voice, "speed": request.speed},
|
|
|
- ) as response:
|
|
|
- if response.status != 200:
|
|
|
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
|
-
|
|
|
- async for chunk in response.content.iter_any():
|
|
|
- yield chunk
|
|
|
+ chunks = split_text_into_chunks(user_input)
|
|
|
+ for index, chunk in enumerate(chunks):
|
|
|
+ audio_bytes = await request_openai_tts_audio(chunk, request.voice)
|
|
|
+ payload = {
|
|
|
+ "index": index,
|
|
|
+ "text": chunk,
|
|
|
+ "audio": base64.b64encode(audio_bytes).decode("utf-8"),
|
|
|
+ }
|
|
|
+ yield (json.dumps(payload, ensure_ascii=False) + "\n").encode("utf-8")
|
|
|
+ await asyncio.sleep(0)
|
|
|
+ except HTTPException as e:
|
|
|
+ logger.error("generate proxy http error: %s", e.detail)
|
|
|
+ yield (json.dumps({"error": e.detail}, ensure_ascii=False) + "\n").encode("utf-8")
|
|
|
except Exception as e:
|
|
|
logger.error(f"generate proxy error: {str(e)}")
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
+ yield (json.dumps({"error": "TTS生成失败"}, ensure_ascii=False) + "\n").encode("utf-8")
|
|
|
|
|
|
return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
|
|
|
|
|
|
|
|
|
+def normalize_openai_voice(voice: str) -> str:
|
|
|
+ allowed_voices = {"alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"}
|
|
|
+ normalized = (voice or "").strip().lower()
|
|
|
+ return normalized if normalized in allowed_voices else OPENAI_TTS_DEFAULT_VOICE
|
|
|
+
|
|
|
+
|
|
|
+def get_audio_media_type(audio_format: str) -> str:
|
|
|
+ mapping = {
|
|
|
+ "wav": "audio/wav",
|
|
|
+ "mp3": "audio/mpeg",
|
|
|
+ "flac": "audio/flac",
|
|
|
+ "opus": "audio/opus",
|
|
|
+ "pcm16": "audio/L16",
|
|
|
+ }
|
|
|
+ return mapping.get(audio_format.lower(), "application/octet-stream")
|
|
|
+
|
|
|
+
|
|
|
+async def request_openai_tts_audio(text: str, voice: str) -> bytes:
|
|
|
+ payload = {
|
|
|
+ "model": OPENAI_TTS_MODEL,
|
|
|
+ "voice": normalize_openai_voice(voice),
|
|
|
+ "input": text,
|
|
|
+ "response_format": OPENAI_TTS_FORMAT,
|
|
|
+ "speed": 1.0,
|
|
|
+ }
|
|
|
+
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {OPENAI_TTS_API_KEY}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ }
|
|
|
+
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ async with session.post(
|
|
|
+ f"{OPENAI_TTS_BASE_URL.rstrip('/')}/v1/audio/speech",
|
|
|
+ headers=headers,
|
|
|
+ json=payload,
|
|
|
+ ) as response:
|
|
|
+ if response.status != 200:
|
|
|
+ response_text = await response.text()
|
|
|
+ logger.error("OpenAI TTS request failed: %s", response_text)
|
|
|
+ error_detail = "OpenAI TTS API 请求失败"
|
|
|
+ try:
|
|
|
+ error_data = json.loads(response_text)
|
|
|
+ error_obj = error_data.get("error", {})
|
|
|
+ error_message = error_obj.get("message")
|
|
|
+ error_code = error_obj.get("code")
|
|
|
+ if error_message:
|
|
|
+ error_detail = f"{error_detail}: {error_message}"
|
|
|
+ if response.status == 429 or error_code in {"rate_limit_exceeded", "model_not_found", "upstream_error"}:
|
|
|
+ raise HTTPException(status_code=503, detail=error_detail)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ pass
|
|
|
+ raise HTTPException(status_code=500 if response.status < 500 else 502, detail=error_detail)
|
|
|
+
|
|
|
+ return await response.read()
|
|
|
+
|
|
|
+
|
|
|
@app.post("/text-to-speech/")
|
|
|
async def text_to_speech(request: TextToSpeechRequest):
|
|
|
user_input = request.user_input.strip()
|
|
|
@@ -811,65 +885,23 @@ async def text_to_speech(request: TextToSpeechRequest):
|
|
|
raise HTTPException(status_code=400, detail="输入文本为空")
|
|
|
|
|
|
text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
|
|
|
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
|
|
|
+ audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
|
|
|
+ media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
|
|
|
|
|
|
if os.path.exists(audio_path):
|
|
|
with open(audio_path, "rb") as f:
|
|
|
- return Response(content=f.read(), media_type="audio/wav")
|
|
|
+ return Response(content=f.read(), media_type=media_type)
|
|
|
|
|
|
- async def audio_generator() -> AsyncGenerator[bytes, None]:
|
|
|
- try:
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- async with session.post(
|
|
|
- "http://141.140.15.30:8028/generate",
|
|
|
- headers={"Content-Type": "application/json"},
|
|
|
- json={"text": user_input, "voice": request.voice, "speed": request.speed},
|
|
|
- ) as response:
|
|
|
- if response.status != 200:
|
|
|
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
|
-
|
|
|
- buffer = ""
|
|
|
- full_audio = io.BytesIO()
|
|
|
- async for chunk in response.content.iter_any():
|
|
|
- buffer += chunk.decode("utf-8")
|
|
|
- lines = buffer.split("\n")
|
|
|
- buffer = lines[-1]
|
|
|
-
|
|
|
- for line in lines[:-1]:
|
|
|
- if not line.strip():
|
|
|
- continue
|
|
|
- try:
|
|
|
- data = json.loads(line)
|
|
|
- if data.get("error"):
|
|
|
- raise HTTPException(status_code=500, detail=data["error"])
|
|
|
- audio_b64 = data.get("audio")
|
|
|
- if audio_b64:
|
|
|
- audio_bytes = base64.b64decode(audio_b64)
|
|
|
- full_audio.write(audio_bytes)
|
|
|
- yield audio_bytes
|
|
|
- except json.JSONDecodeError as e:
|
|
|
- logger.error(f"JSON decode error: {str(e)}")
|
|
|
- continue
|
|
|
-
|
|
|
- if buffer.strip():
|
|
|
- try:
|
|
|
- data = json.loads(buffer)
|
|
|
- if data.get("audio"):
|
|
|
- audio_bytes = base64.b64decode(data["audio"])
|
|
|
- full_audio.write(audio_bytes)
|
|
|
- yield audio_bytes
|
|
|
- except json.JSONDecodeError:
|
|
|
- pass
|
|
|
-
|
|
|
- full_audio.seek(0)
|
|
|
- with open(audio_path, "wb") as f:
|
|
|
- f.write(full_audio.getvalue())
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"TTS error: {str(e)}")
|
|
|
- raise HTTPException(status_code=500, detail=str(e))
|
|
|
-
|
|
|
- return StreamingResponse(audio_generator(), media_type="audio/wav")
|
|
|
+ try:
|
|
|
+ audio_bytes = await request_openai_tts_audio(user_input, request.voice)
|
|
|
+ with open(audio_path, "wb") as f:
|
|
|
+ f.write(audio_bytes)
|
|
|
+ return Response(content=audio_bytes, media_type=media_type)
|
|
|
+ except HTTPException:
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"TTS error: {str(e)}")
|
|
|
+ raise HTTPException(status_code=500, detail="TTS生成失败")
|
|
|
|
|
|
|
|
|
MAX_CHUNK_SIZE = 200
|
|
|
@@ -900,56 +932,19 @@ def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> l
|
|
|
|
|
|
async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
|
|
|
text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
|
|
|
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
|
|
|
+ audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
|
|
|
|
|
|
if os.path.exists(audio_path):
|
|
|
with open(audio_path, "rb") as f:
|
|
|
yield f.read()
|
|
|
else:
|
|
|
try:
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
- async with session.post(
|
|
|
- "http://141.140.15.30:8028/generate",
|
|
|
- headers={"Content-Type": "application/json"},
|
|
|
- json={"text": chunk, "voice": voice, "speed": speed},
|
|
|
- ) as response:
|
|
|
- if response.status != 200:
|
|
|
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
|
-
|
|
|
- buffer = ""
|
|
|
- async for part in response.content.iter_any():
|
|
|
- buffer += part.decode("utf-8")
|
|
|
- lines = buffer.split("\n")
|
|
|
- buffer = lines[-1]
|
|
|
-
|
|
|
- for line in lines[:-1]:
|
|
|
- if not line.strip():
|
|
|
- continue
|
|
|
- try:
|
|
|
- data = json.loads(line)
|
|
|
- if data.get("error"):
|
|
|
- raise HTTPException(status_code=500, detail=data["error"])
|
|
|
- audio_b64 = data.get("audio")
|
|
|
- if audio_b64:
|
|
|
- audio_bytes = base64.b64decode(audio_b64)
|
|
|
- yield audio_bytes
|
|
|
- with open(audio_path, "wb") as f:
|
|
|
- f.write(audio_bytes)
|
|
|
- except json.JSONDecodeError as e:
|
|
|
- logger.error(f"JSON decode error: {str(e)}")
|
|
|
- continue
|
|
|
-
|
|
|
- if buffer.strip():
|
|
|
- try:
|
|
|
- data = json.loads(buffer)
|
|
|
- if data.get("audio"):
|
|
|
- audio_bytes = base64.b64decode(data["audio"])
|
|
|
- yield audio_bytes
|
|
|
- with open(audio_path, "wb") as f:
|
|
|
- f.write(audio_bytes)
|
|
|
- except json.JSONDecodeError:
|
|
|
- pass
|
|
|
-
|
|
|
+ audio_bytes = await request_openai_tts_audio(chunk, voice)
|
|
|
+ with open(audio_path, "wb") as f:
|
|
|
+ f.write(audio_bytes)
|
|
|
+ yield audio_bytes
|
|
|
+ except HTTPException as e:
|
|
|
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
|
|
|
|
|
|
@@ -961,10 +956,11 @@ async def page_to_speech(request: TextToSpeechRequest):
|
|
|
raise HTTPException(status_code=400, detail="输入文本为空")
|
|
|
|
|
|
full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
|
|
|
- full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
|
|
|
+ full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.{OPENAI_TTS_FORMAT}")
|
|
|
+ media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
|
|
|
|
|
|
if os.path.exists(full_audio_path):
|
|
|
- return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
|
|
|
+ return StreamingResponse(open(full_audio_path, "rb"), media_type=media_type)
|
|
|
|
|
|
chunks = split_text_into_chunks(user_input)
|
|
|
|
|
|
@@ -980,7 +976,7 @@ async def page_to_speech(request: TextToSpeechRequest):
|
|
|
with open(full_audio_path, "wb") as f:
|
|
|
f.write(full_audio_buffer.getvalue())
|
|
|
|
|
|
- return StreamingResponse(audio_generator(), media_type="audio/wav")
|
|
|
+ return StreamingResponse(audio_generator(), media_type=media_type)
|
|
|
|
|
|
|
|
|
@app.get("/health")
|