from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
import shutil
import hashlib
import asyncio
from typing import AsyncGenerator, Optional
import aiohttp
import io
import logging
import base64
import json
from datetime import datetime, timedelta, timezone
import secrets
import pymysql
from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Configure CORS
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Base directories
BASE_STATIC_FILES_DIR = "static/files"
os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
# Mount static files
app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Audio cache directory
CACHE_DIR = "audio_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
SESSION_COOKIE = "reader_pro_session"
SESSION_TTL_DAYS = 1
SESSION_TTL_DAYS_REMEMBER = 30
def db_conn():
return pymysql.connect(
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=MYSQL_DATABASE,
charset="utf8mb4",
autocommit=True,
cursorclass=pymysql.cursors.DictCursor,
)
def init_db() -> None:
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS user (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
username VARCHAR(64) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL,
is_admin TINYINT(1) NOT NULL DEFAULT 0,
is_active TINYINT(1) NOT NULL DEFAULT 1,
session_token VARCHAR(128) NULL,
session_expires_at DATETIME NULL,
last_file VARCHAR(1024) NULL,
last_page INT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
)
# backfill old schema without is_active column
cur.execute("SHOW COLUMNS FROM user LIKE 'is_active'")
if not cur.fetchone():
cur.execute("ALTER TABLE user ADD COLUMN is_active TINYINT(1) NOT NULL DEFAULT 1 AFTER is_admin")
cur.execute(
"""
CREATE TABLE IF NOT EXISTS user_progress (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
user_id BIGINT NOT NULL,
file_path VARCHAR(512) NOT NULL,
page INT NOT NULL,
updated_at DATETIME NOT NULL,
UNIQUE KEY uniq_user_file (user_id, file_path),
KEY idx_user_updated (user_id, updated_at),
CONSTRAINT fk_user_progress_user
FOREIGN KEY (user_id) REFERENCES user(id)
ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS user_config (
config_key VARCHAR(128) PRIMARY KEY,
config_value TEXT NULL,
updated_at DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
)
# seed admin
cur.execute("SELECT id FROM user WHERE username=%s", ("admin",))
admin = cur.fetchone()
if not admin:
now = datetime.now(timezone.utc).replace(tzinfo=None)
cur.execute(
"""
INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at)
VALUES (%s, %s, 1, 1, %s, %s)
""",
("admin", hash_password("admin"), now, now),
)
logger.info("Seeded default admin account: admin/admin")
finally:
conn.close()
def hash_password(password: str) -> str:
# Keep simple deterministic hash for compatibility; can migrate to bcrypt later.
return hashlib.sha256(password.encode("utf-8")).hexdigest()
def sanitize_filename(name: str) -> str:
return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
def get_user_dir(username: str) -> str:
safe = sanitize_filename(username) or "user"
path = os.path.join(BASE_STATIC_FILES_DIR, safe)
os.makedirs(path, exist_ok=True)
return path
def build_user_file_url(username: str, filename: str) -> str:
return f"/static/files/{sanitize_filename(username)}/{filename}"
def get_user_by_session(session_token: Optional[str]):
if not session_token:
return None
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT * FROM user
WHERE session_token=%s AND session_expires_at IS NOT NULL AND session_expires_at>%s
""",
(session_token, now),
)
return cur.fetchone()
finally:
conn.close()
def require_user(session_token: Optional[str]):
user = get_user_by_session(session_token)
if not user:
raise HTTPException(status_code=401, detail="未登录或会话已过期")
return user
def require_admin(session_token: Optional[str]):
user = require_user(session_token)
if not user.get("is_admin"):
raise HTTPException(status_code=403, detail="需要管理员权限")
return user
def set_session_for_user(username: str, remember_me: bool):
token = secrets.token_urlsafe(48)
expire_days = SESSION_TTL_DAYS_REMEMBER if remember_me else SESSION_TTL_DAYS
expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=expire_days)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE user
SET session_token=%s, session_expires_at=%s, updated_at=%s
WHERE username=%s
""",
(token, expires_at, datetime.now(timezone.utc).replace(tzinfo=None), username),
)
finally:
conn.close()
return token, expire_days
@app.on_event("startup")
def on_startup():
init_db()
@app.get("/")
def root(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = get_user_by_session(session_token)
if not user:
return RedirectResponse(url="/login", status_code=302)
last_file = user.get("last_file")
if last_file:
return RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
# fallback: first file in user's own directory
user_dir = get_user_dir(user["username"])
files = sorted([f for f in os.listdir(user_dir) if f.lower().endswith(".pdf")])
if files:
file_url = build_user_file_url(user["username"], files[0])
return RedirectResponse(url=f"/static/web/viewer.html?file={file_url}", status_code=302)
# fallback to sample document
return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
@app.get("/login")
def login_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = get_user_by_session(session_token)
if user:
return RedirectResponse(url="/", status_code=302)
html = """
VoiceFlow AI Reader 【小满TTS英文听书】 - 登录
"""
return HTMLResponse(content=html)
@app.get("/register")
def register_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = get_user_by_session(session_token)
if user:
return RedirectResponse(url="/", status_code=302)
html = """
VoiceFlow AI Reader 【小满TTS英文听书】 - 注册
"""
return HTMLResponse(content=html)
class LoginRequest(BaseModel):
username: str
password: str
remember_me: bool = False
class RegisterRequest(BaseModel):
username: str
password: str
@app.post("/auth/register")
async def auth_register(request: RegisterRequest):
username = sanitize_filename((request.username or "").strip())
password = (request.password or "").strip()
if len(username) < 3:
return JSONResponse(status_code=400, content={"success": False, "error": "用户名至少3位"})
if len(password) < 4:
return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute("SELECT id FROM user WHERE username=%s", (username,))
if cur.fetchone():
return JSONResponse(status_code=400, content={"success": False, "error": "用户名已存在"})
cur.execute(
"""
INSERT INTO user (username, password_hash, is_admin, created_at, updated_at)
VALUES (%s, %s, 0, %s, %s)
""",
(username, hash_password(password), now, now),
)
finally:
conn.close()
get_user_dir(username)
return JSONResponse(content={"success": True})
@app.post("/auth/login")
async def auth_login(request: LoginRequest):
username = sanitize_filename((request.username or "").strip())
password = (request.password or "").strip()
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute("SELECT * FROM user WHERE username=%s", (username,))
user = cur.fetchone()
finally:
conn.close()
if not user or user["password_hash"] != hash_password(password):
return JSONResponse(status_code=401, content={"success": False, "error": "用户名或密码错误"})
if int(user.get("is_active", 1)) != 1:
return JSONResponse(status_code=403, content={"success": False, "error": "账号已被禁用"})
token, expire_days = set_session_for_user(username, request.remember_me)
resp = JSONResponse(content={"success": True, "is_admin": bool(user.get("is_admin"))})
max_age = expire_days * 24 * 3600
resp.set_cookie(
key=SESSION_COOKIE,
value=token,
max_age=max_age,
httponly=True,
samesite="lax",
secure=False,
path="/",
)
return resp
@app.post("/auth/logout")
async def auth_logout(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = get_user_by_session(session_token)
if user:
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE user SET session_token=NULL, session_expires_at=NULL, updated_at=%s WHERE id=%s",
(datetime.now(timezone.utc).replace(tzinfo=None), user["id"]),
)
finally:
conn.close()
resp = JSONResponse(content={"success": True})
resp.delete_cookie(SESSION_COOKIE, path="/")
return resp
@app.get("/auth/me")
async def auth_me(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = get_user_by_session(session_token)
if not user:
return JSONResponse(status_code=401, content={"success": False})
return JSONResponse(
content={
"success": True,
"username": user["username"],
"is_admin": bool(user.get("is_admin")),
"is_active": bool(user.get("is_active", 1)),
}
)
@app.get("/admin")
def admin_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
require_admin(session_token)
return RedirectResponse(url="/static/web/admin.html", status_code=302)
class AdminUserRequest(BaseModel):
username: str
password: str
is_admin: bool = False
@app.get("/admin/users")
async def admin_users(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
require_admin(session_token)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute("SELECT id, username, is_admin, is_active, created_at FROM user ORDER BY id ASC")
users = cur.fetchall()
finally:
conn.close()
return JSONResponse(content=jsonable_encoder({"success": True, "users": users}))
@app.post("/admin/users")
async def admin_create_or_reset_user(payload: AdminUserRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
require_admin(session_token)
username = sanitize_filename((payload.username or "").strip())
password = (payload.password or "").strip()
if len(username) < 3 or len(password) < 4:
return JSONResponse(status_code=400, content={"success": False, "error": "用户名或密码不合法"})
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute("SELECT id FROM user WHERE username=%s", (username,))
row = cur.fetchone()
if row:
cur.execute(
"UPDATE user SET password_hash=%s, is_admin=%s, is_active=1, updated_at=%s WHERE username=%s",
(hash_password(password), 1 if payload.is_admin else 0, now, username),
)
else:
cur.execute(
"INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at) VALUES (%s,%s,%s,1,%s,%s)",
(username, hash_password(password), 1 if payload.is_admin else 0, now, now),
)
finally:
conn.close()
get_user_dir(username)
return JSONResponse(content={"success": True})
@app.delete("/admin/users/{username}")
async def admin_delete_user(
username: str,
delete_files: bool = False,
session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
):
require_admin(session_token)
if username == "admin":
return JSONResponse(status_code=400, content={"success": False, "error": "不能删除admin"})
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute("DELETE FROM user WHERE username=%s", (username,))
finally:
conn.close()
if delete_files:
user_dir = get_user_dir(username)
if os.path.isdir(user_dir):
shutil.rmtree(user_dir, ignore_errors=True)
return JSONResponse(content={"success": True})
class AdminResetPasswordRequest(BaseModel):
password: str
@app.post("/admin/users/{username}/reset-password")
async def admin_reset_password(
username: str,
payload: AdminResetPasswordRequest,
session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
):
require_admin(session_token)
password = (payload.password or "").strip()
if len(password) < 4:
return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE user SET password_hash=%s, updated_at=%s WHERE username=%s",
(hash_password(password), now, username),
)
if cur.rowcount == 0:
return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
finally:
conn.close()
return JSONResponse(content={"success": True})
class AdminToggleUserRequest(BaseModel):
is_active: bool
@app.post("/admin/users/{username}/status")
async def admin_toggle_user_status(
username: str,
payload: AdminToggleUserRequest,
session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
):
require_admin(session_token)
if username == "admin" and not payload.is_active:
return JSONResponse(status_code=400, content={"success": False, "error": "不能禁用admin"})
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE user SET is_active=%s, updated_at=%s WHERE username=%s",
(1 if payload.is_active else 0, now, username),
)
if cur.rowcount == 0:
return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
finally:
conn.close()
return JSONResponse(content={"success": True})
class AdminConfigRequest(BaseModel):
config_key: str
config_value: Optional[str] = None
@app.get("/admin/config")
async def admin_get_config(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
require_admin(session_token)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute("SELECT config_key, config_value, updated_at FROM user_config ORDER BY config_key")
rows = cur.fetchall()
finally:
conn.close()
return JSONResponse(content=jsonable_encoder({"success": True, "configs": rows}))
@app.post("/admin/config")
async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
require_admin(session_token)
config_key = (payload.config_key or "").strip()
if not config_key:
return JSONResponse(status_code=400, content={"success": False, "error": "config_key 不能为空"})
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO user_config (config_key, config_value, updated_at)
VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE config_value=VALUES(config_value), updated_at=VALUES(updated_at)
""",
(config_key, payload.config_value, now),
)
finally:
conn.close()
return JSONResponse(content={"success": True})
# PDF upload endpoint
@app.post("/upload-pdf")
async def upload_pdf(
file: UploadFile = File(...),
custom_name: str = Form(...),
session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
):
user = require_user(session_token)
if file.content_type != "application/pdf":
raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
sanitized_name = sanitize_filename(custom_name)
if not sanitized_name:
return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
unique_filename = f"{sanitized_name}.pdf"
user_dir = get_user_dir(user["username"])
file_path = os.path.join(user_dir, unique_filename)
if os.path.exists(file_path):
return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
except Exception:
raise HTTPException(status_code=500, detail="上传过程中出错")
finally:
file.file.close()
file_relative_path = build_user_file_url(user["username"], unique_filename)
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE user SET last_file=%s, last_page=1, updated_at=%s WHERE id=%s",
(file_relative_path, now, user["id"]),
)
finally:
conn.close()
return JSONResponse(content={"success": True, "file_path": file_relative_path})
# List PDFs endpoint
@app.get("/list-pdfs")
async def list_pdfs(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = require_user(session_token)
try:
user_dir = get_user_dir(user["username"])
files = os.listdir(user_dir)
pdf_files = [
{"name": file, "url": build_user_file_url(user["username"], file)}
for file in files
if file.lower().endswith(".pdf")
]
pdf_files.sort(key=lambda x: x["name"].lower())
return JSONResponse(content={"success": True, "files": pdf_files})
except Exception:
raise HTTPException(status_code=500, detail="无法获取文件列表")
class ReadingProgressRequest(BaseModel):
file: str
page: int
@app.get("/reading-progress")
async def get_reading_progress(file: str, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = require_user(session_token)
normalized_file = (file or "").strip()
if not normalized_file:
return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT page FROM user_progress WHERE user_id=%s AND file_path=%s",
(user["id"], normalized_file),
)
row = cur.fetchone()
page = row["page"] if row else None
finally:
conn.close()
return JSONResponse(content={"success": True, "file": normalized_file, "page": page})
@app.post("/reading-progress")
async def save_reading_progress(request: ReadingProgressRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
user = require_user(session_token)
normalized_file = (request.file or "").strip()
page = int(request.page)
if not normalized_file:
return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
if page < 1:
return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
now = datetime.now(timezone.utc).replace(tzinfo=None)
conn = db_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO user_progress (user_id, file_path, page, updated_at)
VALUES (%s, %s, %s, %s)
ON DUPLICATE KEY UPDATE page=VALUES(page), updated_at=VALUES(updated_at)
""",
(user["id"], normalized_file, page, now),
)
cur.execute(
"UPDATE user SET last_file=%s, last_page=%s, updated_at=%s WHERE id=%s",
(normalized_file, page, now, user["id"]),
)
finally:
conn.close()
return JSONResponse(content={"success": True})
class TextToSpeechRequest(BaseModel):
user_input: str
voice: str = "af_heart"
speed: float = 1.0
@app.post("/text-to-speech/")
async def text_to_speech(request: TextToSpeechRequest):
user_input = request.user_input.strip()
if not user_input:
raise HTTPException(status_code=400, detail="输入文本为空")
text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
if os.path.exists(audio_path):
with open(audio_path, "rb") as f:
return Response(content=f.read(), media_type="audio/wav")
async def audio_generator() -> AsyncGenerator[bytes, None]:
try:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://141.140.15.30:8028/generate",
headers={"Content-Type": "application/json"},
json={"text": user_input, "voice": request.voice, "speed": request.speed},
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="TTS API 请求失败")
buffer = ""
full_audio = io.BytesIO()
async for chunk in response.content.iter_any():
buffer += chunk.decode("utf-8")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
if not line.strip():
continue
try:
data = json.loads(line)
if data.get("error"):
raise HTTPException(status_code=500, detail=data["error"])
audio_b64 = data.get("audio")
if audio_b64:
audio_bytes = base64.b64decode(audio_b64)
full_audio.write(audio_bytes)
yield audio_bytes
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {str(e)}")
continue
if buffer.strip():
try:
data = json.loads(buffer)
if data.get("audio"):
audio_bytes = base64.b64decode(data["audio"])
full_audio.write(audio_bytes)
yield audio_bytes
except json.JSONDecodeError:
pass
full_audio.seek(0)
with open(audio_path, "wb") as f:
f.write(full_audio.getvalue())
except Exception as e:
logger.error(f"TTS error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
return StreamingResponse(audio_generator(), media_type="audio/wav")
MAX_CHUNK_SIZE = 200
def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
import re
sentences = re.split(r"(?<=[.!?]) +", text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
current_chunk += " " + sentence if current_chunk else sentence
else:
if current_chunk:
chunks.append(current_chunk)
if len(sentence) > max_chunk_size:
for i in range(0, len(sentence), max_chunk_size):
chunks.append(sentence[i : i + max_chunk_size])
current_chunk = ""
else:
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
return chunks
async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
if os.path.exists(audio_path):
with open(audio_path, "rb") as f:
yield f.read()
else:
try:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://141.140.15.30:8028/generate",
headers={"Content-Type": "application/json"},
json={"text": chunk, "voice": voice, "speed": speed},
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="TTS API 请求失败")
buffer = ""
async for part in response.content.iter_any():
buffer += part.decode("utf-8")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
if not line.strip():
continue
try:
data = json.loads(line)
if data.get("error"):
raise HTTPException(status_code=500, detail=data["error"])
audio_b64 = data.get("audio")
if audio_b64:
audio_bytes = base64.b64decode(audio_b64)
yield audio_bytes
with open(audio_path, "wb") as f:
f.write(audio_bytes)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {str(e)}")
continue
if buffer.strip():
try:
data = json.loads(buffer)
if data.get("audio"):
audio_bytes = base64.b64decode(data["audio"])
yield audio_bytes
with open(audio_path, "wb") as f:
f.write(audio_bytes)
except json.JSONDecodeError:
pass
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
@app.post("/page-to-speech/")
async def page_to_speech(request: TextToSpeechRequest):
user_input = request.user_input.strip()
if not user_input:
raise HTTPException(status_code=400, detail="输入文本为空")
full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
if os.path.exists(full_audio_path):
return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
chunks = split_text_into_chunks(user_input)
async def audio_generator() -> AsyncGenerator[bytes, None]:
full_audio_buffer = io.BytesIO()
for chunk in chunks:
async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
yield audio_data
full_audio_buffer.write(audio_data)
await asyncio.sleep(0)
full_audio_buffer.seek(0)
with open(full_audio_path, "wb") as f:
f.write(full_audio_buffer.getvalue())
return StreamingResponse(audio_generator(), media_type="audio/wav")
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8005)