|
@@ -1,17 +1,25 @@
|
|
|
-from fastapi import FastAPI, Request, File, UploadFile, HTTPException, Form, Response
|
|
|
|
|
|
|
+from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
|
|
|
|
|
+from fastapi.encoders import jsonable_encoder
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
-from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
|
|
|
|
|
|
+from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse, HTMLResponse
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
|
+
|
|
|
import os
|
|
import os
|
|
|
import shutil
|
|
import shutil
|
|
|
import hashlib
|
|
import hashlib
|
|
|
import asyncio
|
|
import asyncio
|
|
|
-from typing import AsyncGenerator
|
|
|
|
|
|
|
+from typing import AsyncGenerator, Optional
|
|
|
import aiohttp
|
|
import aiohttp
|
|
|
import io
|
|
import io
|
|
|
import logging
|
|
import logging
|
|
|
import base64
|
|
import base64
|
|
|
import json
|
|
import json
|
|
|
|
|
+from datetime import datetime, timedelta, timezone
|
|
|
|
|
+import secrets
|
|
|
|
|
+import pymysql
|
|
|
|
|
+
|
|
|
|
|
+from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
|
|
|
|
|
|
|
|
# Set up logging
|
|
# Set up logging
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logging.basicConfig(level=logging.INFO)
|
|
@@ -30,13 +38,12 @@ app.add_middleware(
|
|
|
allow_headers=["*"],
|
|
allow_headers=["*"],
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-# Directory for uploaded files
|
|
|
|
|
-UPLOAD_DIRECTORY = "static/files"
|
|
|
|
|
-if not os.path.exists(UPLOAD_DIRECTORY):
|
|
|
|
|
- os.makedirs(UPLOAD_DIRECTORY)
|
|
|
|
|
|
|
+# Base directories
|
|
|
|
|
+BASE_STATIC_FILES_DIR = "static/files"
|
|
|
|
|
+os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
# Mount static files
|
|
# Mount static files
|
|
|
-app.mount("/static/files", StaticFiles(directory=UPLOAD_DIRECTORY), name="static_files")
|
|
|
|
|
|
|
+app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
|
|
|
app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
|
|
app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
@@ -44,26 +51,623 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
CACHE_DIR = "audio_cache"
|
|
CACHE_DIR = "audio_cache"
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
-# Root redirect to PDF viewer
|
|
|
|
|
-@app.get("/")
|
|
|
|
|
-def root():
|
|
|
|
|
- return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf")
|
|
|
|
|
|
|
+SESSION_COOKIE = "reader_pro_session"
|
|
|
|
|
+SESSION_TTL_DAYS = 1
|
|
|
|
|
+SESSION_TTL_DAYS_REMEMBER = 30
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def db_conn():
|
|
|
|
|
+ return pymysql.connect(
|
|
|
|
|
+ host=MYSQL_HOST,
|
|
|
|
|
+ port=MYSQL_PORT,
|
|
|
|
|
+ user=MYSQL_USER,
|
|
|
|
|
+ password=MYSQL_PASSWORD,
|
|
|
|
|
+ database=MYSQL_DATABASE,
|
|
|
|
|
+ charset="utf8mb4",
|
|
|
|
|
+ autocommit=True,
|
|
|
|
|
+ cursorclass=pymysql.cursors.DictCursor,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def init_db() -> None:
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ CREATE TABLE IF NOT EXISTS user (
|
|
|
|
|
+ id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
|
|
|
|
+ username VARCHAR(64) NOT NULL UNIQUE,
|
|
|
|
|
+ password_hash VARCHAR(255) NOT NULL,
|
|
|
|
|
+ is_admin TINYINT(1) NOT NULL DEFAULT 0,
|
|
|
|
|
+ is_active TINYINT(1) NOT NULL DEFAULT 1,
|
|
|
|
|
+ session_token VARCHAR(128) NULL,
|
|
|
|
|
+ session_expires_at DATETIME NULL,
|
|
|
|
|
+ last_file VARCHAR(1024) NULL,
|
|
|
|
|
+ last_page INT NULL,
|
|
|
|
|
+ created_at DATETIME NOT NULL,
|
|
|
|
|
+ updated_at DATETIME NOT NULL
|
|
|
|
|
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+ # backfill old schema without is_active column
|
|
|
|
|
+ cur.execute("SHOW COLUMNS FROM user LIKE 'is_active'")
|
|
|
|
|
+ if not cur.fetchone():
|
|
|
|
|
+ cur.execute("ALTER TABLE user ADD COLUMN is_active TINYINT(1) NOT NULL DEFAULT 1 AFTER is_admin")
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ CREATE TABLE IF NOT EXISTS user_progress (
|
|
|
|
|
+ id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
|
|
|
|
+ user_id BIGINT NOT NULL,
|
|
|
|
|
+ file_path VARCHAR(512) NOT NULL,
|
|
|
|
|
+ page INT NOT NULL,
|
|
|
|
|
+ updated_at DATETIME NOT NULL,
|
|
|
|
|
+ UNIQUE KEY uniq_user_file (user_id, file_path),
|
|
|
|
|
+ KEY idx_user_updated (user_id, updated_at),
|
|
|
|
|
+ CONSTRAINT fk_user_progress_user
|
|
|
|
|
+ FOREIGN KEY (user_id) REFERENCES user(id)
|
|
|
|
|
+ ON DELETE CASCADE
|
|
|
|
|
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ CREATE TABLE IF NOT EXISTS user_config (
|
|
|
|
|
+ config_key VARCHAR(128) PRIMARY KEY,
|
|
|
|
|
+ config_value TEXT NULL,
|
|
|
|
|
+ updated_at DATETIME NOT NULL
|
|
|
|
|
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
|
|
|
+ """
|
|
|
|
|
+ )
|
|
|
|
|
+ # seed admin
|
|
|
|
|
+ cur.execute("SELECT id FROM user WHERE username=%s", ("admin",))
|
|
|
|
|
+ admin = cur.fetchone()
|
|
|
|
|
+ if not admin:
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at)
|
|
|
|
|
+ VALUES (%s, %s, 1, 1, %s, %s)
|
|
|
|
|
+ """,
|
|
|
|
|
+ ("admin", hash_password("admin"), now, now),
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info("Seeded default admin account: admin/admin")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def hash_password(password: str) -> str:
|
|
|
|
|
+ # Keep simple deterministic hash for compatibility; can migrate to bcrypt later.
|
|
|
|
|
+ return hashlib.sha256(password.encode("utf-8")).hexdigest()
|
|
|
|
|
+
|
|
|
|
|
|
|
|
-# Sanitize filename
|
|
|
|
|
def sanitize_filename(name: str) -> str:
|
|
def sanitize_filename(name: str) -> str:
|
|
|
- return "".join(c for c in name if c.isalnum() or c in (' ', '.', '_', '-')).rstrip()
|
|
|
|
|
|
|
+ return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_user_dir(username: str) -> str:
|
|
|
|
|
+ safe = sanitize_filename(username) or "user"
|
|
|
|
|
+ path = os.path.join(BASE_STATIC_FILES_DIR, safe)
|
|
|
|
|
+ os.makedirs(path, exist_ok=True)
|
|
|
|
|
+ return path
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_user_file_url(username: str, filename: str) -> str:
|
|
|
|
|
+ return f"/static/files/{sanitize_filename(username)}/{filename}"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_user_by_session(session_token: Optional[str]):
|
|
|
|
|
+ if not session_token:
|
|
|
|
|
+ return None
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ SELECT * FROM user
|
|
|
|
|
+ WHERE session_token=%s AND session_expires_at IS NOT NULL AND session_expires_at>%s
|
|
|
|
|
+ """,
|
|
|
|
|
+ (session_token, now),
|
|
|
|
|
+ )
|
|
|
|
|
+ return cur.fetchone()
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def require_user(session_token: Optional[str]):
|
|
|
|
|
+ user = get_user_by_session(session_token)
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ raise HTTPException(status_code=401, detail="未登录或会话已过期")
|
|
|
|
|
+ return user
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def require_admin(session_token: Optional[str]):
|
|
|
|
|
+ user = require_user(session_token)
|
|
|
|
|
+ if not user.get("is_admin"):
|
|
|
|
|
+ raise HTTPException(status_code=403, detail="需要管理员权限")
|
|
|
|
|
+ return user
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def set_session_for_user(username: str, remember_me: bool):
|
|
|
|
|
+ token = secrets.token_urlsafe(48)
|
|
|
|
|
+ expire_days = SESSION_TTL_DAYS_REMEMBER if remember_me else SESSION_TTL_DAYS
|
|
|
|
|
+ expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=expire_days)
|
|
|
|
|
+
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ UPDATE user
|
|
|
|
|
+ SET session_token=%s, session_expires_at=%s, updated_at=%s
|
|
|
|
|
+ WHERE username=%s
|
|
|
|
|
+ """,
|
|
|
|
|
+ (token, expires_at, datetime.now(timezone.utc).replace(tzinfo=None), username),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ return token, expire_days
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.on_event("startup")
|
|
|
|
|
+def on_startup():
|
|
|
|
|
+ init_db()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/")
|
|
|
|
|
+def root(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = get_user_by_session(session_token)
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ return RedirectResponse(url="/login", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+ last_file = user.get("last_file")
|
|
|
|
|
+ if last_file:
|
|
|
|
|
+ return RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+ # fallback: first file in user's own directory
|
|
|
|
|
+ user_dir = get_user_dir(user["username"])
|
|
|
|
|
+ files = sorted([f for f in os.listdir(user_dir) if f.lower().endswith(".pdf")])
|
|
|
|
|
+ if files:
|
|
|
|
|
+ file_url = build_user_file_url(user["username"], files[0])
|
|
|
|
|
+ return RedirectResponse(url=f"/static/web/viewer.html?file={file_url}", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+ # fallback to sample document
|
|
|
|
|
+ return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/login")
|
|
|
|
|
+def login_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = get_user_by_session(session_token)
|
|
|
|
|
+ if user:
|
|
|
|
|
+ return RedirectResponse(url="/", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+ html = """
|
|
|
|
|
+<!doctype html>
|
|
|
|
|
+<html lang="zh-CN">
|
|
|
|
|
+<head>
|
|
|
|
|
+ <meta charset="utf-8"/>
|
|
|
|
|
+ <meta name="viewport" content="width=device-width, initial-scale=1"/>
|
|
|
|
|
+ <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 登录</title>
|
|
|
|
|
+ <style>
|
|
|
|
|
+ *{box-sizing:border-box}
|
|
|
|
|
+ body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
|
|
|
|
|
+ .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
|
|
|
|
|
+ .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
|
|
|
|
|
+ .sub{margin:0 0 18px;font-size:13px;color:#64748b}
|
|
|
|
|
+ input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
|
|
|
|
|
+ input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
|
|
|
|
|
+ input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
|
|
|
|
|
+ button{margin-top:12px;border:none;background:#2563eb;color:#fff;font-weight:600;cursor:pointer}
|
|
|
|
|
+ button:hover{background:#1d4ed8}
|
|
|
|
|
+ .row{display:flex;gap:8px;align-items:center;margin-top:10px;font-size:13px;color:#475569}
|
|
|
|
|
+ .row input{width:auto;margin:0}
|
|
|
|
|
+ .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
|
|
|
|
|
+ .links{margin-top:8px;text-align:right}
|
|
|
|
|
+ .links a{font-size:13px;color:#2563eb;text-decoration:none}
|
|
|
|
|
+ .links a:hover{text-decoration:underline}
|
|
|
|
|
+ .hint{margin-top:12px;padding:10px;border-radius:10px;background:#f8fafc;color:#475569;font-size:12px}
|
|
|
|
|
+ </style>
|
|
|
|
|
+</head>
|
|
|
|
|
+<body>
|
|
|
|
|
+ <div class="card">
|
|
|
|
|
+ <h1 class="title">VoiceFlow AI Reader</h1>
|
|
|
|
|
+ <p class="sub">【小满TTS英文听书】</p>
|
|
|
|
|
+ <input id="loginUser" placeholder="用户名" autocomplete="username"/>
|
|
|
|
|
+ <input id="loginPass" type="password" placeholder="密码" autocomplete="current-password"/>
|
|
|
|
|
+ <label class="row"><input id="remember" type="checkbox"/>30天内免登录</label>
|
|
|
|
|
+ <button onclick="login()">登录</button>
|
|
|
|
|
+ <div id="loginMsg" class="msg"></div>
|
|
|
|
|
+ <div class="links"><a href="/register">没有账号?去注册</a></div>
|
|
|
|
|
+ <div class="hint">管理员默认账号:admin / admin</div>
|
|
|
|
|
+ </div>
|
|
|
|
|
+<script>
|
|
|
|
|
+async function login(){
|
|
|
|
|
+ const username=document.getElementById('loginUser').value.trim();
|
|
|
|
|
+ const password=document.getElementById('loginPass').value;
|
|
|
|
|
+ const remember_me=document.getElementById('remember').checked;
|
|
|
|
|
+ const msg=document.getElementById('loginMsg');
|
|
|
|
|
+ msg.textContent='';
|
|
|
|
|
+ const r=await fetch('/auth/login',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password,remember_me})});
|
|
|
|
|
+ const d=await r.json();
|
|
|
|
|
+ if(!r.ok||!d.success){msg.textContent=d.error||'登录失败';return;}
|
|
|
|
|
+ window.location.href='/';
|
|
|
|
|
+}
|
|
|
|
|
+</script>
|
|
|
|
|
+</body>
|
|
|
|
|
+</html>
|
|
|
|
|
+"""
|
|
|
|
|
+ return HTMLResponse(content=html)
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/register")
|
|
|
|
|
+def register_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = get_user_by_session(session_token)
|
|
|
|
|
+ if user:
|
|
|
|
|
+ return RedirectResponse(url="/", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+ html = """
|
|
|
|
|
+<!doctype html>
|
|
|
|
|
+<html lang="zh-CN">
|
|
|
|
|
+<head>
|
|
|
|
|
+ <meta charset="utf-8"/>
|
|
|
|
|
+ <meta name="viewport" content="width=device-width, initial-scale=1"/>
|
|
|
|
|
+ <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 注册</title>
|
|
|
|
|
+ <style>
|
|
|
|
|
+ *{box-sizing:border-box}
|
|
|
|
|
+ body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
|
|
|
|
|
+ .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
|
|
|
|
|
+ .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
|
|
|
|
|
+ .sub{margin:0 0 18px;font-size:13px;color:#64748b}
|
|
|
|
|
+ input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
|
|
|
|
|
+ input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
|
|
|
|
|
+ input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
|
|
|
|
|
+ button{margin-top:12px;border:none;background:#16a34a;color:#fff;font-weight:600;cursor:pointer}
|
|
|
|
|
+ button:hover{background:#15803d}
|
|
|
|
|
+ .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
|
|
|
|
|
+ .ok{color:#0f7b0f}
|
|
|
|
|
+ .links{margin-top:8px;text-align:right}
|
|
|
|
|
+ .links a{font-size:13px;color:#2563eb;text-decoration:none}
|
|
|
|
|
+ .links a:hover{text-decoration:underline}
|
|
|
|
|
+ </style>
|
|
|
|
|
+</head>
|
|
|
|
|
+<body>
|
|
|
|
|
+ <div class="card">
|
|
|
|
|
+ <h1 class="title">创建新账号</h1>
|
|
|
|
|
+ <p class="sub">VoiceFlow AI Reader 【小满TTS英文听书】</p>
|
|
|
|
|
+ <input id="regUser" placeholder="用户名(至少3位)" autocomplete="username"/>
|
|
|
|
|
+ <input id="regPass" type="password" placeholder="密码(至少4位)" autocomplete="new-password"/>
|
|
|
|
|
+ <button onclick="registerUser()">注册</button>
|
|
|
|
|
+ <div id="regMsg" class="msg"></div>
|
|
|
|
|
+ <div class="links"><a href="/login">已有账号?去登录</a></div>
|
|
|
|
|
+ </div>
|
|
|
|
|
+<script>
|
|
|
|
|
+async function registerUser(){
|
|
|
|
|
+ const username=document.getElementById('regUser').value.trim();
|
|
|
|
|
+ const password=document.getElementById('regPass').value;
|
|
|
|
|
+ const msg=document.getElementById('regMsg');
|
|
|
|
|
+ msg.textContent='';
|
|
|
|
|
+ const r=await fetch('/auth/register',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password})});
|
|
|
|
|
+ const d=await r.json();
|
|
|
|
|
+ if(!r.ok||!d.success){msg.textContent=d.error||'注册失败';return;}
|
|
|
|
|
+ msg.textContent='注册成功,2秒后跳转登录页';
|
|
|
|
|
+ msg.className='msg ok';
|
|
|
|
|
+ setTimeout(()=>{window.location.href='/login';},2000);
|
|
|
|
|
+}
|
|
|
|
|
+</script>
|
|
|
|
|
+</body>
|
|
|
|
|
+</html>
|
|
|
|
|
+"""
|
|
|
|
|
+ return HTMLResponse(content=html)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LoginRequest(BaseModel):
|
|
|
|
|
+ username: str
|
|
|
|
|
+ password: str
|
|
|
|
|
+ remember_me: bool = False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RegisterRequest(BaseModel):
|
|
|
|
|
+ username: str
|
|
|
|
|
+ password: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/auth/register")
|
|
|
|
|
+async def auth_register(request: RegisterRequest):
|
|
|
|
|
+ username = sanitize_filename((request.username or "").strip())
|
|
|
|
|
+ password = (request.password or "").strip()
|
|
|
|
|
+ if len(username) < 3:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "用户名至少3位"})
|
|
|
|
|
+ if len(password) < 4:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
|
|
|
|
|
+
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute("SELECT id FROM user WHERE username=%s", (username,))
|
|
|
|
|
+ if cur.fetchone():
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "用户名已存在"})
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ INSERT INTO user (username, password_hash, is_admin, created_at, updated_at)
|
|
|
|
|
+ VALUES (%s, %s, 0, %s, %s)
|
|
|
|
|
+ """,
|
|
|
|
|
+ (username, hash_password(password), now, now),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ get_user_dir(username)
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/auth/login")
|
|
|
|
|
+async def auth_login(request: LoginRequest):
|
|
|
|
|
+ username = sanitize_filename((request.username or "").strip())
|
|
|
|
|
+ password = (request.password or "").strip()
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute("SELECT * FROM user WHERE username=%s", (username,))
|
|
|
|
|
+ user = cur.fetchone()
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ if not user or user["password_hash"] != hash_password(password):
|
|
|
|
|
+ return JSONResponse(status_code=401, content={"success": False, "error": "用户名或密码错误"})
|
|
|
|
|
+ if int(user.get("is_active", 1)) != 1:
|
|
|
|
|
+ return JSONResponse(status_code=403, content={"success": False, "error": "账号已被禁用"})
|
|
|
|
|
+
|
|
|
|
|
+ token, expire_days = set_session_for_user(username, request.remember_me)
|
|
|
|
|
+ resp = JSONResponse(content={"success": True, "is_admin": bool(user.get("is_admin"))})
|
|
|
|
|
+ max_age = expire_days * 24 * 3600
|
|
|
|
|
+ resp.set_cookie(
|
|
|
|
|
+ key=SESSION_COOKIE,
|
|
|
|
|
+ value=token,
|
|
|
|
|
+ max_age=max_age,
|
|
|
|
|
+ httponly=True,
|
|
|
|
|
+ samesite="lax",
|
|
|
|
|
+ secure=False,
|
|
|
|
|
+ path="/",
|
|
|
|
|
+ )
|
|
|
|
|
+ return resp
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/auth/logout")
|
|
|
|
|
+async def auth_logout(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = get_user_by_session(session_token)
|
|
|
|
|
+ if user:
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "UPDATE user SET session_token=NULL, session_expires_at=NULL, updated_at=%s WHERE id=%s",
|
|
|
|
|
+ (datetime.now(timezone.utc).replace(tzinfo=None), user["id"]),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ resp = JSONResponse(content={"success": True})
|
|
|
|
|
+ resp.delete_cookie(SESSION_COOKIE, path="/")
|
|
|
|
|
+ return resp
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/auth/me")
|
|
|
|
|
+async def auth_me(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = get_user_by_session(session_token)
|
|
|
|
|
+ if not user:
|
|
|
|
|
+ return JSONResponse(status_code=401, content={"success": False})
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={
|
|
|
|
|
+ "success": True,
|
|
|
|
|
+ "username": user["username"],
|
|
|
|
|
+ "is_admin": bool(user.get("is_admin")),
|
|
|
|
|
+ "is_active": bool(user.get("is_active", 1)),
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/admin")
|
|
|
|
|
+def admin_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ return RedirectResponse(url="/static/web/admin.html", status_code=302)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AdminUserRequest(BaseModel):
|
|
|
|
|
+ username: str
|
|
|
|
|
+ password: str
|
|
|
|
|
+ is_admin: bool = False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/admin/users")
|
|
|
|
|
+async def admin_users(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute("SELECT id, username, is_admin, is_active, created_at FROM user ORDER BY id ASC")
|
|
|
|
|
+ users = cur.fetchall()
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ return JSONResponse(content=jsonable_encoder({"success": True, "users": users}))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/admin/users")
|
|
|
|
|
+async def admin_create_or_reset_user(payload: AdminUserRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ username = sanitize_filename((payload.username or "").strip())
|
|
|
|
|
+ password = (payload.password or "").strip()
|
|
|
|
|
+ if len(username) < 3 or len(password) < 4:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "用户名或密码不合法"})
|
|
|
|
|
+
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute("SELECT id FROM user WHERE username=%s", (username,))
|
|
|
|
|
+ row = cur.fetchone()
|
|
|
|
|
+ if row:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "UPDATE user SET password_hash=%s, is_admin=%s, is_active=1, updated_at=%s WHERE username=%s",
|
|
|
|
|
+ (hash_password(password), 1 if payload.is_admin else 0, now, username),
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at) VALUES (%s,%s,%s,1,%s,%s)",
|
|
|
|
|
+ (username, hash_password(password), 1 if payload.is_admin else 0, now, now),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ get_user_dir(username)
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.delete("/admin/users/{username}")
|
|
|
|
|
+async def admin_delete_user(
|
|
|
|
|
+ username: str,
|
|
|
|
|
+ delete_files: bool = False,
|
|
|
|
|
+ session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
|
|
|
|
|
+):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ if username == "admin":
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "不能删除admin"})
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute("DELETE FROM user WHERE username=%s", (username,))
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ if delete_files:
|
|
|
|
|
+ user_dir = get_user_dir(username)
|
|
|
|
|
+ if os.path.isdir(user_dir):
|
|
|
|
|
+ shutil.rmtree(user_dir, ignore_errors=True)
|
|
|
|
|
+
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AdminResetPasswordRequest(BaseModel):
|
|
|
|
|
+ password: str
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/admin/users/{username}/reset-password")
|
|
|
|
|
+async def admin_reset_password(
|
|
|
|
|
+ username: str,
|
|
|
|
|
+ payload: AdminResetPasswordRequest,
|
|
|
|
|
+ session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
|
|
|
|
|
+):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ password = (payload.password or "").strip()
|
|
|
|
|
+ if len(password) < 4:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
|
|
|
|
|
+
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "UPDATE user SET password_hash=%s, updated_at=%s WHERE username=%s",
|
|
|
|
|
+ (hash_password(password), now, username),
|
|
|
|
|
+ )
|
|
|
|
|
+ if cur.rowcount == 0:
|
|
|
|
|
+ return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AdminToggleUserRequest(BaseModel):
|
|
|
|
|
+ is_active: bool
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/admin/users/{username}/status")
|
|
|
|
|
+async def admin_toggle_user_status(
|
|
|
|
|
+ username: str,
|
|
|
|
|
+ payload: AdminToggleUserRequest,
|
|
|
|
|
+ session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
|
|
|
|
|
+):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ if username == "admin" and not payload.is_active:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "不能禁用admin"})
|
|
|
|
|
+
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "UPDATE user SET is_active=%s, updated_at=%s WHERE username=%s",
|
|
|
|
|
+ (1 if payload.is_active else 0, now, username),
|
|
|
|
|
+ )
|
|
|
|
|
+ if cur.rowcount == 0:
|
|
|
|
|
+ return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AdminConfigRequest(BaseModel):
|
|
|
|
|
+ config_key: str
|
|
|
|
|
+ config_value: Optional[str] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/admin/config")
|
|
|
|
|
+async def admin_get_config(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute("SELECT config_key, config_value, updated_at FROM user_config ORDER BY config_key")
|
|
|
|
|
+ rows = cur.fetchall()
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+ return JSONResponse(content=jsonable_encoder({"success": True, "configs": rows}))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/admin/config")
|
|
|
|
|
+async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ require_admin(session_token)
|
|
|
|
|
+ config_key = (payload.config_key or "").strip()
|
|
|
|
|
+ if not config_key:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "config_key 不能为空"})
|
|
|
|
|
+
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ INSERT INTO user_config (config_key, config_value, updated_at)
|
|
|
|
|
+ VALUES (%s, %s, %s)
|
|
|
|
|
+ ON DUPLICATE KEY UPDATE config_value=VALUES(config_value), updated_at=VALUES(updated_at)
|
|
|
|
|
+ """,
|
|
|
|
|
+ (config_key, payload.config_value, now),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
|
|
|
# PDF upload endpoint
|
|
# PDF upload endpoint
|
|
|
@app.post("/upload-pdf")
|
|
@app.post("/upload-pdf")
|
|
|
-async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...)):
|
|
|
|
|
- if file.content_type != 'application/pdf':
|
|
|
|
|
|
|
+async def upload_pdf(
|
|
|
|
|
+ file: UploadFile = File(...),
|
|
|
|
|
+ custom_name: str = Form(...),
|
|
|
|
|
+ session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
|
|
|
|
|
+):
|
|
|
|
|
+ user = require_user(session_token)
|
|
|
|
|
+ if file.content_type != "application/pdf":
|
|
|
raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
|
|
raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
|
|
|
|
|
+
|
|
|
sanitized_name = sanitize_filename(custom_name)
|
|
sanitized_name = sanitize_filename(custom_name)
|
|
|
if not sanitized_name:
|
|
if not sanitized_name:
|
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
unique_filename = f"{sanitized_name}.pdf"
|
|
unique_filename = f"{sanitized_name}.pdf"
|
|
|
- file_path = os.path.join(UPLOAD_DIRECTORY, unique_filename)
|
|
|
|
|
|
|
+ user_dir = get_user_dir(user["username"])
|
|
|
|
|
+ file_path = os.path.join(user_dir, unique_filename)
|
|
|
|
|
|
|
|
if os.path.exists(file_path):
|
|
if os.path.exists(file_path):
|
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
|
|
@@ -71,79 +675,148 @@ async def upload_pdf(file: UploadFile = File(...), custom_name: str = Form(...))
|
|
|
try:
|
|
try:
|
|
|
with open(file_path, "wb") as buffer:
|
|
with open(file_path, "wb") as buffer:
|
|
|
shutil.copyfileobj(file.file, buffer)
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
- except Exception as e:
|
|
|
|
|
|
|
+ except Exception:
|
|
|
raise HTTPException(status_code=500, detail="上传过程中出错")
|
|
raise HTTPException(status_code=500, detail="上传过程中出错")
|
|
|
finally:
|
|
finally:
|
|
|
file.file.close()
|
|
file.file.close()
|
|
|
|
|
|
|
|
- file_relative_path = f"/static/files/{unique_filename}"
|
|
|
|
|
|
|
+ file_relative_path = build_user_file_url(user["username"], unique_filename)
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "UPDATE user SET last_file=%s, last_page=1, updated_at=%s WHERE id=%s",
|
|
|
|
|
+ (file_relative_path, now, user["id"]),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
return JSONResponse(content={"success": True, "file_path": file_relative_path})
|
|
return JSONResponse(content={"success": True, "file_path": file_relative_path})
|
|
|
|
|
|
|
|
|
|
+
|
|
|
# List PDFs endpoint
|
|
# List PDFs endpoint
|
|
|
@app.get("/list-pdfs")
|
|
@app.get("/list-pdfs")
|
|
|
-async def list_pdfs():
|
|
|
|
|
|
|
+async def list_pdfs(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = require_user(session_token)
|
|
|
try:
|
|
try:
|
|
|
- files = os.listdir(UPLOAD_DIRECTORY)
|
|
|
|
|
|
|
+ user_dir = get_user_dir(user["username"])
|
|
|
|
|
+ files = os.listdir(user_dir)
|
|
|
pdf_files = [
|
|
pdf_files = [
|
|
|
- {"name": file, "url": f"/static/files/{file}"}
|
|
|
|
|
- for file in files if file.lower().endswith(".pdf")
|
|
|
|
|
|
|
+ {"name": file, "url": build_user_file_url(user["username"], file)}
|
|
|
|
|
+ for file in files
|
|
|
|
|
+ if file.lower().endswith(".pdf")
|
|
|
]
|
|
]
|
|
|
|
|
+ pdf_files.sort(key=lambda x: x["name"].lower())
|
|
|
return JSONResponse(content={"success": True, "files": pdf_files})
|
|
return JSONResponse(content={"success": True, "files": pdf_files})
|
|
|
- except Exception as e:
|
|
|
|
|
|
|
+ except Exception:
|
|
|
raise HTTPException(status_code=500, detail="无法获取文件列表")
|
|
raise HTTPException(status_code=500, detail="无法获取文件列表")
|
|
|
|
|
|
|
|
-# Request models
|
|
|
|
|
-from pydantic import BaseModel
|
|
|
|
|
|
|
+
|
|
|
|
|
+class ReadingProgressRequest(BaseModel):
|
|
|
|
|
+ file: str
|
|
|
|
|
+ page: int
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.get("/reading-progress")
|
|
|
|
|
+async def get_reading_progress(file: str, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = require_user(session_token)
|
|
|
|
|
+ normalized_file = (file or "").strip()
|
|
|
|
|
+ if not normalized_file:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
|
|
|
|
|
+
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "SELECT page FROM user_progress WHERE user_id=%s AND file_path=%s",
|
|
|
|
|
+ (user["id"], normalized_file),
|
|
|
|
|
+ )
|
|
|
|
|
+ row = cur.fetchone()
|
|
|
|
|
+ page = row["page"] if row else None
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ return JSONResponse(content={"success": True, "file": normalized_file, "page": page})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@app.post("/reading-progress")
|
|
|
|
|
+async def save_reading_progress(request: ReadingProgressRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
|
|
|
|
|
+ user = require_user(session_token)
|
|
|
|
|
+ normalized_file = (request.file or "").strip()
|
|
|
|
|
+ page = int(request.page)
|
|
|
|
|
+ if not normalized_file:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
|
|
|
|
|
+ if page < 1:
|
|
|
|
|
+ return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
|
|
|
|
|
+
|
|
|
|
|
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
|
|
|
+ conn = db_conn()
|
|
|
|
|
+ try:
|
|
|
|
|
+ with conn.cursor() as cur:
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ """
|
|
|
|
|
+ INSERT INTO user_progress (user_id, file_path, page, updated_at)
|
|
|
|
|
+ VALUES (%s, %s, %s, %s)
|
|
|
|
|
+ ON DUPLICATE KEY UPDATE page=VALUES(page), updated_at=VALUES(updated_at)
|
|
|
|
|
+ """,
|
|
|
|
|
+ (user["id"], normalized_file, page, now),
|
|
|
|
|
+ )
|
|
|
|
|
+ cur.execute(
|
|
|
|
|
+ "UPDATE user SET last_file=%s, last_page=%s, updated_at=%s WHERE id=%s",
|
|
|
|
|
+ (normalized_file, page, now, user["id"]),
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ return JSONResponse(content={"success": True})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class TextToSpeechRequest(BaseModel):
|
|
class TextToSpeechRequest(BaseModel):
|
|
|
user_input: str
|
|
user_input: str
|
|
|
- voice: str = 'af_heart' # Default voice
|
|
|
|
|
- speed: float = 1.0 # Default speed
|
|
|
|
|
|
|
+ voice: str = "af_heart"
|
|
|
|
|
+ speed: float = 1.0
|
|
|
|
|
+
|
|
|
|
|
|
|
|
-# Text-to-speech endpoint (streaming)
|
|
|
|
|
@app.post("/text-to-speech/")
|
|
@app.post("/text-to-speech/")
|
|
|
async def text_to_speech(request: TextToSpeechRequest):
|
|
async def text_to_speech(request: TextToSpeechRequest):
|
|
|
user_input = request.user_input.strip()
|
|
user_input = request.user_input.strip()
|
|
|
if not user_input:
|
|
if not user_input:
|
|
|
raise HTTPException(status_code=400, detail="输入文本为空")
|
|
raise HTTPException(status_code=400, detail="输入文本为空")
|
|
|
-
|
|
|
|
|
- text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
|
|
|
|
|
|
|
+
|
|
|
|
|
+ text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
|
|
|
audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
|
|
audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
|
|
|
|
|
|
|
|
if os.path.exists(audio_path):
|
|
if os.path.exists(audio_path):
|
|
|
with open(audio_path, "rb") as f:
|
|
with open(audio_path, "rb") as f:
|
|
|
return Response(content=f.read(), media_type="audio/wav")
|
|
return Response(content=f.read(), media_type="audio/wav")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
async def audio_generator() -> AsyncGenerator[bytes, None]:
|
|
async def audio_generator() -> AsyncGenerator[bytes, None]:
|
|
|
try:
|
|
try:
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with aiohttp.ClientSession() as session:
|
|
|
async with session.post(
|
|
async with session.post(
|
|
|
- 'http://141.140.15.30:8028/generate',
|
|
|
|
|
- headers={'Content-Type': 'application/json'},
|
|
|
|
|
- json={
|
|
|
|
|
- 'text': user_input,
|
|
|
|
|
- 'voice': request.voice,
|
|
|
|
|
- 'speed': request.speed
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ "http://141.140.15.30:8028/generate",
|
|
|
|
|
+ headers={"Content-Type": "application/json"},
|
|
|
|
|
+ json={"text": user_input, "voice": request.voice, "speed": request.speed},
|
|
|
) as response:
|
|
) as response:
|
|
|
if response.status != 200:
|
|
if response.status != 200:
|
|
|
raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
|
-
|
|
|
|
|
- # Read NDJSON response
|
|
|
|
|
|
|
+
|
|
|
buffer = ""
|
|
buffer = ""
|
|
|
full_audio = io.BytesIO()
|
|
full_audio = io.BytesIO()
|
|
|
async for chunk in response.content.iter_any():
|
|
async for chunk in response.content.iter_any():
|
|
|
- buffer += chunk.decode('utf-8')
|
|
|
|
|
- lines = buffer.split('\n')
|
|
|
|
|
- buffer = lines[-1] # Keep incomplete line
|
|
|
|
|
-
|
|
|
|
|
|
|
+ buffer += chunk.decode("utf-8")
|
|
|
|
|
+ lines = buffer.split("\n")
|
|
|
|
|
+ buffer = lines[-1]
|
|
|
|
|
+
|
|
|
for line in lines[:-1]:
|
|
for line in lines[:-1]:
|
|
|
if not line.strip():
|
|
if not line.strip():
|
|
|
continue
|
|
continue
|
|
|
try:
|
|
try:
|
|
|
data = json.loads(line)
|
|
data = json.loads(line)
|
|
|
- if data.get('error'):
|
|
|
|
|
- raise HTTPException(status_code=500, detail=data['error'])
|
|
|
|
|
- audio_b64 = data.get('audio')
|
|
|
|
|
|
|
+ if data.get("error"):
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=data["error"])
|
|
|
|
|
+ audio_b64 = data.get("audio")
|
|
|
if audio_b64:
|
|
if audio_b64:
|
|
|
audio_bytes = base64.b64decode(audio_b64)
|
|
audio_bytes = base64.b64decode(audio_b64)
|
|
|
full_audio.write(audio_bytes)
|
|
full_audio.write(audio_bytes)
|
|
@@ -151,35 +824,35 @@ async def text_to_speech(request: TextToSpeechRequest):
|
|
|
except json.JSONDecodeError as e:
|
|
except json.JSONDecodeError as e:
|
|
|
logger.error(f"JSON decode error: {str(e)}")
|
|
logger.error(f"JSON decode error: {str(e)}")
|
|
|
continue
|
|
continue
|
|
|
-
|
|
|
|
|
- # Handle final buffer
|
|
|
|
|
|
|
+
|
|
|
if buffer.strip():
|
|
if buffer.strip():
|
|
|
try:
|
|
try:
|
|
|
data = json.loads(buffer)
|
|
data = json.loads(buffer)
|
|
|
- if data.get('audio'):
|
|
|
|
|
- audio_bytes = base64.b64decode(data['audio'])
|
|
|
|
|
|
|
+ if data.get("audio"):
|
|
|
|
|
+ audio_bytes = base64.b64decode(data["audio"])
|
|
|
full_audio.write(audio_bytes)
|
|
full_audio.write(audio_bytes)
|
|
|
yield audio_bytes
|
|
yield audio_bytes
|
|
|
except json.JSONDecodeError:
|
|
except json.JSONDecodeError:
|
|
|
pass
|
|
pass
|
|
|
-
|
|
|
|
|
- # Save to cache
|
|
|
|
|
|
|
+
|
|
|
full_audio.seek(0)
|
|
full_audio.seek(0)
|
|
|
with open(audio_path, "wb") as f:
|
|
with open(audio_path, "wb") as f:
|
|
|
f.write(full_audio.getvalue())
|
|
f.write(full_audio.getvalue())
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"TTS error: {str(e)}")
|
|
logger.error(f"TTS error: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
return StreamingResponse(audio_generator(), media_type="audio/wav")
|
|
return StreamingResponse(audio_generator(), media_type="audio/wav")
|
|
|
|
|
|
|
|
-# Page-to-speech endpoint (chunked streaming)
|
|
|
|
|
|
|
+
|
|
|
MAX_CHUNK_SIZE = 200
|
|
MAX_CHUNK_SIZE = 200
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
|
|
def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
|
|
|
import re
|
|
import re
|
|
|
- sentences = re.split('(?<=[.!?]) +', text)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ sentences = re.split(r"(?<=[.!?]) +", text)
|
|
|
chunks = []
|
|
chunks = []
|
|
|
current_chunk = ""
|
|
current_chunk = ""
|
|
|
for sentence in sentences:
|
|
for sentence in sentences:
|
|
@@ -190,7 +863,7 @@ def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> l
|
|
|
chunks.append(current_chunk)
|
|
chunks.append(current_chunk)
|
|
|
if len(sentence) > max_chunk_size:
|
|
if len(sentence) > max_chunk_size:
|
|
|
for i in range(0, len(sentence), max_chunk_size):
|
|
for i in range(0, len(sentence), max_chunk_size):
|
|
|
- chunks.append(sentence[i:i + max_chunk_size])
|
|
|
|
|
|
|
+ chunks.append(sentence[i : i + max_chunk_size])
|
|
|
current_chunk = ""
|
|
current_chunk = ""
|
|
|
else:
|
|
else:
|
|
|
current_chunk = sentence
|
|
current_chunk = sentence
|
|
@@ -198,10 +871,11 @@ def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> l
|
|
|
chunks.append(current_chunk)
|
|
chunks.append(current_chunk)
|
|
|
return chunks
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
+
|
|
|
async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
|
|
async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
|
|
|
- text_hash = hashlib.md5(chunk.encode('utf-8')).hexdigest()
|
|
|
|
|
|
|
+ text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
|
|
|
audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
|
|
audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if os.path.exists(audio_path):
|
|
if os.path.exists(audio_path):
|
|
|
with open(audio_path, "rb") as f:
|
|
with open(audio_path, "rb") as f:
|
|
|
yield f.read()
|
|
yield f.read()
|
|
@@ -209,91 +883,86 @@ async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGener
|
|
|
try:
|
|
try:
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with aiohttp.ClientSession() as session:
|
|
|
async with session.post(
|
|
async with session.post(
|
|
|
- 'http://141.140.15.30:8028/generate',
|
|
|
|
|
- headers={'Content-Type': 'application/json'},
|
|
|
|
|
- json={
|
|
|
|
|
- 'text': chunk,
|
|
|
|
|
- 'voice': voice,
|
|
|
|
|
- 'speed': speed
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ "http://141.140.15.30:8028/generate",
|
|
|
|
|
+ headers={"Content-Type": "application/json"},
|
|
|
|
|
+ json={"text": chunk, "voice": voice, "speed": speed},
|
|
|
) as response:
|
|
) as response:
|
|
|
if response.status != 200:
|
|
if response.status != 200:
|
|
|
raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
raise HTTPException(status_code=500, detail="TTS API 请求失败")
|
|
|
-
|
|
|
|
|
- # Read NDJSON response
|
|
|
|
|
|
|
+
|
|
|
buffer = ""
|
|
buffer = ""
|
|
|
- async for chunk in response.content.iter_any():
|
|
|
|
|
- buffer += chunk.decode('utf-8')
|
|
|
|
|
- lines = buffer.split('\n')
|
|
|
|
|
|
|
+ async for part in response.content.iter_any():
|
|
|
|
|
+ buffer += part.decode("utf-8")
|
|
|
|
|
+ lines = buffer.split("\n")
|
|
|
buffer = lines[-1]
|
|
buffer = lines[-1]
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
for line in lines[:-1]:
|
|
for line in lines[:-1]:
|
|
|
if not line.strip():
|
|
if not line.strip():
|
|
|
continue
|
|
continue
|
|
|
try:
|
|
try:
|
|
|
data = json.loads(line)
|
|
data = json.loads(line)
|
|
|
- if data.get('error'):
|
|
|
|
|
- raise HTTPException(status_code=500, detail=data['error'])
|
|
|
|
|
- audio_b64 = data.get('audio')
|
|
|
|
|
|
|
+ if data.get("error"):
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=data["error"])
|
|
|
|
|
+ audio_b64 = data.get("audio")
|
|
|
if audio_b64:
|
|
if audio_b64:
|
|
|
audio_bytes = base64.b64decode(audio_b64)
|
|
audio_bytes = base64.b64decode(audio_b64)
|
|
|
yield audio_bytes
|
|
yield audio_bytes
|
|
|
- # Cache the chunk
|
|
|
|
|
with open(audio_path, "wb") as f:
|
|
with open(audio_path, "wb") as f:
|
|
|
f.write(audio_bytes)
|
|
f.write(audio_bytes)
|
|
|
except json.JSONDecodeError as e:
|
|
except json.JSONDecodeError as e:
|
|
|
logger.error(f"JSON decode error: {str(e)}")
|
|
logger.error(f"JSON decode error: {str(e)}")
|
|
|
continue
|
|
continue
|
|
|
-
|
|
|
|
|
- # Handle final buffer
|
|
|
|
|
|
|
+
|
|
|
if buffer.strip():
|
|
if buffer.strip():
|
|
|
try:
|
|
try:
|
|
|
data = json.loads(buffer)
|
|
data = json.loads(buffer)
|
|
|
- if data.get('audio'):
|
|
|
|
|
- audio_bytes = base64.b64decode(data['audio'])
|
|
|
|
|
|
|
+ if data.get("audio"):
|
|
|
|
|
+ audio_bytes = base64.b64decode(data["audio"])
|
|
|
yield audio_bytes
|
|
yield audio_bytes
|
|
|
with open(audio_path, "wb") as f:
|
|
with open(audio_path, "wb") as f:
|
|
|
f.write(audio_bytes)
|
|
f.write(audio_bytes)
|
|
|
except json.JSONDecodeError:
|
|
except json.JSONDecodeError:
|
|
|
pass
|
|
pass
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
+
|
|
|
@app.post("/page-to-speech/")
|
|
@app.post("/page-to-speech/")
|
|
|
async def page_to_speech(request: TextToSpeechRequest):
|
|
async def page_to_speech(request: TextToSpeechRequest):
|
|
|
user_input = request.user_input.strip()
|
|
user_input = request.user_input.strip()
|
|
|
if not user_input:
|
|
if not user_input:
|
|
|
raise HTTPException(status_code=400, detail="输入文本为空")
|
|
raise HTTPException(status_code=400, detail="输入文本为空")
|
|
|
-
|
|
|
|
|
- full_text_hash = hashlib.md5(user_input.encode('utf-8')).hexdigest()
|
|
|
|
|
|
|
+
|
|
|
|
|
+ full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
|
|
|
full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
|
|
full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if os.path.exists(full_audio_path):
|
|
if os.path.exists(full_audio_path):
|
|
|
return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
|
|
return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
chunks = split_text_into_chunks(user_input)
|
|
chunks = split_text_into_chunks(user_input)
|
|
|
|
|
|
|
|
async def audio_generator() -> AsyncGenerator[bytes, None]:
|
|
async def audio_generator() -> AsyncGenerator[bytes, None]:
|
|
|
- full_audio_buffer = io.BytesIO() # For caching full audio
|
|
|
|
|
|
|
+ full_audio_buffer = io.BytesIO()
|
|
|
for chunk in chunks:
|
|
for chunk in chunks:
|
|
|
async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
|
|
async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
|
|
|
- yield audio_data # Stream each chunk's audio
|
|
|
|
|
|
|
+ yield audio_data
|
|
|
full_audio_buffer.write(audio_data)
|
|
full_audio_buffer.write(audio_data)
|
|
|
- await asyncio.sleep(0) # Yield control to event loop
|
|
|
|
|
-
|
|
|
|
|
- # Save the full audio to cache
|
|
|
|
|
|
|
+ await asyncio.sleep(0)
|
|
|
|
|
+
|
|
|
full_audio_buffer.seek(0)
|
|
full_audio_buffer.seek(0)
|
|
|
with open(full_audio_path, "wb") as f:
|
|
with open(full_audio_path, "wb") as f:
|
|
|
f.write(full_audio_buffer.getvalue())
|
|
f.write(full_audio_buffer.getvalue())
|
|
|
|
|
|
|
|
return StreamingResponse(audio_generator(), media_type="audio/wav")
|
|
return StreamingResponse(audio_generator(), media_type="audio/wav")
|
|
|
|
|
|
|
|
-# Health check
|
|
|
|
|
|
|
+
|
|
|
@app.get("/health")
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
async def health_check():
|
|
|
return {"status": "healthy"}
|
|
return {"status": "healthy"}
|
|
|
|
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
import uvicorn
|
|
|
|
|
+
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8005)
|
|
uvicorn.run(app, host="0.0.0.0", port=8005)
|