| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994 |
- from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
- from fastapi.encoders import jsonable_encoder
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse, HTMLResponse
- from fastapi.staticfiles import StaticFiles
- from pydantic import BaseModel
- import os
- import shutil
- import hashlib
- import asyncio
- from typing import AsyncGenerator, Optional
- import aiohttp
- import io
- import logging
- import base64
- import json
- from datetime import datetime, timedelta, timezone
- import secrets
- import pymysql
- from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
- # Set up logging
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- # Initialize FastAPI app
- app = FastAPI()
- # Configure CORS
- origins = ["*"]
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # Base directories
- BASE_STATIC_FILES_DIR = "static/files"
- os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
- # Mount static files
- app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
- app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
- app.mount("/static", StaticFiles(directory="static"), name="static")
- # Audio cache directory
- CACHE_DIR = "audio_cache"
- os.makedirs(CACHE_DIR, exist_ok=True)
- SESSION_COOKIE = "reader_pro_session"
- SESSION_TTL_DAYS = 1
- SESSION_TTL_DAYS_REMEMBER = 30
- def db_conn():
- return pymysql.connect(
- host=MYSQL_HOST,
- port=MYSQL_PORT,
- user=MYSQL_USER,
- password=MYSQL_PASSWORD,
- database=MYSQL_DATABASE,
- charset="utf8mb4",
- autocommit=True,
- cursorclass=pymysql.cursors.DictCursor,
- )
- def init_db() -> None:
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- """
- CREATE TABLE IF NOT EXISTS user (
- id BIGINT PRIMARY KEY AUTO_INCREMENT,
- username VARCHAR(64) NOT NULL UNIQUE,
- password_hash VARCHAR(255) NOT NULL,
- is_admin TINYINT(1) NOT NULL DEFAULT 0,
- is_active TINYINT(1) NOT NULL DEFAULT 1,
- session_token VARCHAR(128) NULL,
- session_expires_at DATETIME NULL,
- last_file VARCHAR(1024) NULL,
- last_page INT NULL,
- created_at DATETIME NOT NULL,
- updated_at DATETIME NOT NULL
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
- """
- )
- # backfill old schema without is_active column
- cur.execute("SHOW COLUMNS FROM user LIKE 'is_active'")
- if not cur.fetchone():
- cur.execute("ALTER TABLE user ADD COLUMN is_active TINYINT(1) NOT NULL DEFAULT 1 AFTER is_admin")
- cur.execute(
- """
- CREATE TABLE IF NOT EXISTS user_progress (
- id BIGINT PRIMARY KEY AUTO_INCREMENT,
- user_id BIGINT NOT NULL,
- file_path VARCHAR(512) NOT NULL,
- page INT NOT NULL,
- updated_at DATETIME NOT NULL,
- UNIQUE KEY uniq_user_file (user_id, file_path),
- KEY idx_user_updated (user_id, updated_at),
- CONSTRAINT fk_user_progress_user
- FOREIGN KEY (user_id) REFERENCES user(id)
- ON DELETE CASCADE
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
- """
- )
- cur.execute(
- """
- CREATE TABLE IF NOT EXISTS user_config (
- config_key VARCHAR(128) PRIMARY KEY,
- config_value TEXT NULL,
- updated_at DATETIME NOT NULL
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
- """
- )
- # seed admin
- cur.execute("SELECT id FROM user WHERE username=%s", ("admin",))
- admin = cur.fetchone()
- if not admin:
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- cur.execute(
- """
- INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at)
- VALUES (%s, %s, 1, 1, %s, %s)
- """,
- ("admin", hash_password("admin"), now, now),
- )
- logger.info("Seeded default admin account: admin/admin")
- finally:
- conn.close()
- def hash_password(password: str) -> str:
- # Keep simple deterministic hash for compatibility; can migrate to bcrypt later.
- return hashlib.sha256(password.encode("utf-8")).hexdigest()
- def sanitize_filename(name: str) -> str:
- return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
- def get_user_dir(username: str) -> str:
- safe = sanitize_filename(username) or "user"
- path = os.path.join(BASE_STATIC_FILES_DIR, safe)
- os.makedirs(path, exist_ok=True)
- return path
- def build_user_file_url(username: str, filename: str) -> str:
- return f"/static/files/{sanitize_filename(username)}/{filename}"
- def get_user_by_session(session_token: Optional[str]):
- if not session_token:
- return None
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- """
- SELECT * FROM user
- WHERE session_token=%s AND session_expires_at IS NOT NULL AND session_expires_at>%s
- """,
- (session_token, now),
- )
- return cur.fetchone()
- finally:
- conn.close()
- def require_user(session_token: Optional[str]):
- user = get_user_by_session(session_token)
- if not user:
- raise HTTPException(status_code=401, detail="未登录或会话已过期")
- return user
- def require_admin(session_token: Optional[str]):
- user = require_user(session_token)
- if not user.get("is_admin"):
- raise HTTPException(status_code=403, detail="需要管理员权限")
- return user
- def set_session_for_user(username: str, remember_me: bool):
- token = secrets.token_urlsafe(48)
- expire_days = SESSION_TTL_DAYS_REMEMBER if remember_me else SESSION_TTL_DAYS
- expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=expire_days)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- """
- UPDATE user
- SET session_token=%s, session_expires_at=%s, updated_at=%s
- WHERE username=%s
- """,
- (token, expires_at, datetime.now(timezone.utc).replace(tzinfo=None), username),
- )
- finally:
- conn.close()
- return token, expire_days
- @app.on_event("startup")
- def on_startup():
- init_db()
- @app.get("/")
- def root(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = get_user_by_session(session_token)
- if not user:
- return RedirectResponse(url="/login", status_code=302)
- last_file = user.get("last_file")
- if last_file:
- return RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
- # fallback: first file in user's own directory
- user_dir = get_user_dir(user["username"])
- files = sorted([f for f in os.listdir(user_dir) if f.lower().endswith(".pdf")])
- if files:
- file_url = build_user_file_url(user["username"], files[0])
- return RedirectResponse(url=f"/static/web/viewer.html?file={file_url}", status_code=302)
- # fallback to sample document
- return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
- @app.get("/login")
- def login_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = get_user_by_session(session_token)
- if user:
- return RedirectResponse(url="/", status_code=302)
- html = """
- <!doctype html>
- <html lang="zh-CN">
- <head>
- <meta charset="utf-8"/>
- <meta name="viewport" content="width=device-width, initial-scale=1"/>
- <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 登录</title>
- <style>
- *{box-sizing:border-box}
- body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
- .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
- .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
- .sub{margin:0 0 18px;font-size:13px;color:#64748b}
- input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
- input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
- input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
- button{margin-top:12px;border:none;background:#2563eb;color:#fff;font-weight:600;cursor:pointer}
- button:hover{background:#1d4ed8}
- .row{display:flex;gap:8px;align-items:center;margin-top:10px;font-size:13px;color:#475569}
- .row input{width:auto;margin:0}
- .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
- .links{margin-top:8px;text-align:right}
- .links a{font-size:13px;color:#2563eb;text-decoration:none}
- .links a:hover{text-decoration:underline}
- .hint{margin-top:12px;padding:10px;border-radius:10px;background:#f8fafc;color:#475569;font-size:12px}
- </style>
- </head>
- <body>
- <div class="card">
- <h1 class="title">VoiceFlow AI Reader</h1>
- <p class="sub">【小满TTS英文听书】</p>
- <input id="loginUser" placeholder="用户名" autocomplete="username"/>
- <input id="loginPass" type="password" placeholder="密码" autocomplete="current-password"/>
- <label class="row"><input id="remember" type="checkbox"/>30天内免登录</label>
- <button onclick="login()">登录</button>
- <div id="loginMsg" class="msg"></div>
- <div class="links"><a href="/register">没有账号?去注册</a></div>
- <div class="hint">管理员默认账号:admin / admin</div>
- </div>
- <script>
- async function login(){
- const username=document.getElementById('loginUser').value.trim();
- const password=document.getElementById('loginPass').value;
- const remember_me=document.getElementById('remember').checked;
- const msg=document.getElementById('loginMsg');
- msg.textContent='';
- const r=await fetch('/auth/login',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password,remember_me})});
- const d=await r.json();
- if(!r.ok||!d.success){msg.textContent=d.error||'登录失败';return;}
- window.location.href='/';
- }
- </script>
- </body>
- </html>
- """
- return HTMLResponse(content=html)
- @app.get("/register")
- def register_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = get_user_by_session(session_token)
- if user:
- return RedirectResponse(url="/", status_code=302)
- html = """
- <!doctype html>
- <html lang="zh-CN">
- <head>
- <meta charset="utf-8"/>
- <meta name="viewport" content="width=device-width, initial-scale=1"/>
- <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 注册</title>
- <style>
- *{box-sizing:border-box}
- body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
- .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
- .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
- .sub{margin:0 0 18px;font-size:13px;color:#64748b}
- input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
- input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
- input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
- button{margin-top:12px;border:none;background:#16a34a;color:#fff;font-weight:600;cursor:pointer}
- button:hover{background:#15803d}
- .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
- .ok{color:#0f7b0f}
- .links{margin-top:8px;text-align:right}
- .links a{font-size:13px;color:#2563eb;text-decoration:none}
- .links a:hover{text-decoration:underline}
- </style>
- </head>
- <body>
- <div class="card">
- <h1 class="title">创建新账号</h1>
- <p class="sub">VoiceFlow AI Reader 【小满TTS英文听书】</p>
- <input id="regUser" placeholder="用户名(至少3位)" autocomplete="username"/>
- <input id="regPass" type="password" placeholder="密码(至少4位)" autocomplete="new-password"/>
- <button onclick="registerUser()">注册</button>
- <div id="regMsg" class="msg"></div>
- <div class="links"><a href="/login">已有账号?去登录</a></div>
- </div>
- <script>
- async function registerUser(){
- const username=document.getElementById('regUser').value.trim();
- const password=document.getElementById('regPass').value;
- const msg=document.getElementById('regMsg');
- msg.textContent='';
- const r=await fetch('/auth/register',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password})});
- const d=await r.json();
- if(!r.ok||!d.success){msg.textContent=d.error||'注册失败';return;}
- msg.textContent='注册成功,2秒后跳转登录页';
- msg.className='msg ok';
- setTimeout(()=>{window.location.href='/login';},2000);
- }
- </script>
- </body>
- </html>
- """
- return HTMLResponse(content=html)
- class LoginRequest(BaseModel):
- username: str
- password: str
- remember_me: bool = False
- class RegisterRequest(BaseModel):
- username: str
- password: str
- @app.post("/auth/register")
- async def auth_register(request: RegisterRequest):
- username = sanitize_filename((request.username or "").strip())
- password = (request.password or "").strip()
- if len(username) < 3:
- return JSONResponse(status_code=400, content={"success": False, "error": "用户名至少3位"})
- if len(password) < 4:
- return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute("SELECT id FROM user WHERE username=%s", (username,))
- if cur.fetchone():
- return JSONResponse(status_code=400, content={"success": False, "error": "用户名已存在"})
- cur.execute(
- """
- INSERT INTO user (username, password_hash, is_admin, created_at, updated_at)
- VALUES (%s, %s, 0, %s, %s)
- """,
- (username, hash_password(password), now, now),
- )
- finally:
- conn.close()
- get_user_dir(username)
- return JSONResponse(content={"success": True})
- @app.post("/auth/login")
- async def auth_login(request: LoginRequest):
- username = sanitize_filename((request.username or "").strip())
- password = (request.password or "").strip()
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute("SELECT * FROM user WHERE username=%s", (username,))
- user = cur.fetchone()
- finally:
- conn.close()
- if not user or user["password_hash"] != hash_password(password):
- return JSONResponse(status_code=401, content={"success": False, "error": "用户名或密码错误"})
- if int(user.get("is_active", 1)) != 1:
- return JSONResponse(status_code=403, content={"success": False, "error": "账号已被禁用"})
- token, expire_days = set_session_for_user(username, request.remember_me)
- resp = JSONResponse(content={"success": True, "is_admin": bool(user.get("is_admin"))})
- max_age = expire_days * 24 * 3600
- resp.set_cookie(
- key=SESSION_COOKIE,
- value=token,
- max_age=max_age,
- httponly=True,
- samesite="lax",
- secure=False,
- path="/",
- )
- return resp
- @app.post("/auth/logout")
- async def auth_logout(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = get_user_by_session(session_token)
- if user:
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- "UPDATE user SET session_token=NULL, session_expires_at=NULL, updated_at=%s WHERE id=%s",
- (datetime.now(timezone.utc).replace(tzinfo=None), user["id"]),
- )
- finally:
- conn.close()
- resp = JSONResponse(content={"success": True})
- resp.delete_cookie(SESSION_COOKIE, path="/")
- return resp
- @app.get("/auth/me")
- async def auth_me(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = get_user_by_session(session_token)
- if not user:
- return JSONResponse(status_code=401, content={"success": False})
- return JSONResponse(
- content={
- "success": True,
- "username": user["username"],
- "is_admin": bool(user.get("is_admin")),
- "is_active": bool(user.get("is_active", 1)),
- }
- )
- @app.get("/admin")
- def admin_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- require_admin(session_token)
- return RedirectResponse(url="/static/web/admin.html", status_code=302)
- class AdminUserRequest(BaseModel):
- username: str
- password: str
- is_admin: bool = False
- @app.get("/admin/users")
- async def admin_users(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- require_admin(session_token)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute("SELECT id, username, is_admin, is_active, created_at FROM user ORDER BY id ASC")
- users = cur.fetchall()
- finally:
- conn.close()
- return JSONResponse(content=jsonable_encoder({"success": True, "users": users}))
- @app.post("/admin/users")
- async def admin_create_or_reset_user(payload: AdminUserRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- require_admin(session_token)
- username = sanitize_filename((payload.username or "").strip())
- password = (payload.password or "").strip()
- if len(username) < 3 or len(password) < 4:
- return JSONResponse(status_code=400, content={"success": False, "error": "用户名或密码不合法"})
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute("SELECT id FROM user WHERE username=%s", (username,))
- row = cur.fetchone()
- if row:
- cur.execute(
- "UPDATE user SET password_hash=%s, is_admin=%s, is_active=1, updated_at=%s WHERE username=%s",
- (hash_password(password), 1 if payload.is_admin else 0, now, username),
- )
- else:
- cur.execute(
- "INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at) VALUES (%s,%s,%s,1,%s,%s)",
- (username, hash_password(password), 1 if payload.is_admin else 0, now, now),
- )
- finally:
- conn.close()
- get_user_dir(username)
- return JSONResponse(content={"success": True})
- @app.delete("/admin/users/{username}")
- async def admin_delete_user(
- username: str,
- delete_files: bool = False,
- session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
- ):
- require_admin(session_token)
- if username == "admin":
- return JSONResponse(status_code=400, content={"success": False, "error": "不能删除admin"})
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute("DELETE FROM user WHERE username=%s", (username,))
- finally:
- conn.close()
- if delete_files:
- user_dir = get_user_dir(username)
- if os.path.isdir(user_dir):
- shutil.rmtree(user_dir, ignore_errors=True)
- return JSONResponse(content={"success": True})
- class AdminResetPasswordRequest(BaseModel):
- password: str
- @app.post("/admin/users/{username}/reset-password")
- async def admin_reset_password(
- username: str,
- payload: AdminResetPasswordRequest,
- session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
- ):
- require_admin(session_token)
- password = (payload.password or "").strip()
- if len(password) < 4:
- return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- "UPDATE user SET password_hash=%s, updated_at=%s WHERE username=%s",
- (hash_password(password), now, username),
- )
- if cur.rowcount == 0:
- return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
- finally:
- conn.close()
- return JSONResponse(content={"success": True})
- class AdminToggleUserRequest(BaseModel):
- is_active: bool
- @app.post("/admin/users/{username}/status")
- async def admin_toggle_user_status(
- username: str,
- payload: AdminToggleUserRequest,
- session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
- ):
- require_admin(session_token)
- if username == "admin" and not payload.is_active:
- return JSONResponse(status_code=400, content={"success": False, "error": "不能禁用admin"})
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- "UPDATE user SET is_active=%s, updated_at=%s WHERE username=%s",
- (1 if payload.is_active else 0, now, username),
- )
- if cur.rowcount == 0:
- return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
- finally:
- conn.close()
- return JSONResponse(content={"success": True})
- class AdminConfigRequest(BaseModel):
- config_key: str
- config_value: Optional[str] = None
- @app.get("/admin/config")
- async def admin_get_config(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- require_admin(session_token)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute("SELECT config_key, config_value, updated_at FROM user_config ORDER BY config_key")
- rows = cur.fetchall()
- finally:
- conn.close()
- return JSONResponse(content=jsonable_encoder({"success": True, "configs": rows}))
- @app.post("/admin/config")
- async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- require_admin(session_token)
- config_key = (payload.config_key or "").strip()
- if not config_key:
- return JSONResponse(status_code=400, content={"success": False, "error": "config_key 不能为空"})
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- """
- INSERT INTO user_config (config_key, config_value, updated_at)
- VALUES (%s, %s, %s)
- ON DUPLICATE KEY UPDATE config_value=VALUES(config_value), updated_at=VALUES(updated_at)
- """,
- (config_key, payload.config_value, now),
- )
- finally:
- conn.close()
- return JSONResponse(content={"success": True})
- # PDF upload endpoint
- @app.post("/upload-pdf")
- async def upload_pdf(
- file: UploadFile = File(...),
- custom_name: str = Form(...),
- session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
- ):
- user = require_user(session_token)
- if file.content_type != "application/pdf":
- raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
- sanitized_name = sanitize_filename(custom_name)
- if not sanitized_name:
- return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
- unique_filename = f"{sanitized_name}.pdf"
- user_dir = get_user_dir(user["username"])
- file_path = os.path.join(user_dir, unique_filename)
- if os.path.exists(file_path):
- return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
- try:
- with open(file_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
- except Exception:
- raise HTTPException(status_code=500, detail="上传过程中出错")
- finally:
- file.file.close()
- file_relative_path = build_user_file_url(user["username"], unique_filename)
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- "UPDATE user SET last_file=%s, last_page=1, updated_at=%s WHERE id=%s",
- (file_relative_path, now, user["id"]),
- )
- finally:
- conn.close()
- return JSONResponse(content={"success": True, "file_path": file_relative_path})
- # List PDFs endpoint
- @app.get("/list-pdfs")
- async def list_pdfs(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = require_user(session_token)
- try:
- user_dir = get_user_dir(user["username"])
- files = os.listdir(user_dir)
- pdf_files = [
- {"name": file, "url": build_user_file_url(user["username"], file)}
- for file in files
- if file.lower().endswith(".pdf")
- ]
- pdf_files.sort(key=lambda x: x["name"].lower())
- return JSONResponse(content={"success": True, "files": pdf_files})
- except Exception:
- raise HTTPException(status_code=500, detail="无法获取文件列表")
- class ReadingProgressRequest(BaseModel):
- file: str
- page: int
- @app.get("/reading-progress")
- async def get_reading_progress(file: str, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = require_user(session_token)
- normalized_file = (file or "").strip()
- if not normalized_file:
- return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- "SELECT page FROM user_progress WHERE user_id=%s AND file_path=%s",
- (user["id"], normalized_file),
- )
- row = cur.fetchone()
- page = row["page"] if row else None
- finally:
- conn.close()
- return JSONResponse(content={"success": True, "file": normalized_file, "page": page})
- @app.post("/reading-progress")
- async def save_reading_progress(request: ReadingProgressRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
- user = require_user(session_token)
- normalized_file = (request.file or "").strip()
- page = int(request.page)
- if not normalized_file:
- return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
- if page < 1:
- return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
- now = datetime.now(timezone.utc).replace(tzinfo=None)
- conn = db_conn()
- try:
- with conn.cursor() as cur:
- cur.execute(
- """
- INSERT INTO user_progress (user_id, file_path, page, updated_at)
- VALUES (%s, %s, %s, %s)
- ON DUPLICATE KEY UPDATE page=VALUES(page), updated_at=VALUES(updated_at)
- """,
- (user["id"], normalized_file, page, now),
- )
- cur.execute(
- "UPDATE user SET last_file=%s, last_page=%s, updated_at=%s WHERE id=%s",
- (normalized_file, page, now, user["id"]),
- )
- finally:
- conn.close()
- return JSONResponse(content={"success": True})
- class TextToSpeechRequest(BaseModel):
- user_input: str
- voice: str = "af_heart"
- speed: float = 1.0
- @app.post("/generate")
- async def generate_proxy(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
- async def stream_generator() -> AsyncGenerator[bytes, None]:
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "http://141.140.15.30:8028/generate",
- headers={"Content-Type": "application/json"},
- json={"text": user_input, "voice": request.voice, "speed": request.speed},
- ) as response:
- if response.status != 200:
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
- async for chunk in response.content.iter_any():
- yield chunk
- except Exception as e:
- logger.error(f"generate proxy error: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
- @app.post("/text-to-speech/")
- async def text_to_speech(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
- text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
- if os.path.exists(audio_path):
- with open(audio_path, "rb") as f:
- return Response(content=f.read(), media_type="audio/wav")
- async def audio_generator() -> AsyncGenerator[bytes, None]:
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "http://141.140.15.30:8028/generate",
- headers={"Content-Type": "application/json"},
- json={"text": user_input, "voice": request.voice, "speed": request.speed},
- ) as response:
- if response.status != 200:
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
- buffer = ""
- full_audio = io.BytesIO()
- async for chunk in response.content.iter_any():
- buffer += chunk.decode("utf-8")
- lines = buffer.split("\n")
- buffer = lines[-1]
- for line in lines[:-1]:
- if not line.strip():
- continue
- try:
- data = json.loads(line)
- if data.get("error"):
- raise HTTPException(status_code=500, detail=data["error"])
- audio_b64 = data.get("audio")
- if audio_b64:
- audio_bytes = base64.b64decode(audio_b64)
- full_audio.write(audio_bytes)
- yield audio_bytes
- except json.JSONDecodeError as e:
- logger.error(f"JSON decode error: {str(e)}")
- continue
- if buffer.strip():
- try:
- data = json.loads(buffer)
- if data.get("audio"):
- audio_bytes = base64.b64decode(data["audio"])
- full_audio.write(audio_bytes)
- yield audio_bytes
- except json.JSONDecodeError:
- pass
- full_audio.seek(0)
- with open(audio_path, "wb") as f:
- f.write(full_audio.getvalue())
- except Exception as e:
- logger.error(f"TTS error: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- return StreamingResponse(audio_generator(), media_type="audio/wav")
- MAX_CHUNK_SIZE = 200
- def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
- import re
- sentences = re.split(r"(?<=[.!?]) +", text)
- chunks = []
- current_chunk = ""
- for sentence in sentences:
- if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
- current_chunk += " " + sentence if current_chunk else sentence
- else:
- if current_chunk:
- chunks.append(current_chunk)
- if len(sentence) > max_chunk_size:
- for i in range(0, len(sentence), max_chunk_size):
- chunks.append(sentence[i : i + max_chunk_size])
- current_chunk = ""
- else:
- current_chunk = sentence
- if current_chunk:
- chunks.append(current_chunk)
- return chunks
- async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
- text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
- audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
- if os.path.exists(audio_path):
- with open(audio_path, "rb") as f:
- yield f.read()
- else:
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "http://141.140.15.30:8028/generate",
- headers={"Content-Type": "application/json"},
- json={"text": chunk, "voice": voice, "speed": speed},
- ) as response:
- if response.status != 200:
- raise HTTPException(status_code=500, detail="TTS API 请求失败")
- buffer = ""
- async for part in response.content.iter_any():
- buffer += part.decode("utf-8")
- lines = buffer.split("\n")
- buffer = lines[-1]
- for line in lines[:-1]:
- if not line.strip():
- continue
- try:
- data = json.loads(line)
- if data.get("error"):
- raise HTTPException(status_code=500, detail=data["error"])
- audio_b64 = data.get("audio")
- if audio_b64:
- audio_bytes = base64.b64decode(audio_b64)
- yield audio_bytes
- with open(audio_path, "wb") as f:
- f.write(audio_bytes)
- except json.JSONDecodeError as e:
- logger.error(f"JSON decode error: {str(e)}")
- continue
- if buffer.strip():
- try:
- data = json.loads(buffer)
- if data.get("audio"):
- audio_bytes = base64.b64decode(data["audio"])
- yield audio_bytes
- with open(audio_path, "wb") as f:
- f.write(audio_bytes)
- except json.JSONDecodeError:
- pass
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
- @app.post("/page-to-speech/")
- async def page_to_speech(request: TextToSpeechRequest):
- user_input = request.user_input.strip()
- if not user_input:
- raise HTTPException(status_code=400, detail="输入文本为空")
- full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
- full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
- if os.path.exists(full_audio_path):
- return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
- chunks = split_text_into_chunks(user_input)
- async def audio_generator() -> AsyncGenerator[bytes, None]:
- full_audio_buffer = io.BytesIO()
- for chunk in chunks:
- async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
- yield audio_data
- full_audio_buffer.write(audio_data)
- await asyncio.sleep(0)
- full_audio_buffer.seek(0)
- with open(full_audio_path, "wb") as f:
- f.write(full_audio_buffer.getvalue())
- return StreamingResponse(audio_generator(), media_type="audio/wav")
- @app.get("/health")
- async def health_check():
- return {"status": "healthy"}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8005)
|