main_server.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
  2. from fastapi.encoders import jsonable_encoder
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse, HTMLResponse
  5. from fastapi.staticfiles import StaticFiles
  6. from pydantic import BaseModel
  7. import os
  8. import shutil
  9. import hashlib
  10. import asyncio
  11. from typing import AsyncGenerator, Optional
  12. import aiohttp
  13. import io
  14. import logging
  15. import base64
  16. import json
  17. from datetime import datetime, timedelta, timezone
  18. import secrets
  19. import pymysql
  20. from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
  21. # Set up logging
  22. logging.basicConfig(level=logging.INFO)
  23. logger = logging.getLogger(__name__)
  24. # Initialize FastAPI app
  25. app = FastAPI()
  26. # Configure CORS
  27. origins = ["*"]
  28. app.add_middleware(
  29. CORSMiddleware,
  30. allow_origins=origins,
  31. allow_credentials=True,
  32. allow_methods=["*"],
  33. allow_headers=["*"],
  34. )
  35. # Base directories
  36. BASE_STATIC_FILES_DIR = "static/files"
  37. os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
  38. # Mount static files
  39. app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
  40. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  41. app.mount("/static", StaticFiles(directory="static"), name="static")
  42. # Audio cache directory
  43. CACHE_DIR = "audio_cache"
  44. os.makedirs(CACHE_DIR, exist_ok=True)
  45. SESSION_COOKIE = "reader_pro_session"
  46. SESSION_TTL_DAYS = 1
  47. SESSION_TTL_DAYS_REMEMBER = 30
  48. def db_conn():
  49. return pymysql.connect(
  50. host=MYSQL_HOST,
  51. port=MYSQL_PORT,
  52. user=MYSQL_USER,
  53. password=MYSQL_PASSWORD,
  54. database=MYSQL_DATABASE,
  55. charset="utf8mb4",
  56. autocommit=True,
  57. cursorclass=pymysql.cursors.DictCursor,
  58. )
  59. def init_db() -> None:
  60. conn = db_conn()
  61. try:
  62. with conn.cursor() as cur:
  63. cur.execute(
  64. """
  65. CREATE TABLE IF NOT EXISTS user (
  66. id BIGINT PRIMARY KEY AUTO_INCREMENT,
  67. username VARCHAR(64) NOT NULL UNIQUE,
  68. password_hash VARCHAR(255) NOT NULL,
  69. is_admin TINYINT(1) NOT NULL DEFAULT 0,
  70. is_active TINYINT(1) NOT NULL DEFAULT 1,
  71. session_token VARCHAR(128) NULL,
  72. session_expires_at DATETIME NULL,
  73. last_file VARCHAR(1024) NULL,
  74. last_page INT NULL,
  75. created_at DATETIME NOT NULL,
  76. updated_at DATETIME NOT NULL
  77. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  78. """
  79. )
  80. # backfill old schema without is_active column
  81. cur.execute("SHOW COLUMNS FROM user LIKE 'is_active'")
  82. if not cur.fetchone():
  83. cur.execute("ALTER TABLE user ADD COLUMN is_active TINYINT(1) NOT NULL DEFAULT 1 AFTER is_admin")
  84. cur.execute(
  85. """
  86. CREATE TABLE IF NOT EXISTS user_progress (
  87. id BIGINT PRIMARY KEY AUTO_INCREMENT,
  88. user_id BIGINT NOT NULL,
  89. file_path VARCHAR(512) NOT NULL,
  90. page INT NOT NULL,
  91. updated_at DATETIME NOT NULL,
  92. UNIQUE KEY uniq_user_file (user_id, file_path),
  93. KEY idx_user_updated (user_id, updated_at),
  94. CONSTRAINT fk_user_progress_user
  95. FOREIGN KEY (user_id) REFERENCES user(id)
  96. ON DELETE CASCADE
  97. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  98. """
  99. )
  100. cur.execute(
  101. """
  102. CREATE TABLE IF NOT EXISTS user_config (
  103. config_key VARCHAR(128) PRIMARY KEY,
  104. config_value TEXT NULL,
  105. updated_at DATETIME NOT NULL
  106. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  107. """
  108. )
  109. # seed admin
  110. cur.execute("SELECT id FROM user WHERE username=%s", ("admin",))
  111. admin = cur.fetchone()
  112. if not admin:
  113. now = datetime.now(timezone.utc).replace(tzinfo=None)
  114. cur.execute(
  115. """
  116. INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at)
  117. VALUES (%s, %s, 1, 1, %s, %s)
  118. """,
  119. ("admin", hash_password("admin"), now, now),
  120. )
  121. logger.info("Seeded default admin account: admin/admin")
  122. finally:
  123. conn.close()
  124. def hash_password(password: str) -> str:
  125. # Keep simple deterministic hash for compatibility; can migrate to bcrypt later.
  126. return hashlib.sha256(password.encode("utf-8")).hexdigest()
  127. def sanitize_filename(name: str) -> str:
  128. return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
  129. def get_user_dir(username: str) -> str:
  130. safe = sanitize_filename(username) or "user"
  131. path = os.path.join(BASE_STATIC_FILES_DIR, safe)
  132. os.makedirs(path, exist_ok=True)
  133. return path
  134. def build_user_file_url(username: str, filename: str) -> str:
  135. return f"/static/files/{sanitize_filename(username)}/{filename}"
  136. def get_user_by_session(session_token: Optional[str]):
  137. if not session_token:
  138. return None
  139. now = datetime.now(timezone.utc).replace(tzinfo=None)
  140. conn = db_conn()
  141. try:
  142. with conn.cursor() as cur:
  143. cur.execute(
  144. """
  145. SELECT * FROM user
  146. WHERE session_token=%s AND session_expires_at IS NOT NULL AND session_expires_at>%s
  147. """,
  148. (session_token, now),
  149. )
  150. return cur.fetchone()
  151. finally:
  152. conn.close()
  153. def require_user(session_token: Optional[str]):
  154. user = get_user_by_session(session_token)
  155. if not user:
  156. raise HTTPException(status_code=401, detail="未登录或会话已过期")
  157. return user
  158. def require_admin(session_token: Optional[str]):
  159. user = require_user(session_token)
  160. if not user.get("is_admin"):
  161. raise HTTPException(status_code=403, detail="需要管理员权限")
  162. return user
  163. def set_session_for_user(username: str, remember_me: bool):
  164. token = secrets.token_urlsafe(48)
  165. expire_days = SESSION_TTL_DAYS_REMEMBER if remember_me else SESSION_TTL_DAYS
  166. expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=expire_days)
  167. conn = db_conn()
  168. try:
  169. with conn.cursor() as cur:
  170. cur.execute(
  171. """
  172. UPDATE user
  173. SET session_token=%s, session_expires_at=%s, updated_at=%s
  174. WHERE username=%s
  175. """,
  176. (token, expires_at, datetime.now(timezone.utc).replace(tzinfo=None), username),
  177. )
  178. finally:
  179. conn.close()
  180. return token, expire_days
  181. @app.on_event("startup")
  182. def on_startup():
  183. init_db()
  184. @app.get("/")
  185. def root(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  186. user = get_user_by_session(session_token)
  187. if not user:
  188. return RedirectResponse(url="/login", status_code=302)
  189. last_file = user.get("last_file")
  190. if last_file:
  191. return RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
  192. # fallback: first file in user's own directory
  193. user_dir = get_user_dir(user["username"])
  194. files = sorted([f for f in os.listdir(user_dir) if f.lower().endswith(".pdf")])
  195. if files:
  196. file_url = build_user_file_url(user["username"], files[0])
  197. return RedirectResponse(url=f"/static/web/viewer.html?file={file_url}", status_code=302)
  198. # fallback to sample document
  199. return RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
  200. @app.get("/login")
  201. def login_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  202. user = get_user_by_session(session_token)
  203. if user:
  204. return RedirectResponse(url="/", status_code=302)
  205. html = """
  206. <!doctype html>
  207. <html lang="zh-CN">
  208. <head>
  209. <meta charset="utf-8"/>
  210. <meta name="viewport" content="width=device-width, initial-scale=1"/>
  211. <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 登录</title>
  212. <style>
  213. *{box-sizing:border-box}
  214. body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
  215. .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
  216. .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
  217. .sub{margin:0 0 18px;font-size:13px;color:#64748b}
  218. input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
  219. input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
  220. input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
  221. button{margin-top:12px;border:none;background:#2563eb;color:#fff;font-weight:600;cursor:pointer}
  222. button:hover{background:#1d4ed8}
  223. .row{display:flex;gap:8px;align-items:center;margin-top:10px;font-size:13px;color:#475569}
  224. .row input{width:auto;margin:0}
  225. .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
  226. .links{margin-top:8px;text-align:right}
  227. .links a{font-size:13px;color:#2563eb;text-decoration:none}
  228. .links a:hover{text-decoration:underline}
  229. .hint{margin-top:12px;padding:10px;border-radius:10px;background:#f8fafc;color:#475569;font-size:12px}
  230. </style>
  231. </head>
  232. <body>
  233. <div class="card">
  234. <h1 class="title">VoiceFlow AI Reader</h1>
  235. <p class="sub">【小满TTS英文听书】</p>
  236. <input id="loginUser" placeholder="用户名" autocomplete="username"/>
  237. <input id="loginPass" type="password" placeholder="密码" autocomplete="current-password"/>
  238. <label class="row"><input id="remember" type="checkbox"/>30天内免登录</label>
  239. <button onclick="login()">登录</button>
  240. <div id="loginMsg" class="msg"></div>
  241. <div class="links"><a href="/register">没有账号?去注册</a></div>
  242. <div class="hint">管理员默认账号:admin / admin</div>
  243. </div>
  244. <script>
  245. async function login(){
  246. const username=document.getElementById('loginUser').value.trim();
  247. const password=document.getElementById('loginPass').value;
  248. const remember_me=document.getElementById('remember').checked;
  249. const msg=document.getElementById('loginMsg');
  250. msg.textContent='';
  251. const r=await fetch('/auth/login',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password,remember_me})});
  252. const d=await r.json();
  253. if(!r.ok||!d.success){msg.textContent=d.error||'登录失败';return;}
  254. window.location.href='/';
  255. }
  256. </script>
  257. </body>
  258. </html>
  259. """
  260. return HTMLResponse(content=html)
  261. @app.get("/register")
  262. def register_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  263. user = get_user_by_session(session_token)
  264. if user:
  265. return RedirectResponse(url="/", status_code=302)
  266. html = """
  267. <!doctype html>
  268. <html lang="zh-CN">
  269. <head>
  270. <meta charset="utf-8"/>
  271. <meta name="viewport" content="width=device-width, initial-scale=1"/>
  272. <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 注册</title>
  273. <style>
  274. *{box-sizing:border-box}
  275. body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
  276. .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
  277. .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
  278. .sub{margin:0 0 18px;font-size:13px;color:#64748b}
  279. input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
  280. input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
  281. input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
  282. button{margin-top:12px;border:none;background:#16a34a;color:#fff;font-weight:600;cursor:pointer}
  283. button:hover{background:#15803d}
  284. .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
  285. .ok{color:#0f7b0f}
  286. .links{margin-top:8px;text-align:right}
  287. .links a{font-size:13px;color:#2563eb;text-decoration:none}
  288. .links a:hover{text-decoration:underline}
  289. </style>
  290. </head>
  291. <body>
  292. <div class="card">
  293. <h1 class="title">创建新账号</h1>
  294. <p class="sub">VoiceFlow AI Reader 【小满TTS英文听书】</p>
  295. <input id="regUser" placeholder="用户名(至少3位)" autocomplete="username"/>
  296. <input id="regPass" type="password" placeholder="密码(至少4位)" autocomplete="new-password"/>
  297. <button onclick="registerUser()">注册</button>
  298. <div id="regMsg" class="msg"></div>
  299. <div class="links"><a href="/login">已有账号?去登录</a></div>
  300. </div>
  301. <script>
  302. async function registerUser(){
  303. const username=document.getElementById('regUser').value.trim();
  304. const password=document.getElementById('regPass').value;
  305. const msg=document.getElementById('regMsg');
  306. msg.textContent='';
  307. const r=await fetch('/auth/register',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password})});
  308. const d=await r.json();
  309. if(!r.ok||!d.success){msg.textContent=d.error||'注册失败';return;}
  310. msg.textContent='注册成功,2秒后跳转登录页';
  311. msg.className='msg ok';
  312. setTimeout(()=>{window.location.href='/login';},2000);
  313. }
  314. </script>
  315. </body>
  316. </html>
  317. """
  318. return HTMLResponse(content=html)
  319. class LoginRequest(BaseModel):
  320. username: str
  321. password: str
  322. remember_me: bool = False
  323. class RegisterRequest(BaseModel):
  324. username: str
  325. password: str
  326. @app.post("/auth/register")
  327. async def auth_register(request: RegisterRequest):
  328. username = sanitize_filename((request.username or "").strip())
  329. password = (request.password or "").strip()
  330. if len(username) < 3:
  331. return JSONResponse(status_code=400, content={"success": False, "error": "用户名至少3位"})
  332. if len(password) < 4:
  333. return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
  334. now = datetime.now(timezone.utc).replace(tzinfo=None)
  335. conn = db_conn()
  336. try:
  337. with conn.cursor() as cur:
  338. cur.execute("SELECT id FROM user WHERE username=%s", (username,))
  339. if cur.fetchone():
  340. return JSONResponse(status_code=400, content={"success": False, "error": "用户名已存在"})
  341. cur.execute(
  342. """
  343. INSERT INTO user (username, password_hash, is_admin, created_at, updated_at)
  344. VALUES (%s, %s, 0, %s, %s)
  345. """,
  346. (username, hash_password(password), now, now),
  347. )
  348. finally:
  349. conn.close()
  350. get_user_dir(username)
  351. return JSONResponse(content={"success": True})
  352. @app.post("/auth/login")
  353. async def auth_login(request: LoginRequest):
  354. username = sanitize_filename((request.username or "").strip())
  355. password = (request.password or "").strip()
  356. conn = db_conn()
  357. try:
  358. with conn.cursor() as cur:
  359. cur.execute("SELECT * FROM user WHERE username=%s", (username,))
  360. user = cur.fetchone()
  361. finally:
  362. conn.close()
  363. if not user or user["password_hash"] != hash_password(password):
  364. return JSONResponse(status_code=401, content={"success": False, "error": "用户名或密码错误"})
  365. if int(user.get("is_active", 1)) != 1:
  366. return JSONResponse(status_code=403, content={"success": False, "error": "账号已被禁用"})
  367. token, expire_days = set_session_for_user(username, request.remember_me)
  368. resp = JSONResponse(content={"success": True, "is_admin": bool(user.get("is_admin"))})
  369. max_age = expire_days * 24 * 3600
  370. resp.set_cookie(
  371. key=SESSION_COOKIE,
  372. value=token,
  373. max_age=max_age,
  374. httponly=True,
  375. samesite="lax",
  376. secure=False,
  377. path="/",
  378. )
  379. return resp
  380. @app.post("/auth/logout")
  381. async def auth_logout(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  382. user = get_user_by_session(session_token)
  383. if user:
  384. conn = db_conn()
  385. try:
  386. with conn.cursor() as cur:
  387. cur.execute(
  388. "UPDATE user SET session_token=NULL, session_expires_at=NULL, updated_at=%s WHERE id=%s",
  389. (datetime.now(timezone.utc).replace(tzinfo=None), user["id"]),
  390. )
  391. finally:
  392. conn.close()
  393. resp = JSONResponse(content={"success": True})
  394. resp.delete_cookie(SESSION_COOKIE, path="/")
  395. return resp
  396. @app.get("/auth/me")
  397. async def auth_me(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  398. user = get_user_by_session(session_token)
  399. if not user:
  400. return JSONResponse(status_code=401, content={"success": False})
  401. return JSONResponse(
  402. content={
  403. "success": True,
  404. "username": user["username"],
  405. "is_admin": bool(user.get("is_admin")),
  406. "is_active": bool(user.get("is_active", 1)),
  407. }
  408. )
  409. @app.get("/admin")
  410. def admin_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  411. require_admin(session_token)
  412. return RedirectResponse(url="/static/web/admin.html", status_code=302)
  413. class AdminUserRequest(BaseModel):
  414. username: str
  415. password: str
  416. is_admin: bool = False
  417. @app.get("/admin/users")
  418. async def admin_users(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  419. require_admin(session_token)
  420. conn = db_conn()
  421. try:
  422. with conn.cursor() as cur:
  423. cur.execute("SELECT id, username, is_admin, is_active, created_at FROM user ORDER BY id ASC")
  424. users = cur.fetchall()
  425. finally:
  426. conn.close()
  427. return JSONResponse(content=jsonable_encoder({"success": True, "users": users}))
  428. @app.post("/admin/users")
  429. async def admin_create_or_reset_user(payload: AdminUserRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  430. require_admin(session_token)
  431. username = sanitize_filename((payload.username or "").strip())
  432. password = (payload.password or "").strip()
  433. if len(username) < 3 or len(password) < 4:
  434. return JSONResponse(status_code=400, content={"success": False, "error": "用户名或密码不合法"})
  435. now = datetime.now(timezone.utc).replace(tzinfo=None)
  436. conn = db_conn()
  437. try:
  438. with conn.cursor() as cur:
  439. cur.execute("SELECT id FROM user WHERE username=%s", (username,))
  440. row = cur.fetchone()
  441. if row:
  442. cur.execute(
  443. "UPDATE user SET password_hash=%s, is_admin=%s, is_active=1, updated_at=%s WHERE username=%s",
  444. (hash_password(password), 1 if payload.is_admin else 0, now, username),
  445. )
  446. else:
  447. cur.execute(
  448. "INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at) VALUES (%s,%s,%s,1,%s,%s)",
  449. (username, hash_password(password), 1 if payload.is_admin else 0, now, now),
  450. )
  451. finally:
  452. conn.close()
  453. get_user_dir(username)
  454. return JSONResponse(content={"success": True})
  455. @app.delete("/admin/users/{username}")
  456. async def admin_delete_user(
  457. username: str,
  458. delete_files: bool = False,
  459. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  460. ):
  461. require_admin(session_token)
  462. if username == "admin":
  463. return JSONResponse(status_code=400, content={"success": False, "error": "不能删除admin"})
  464. conn = db_conn()
  465. try:
  466. with conn.cursor() as cur:
  467. cur.execute("DELETE FROM user WHERE username=%s", (username,))
  468. finally:
  469. conn.close()
  470. if delete_files:
  471. user_dir = get_user_dir(username)
  472. if os.path.isdir(user_dir):
  473. shutil.rmtree(user_dir, ignore_errors=True)
  474. return JSONResponse(content={"success": True})
  475. class AdminResetPasswordRequest(BaseModel):
  476. password: str
  477. @app.post("/admin/users/{username}/reset-password")
  478. async def admin_reset_password(
  479. username: str,
  480. payload: AdminResetPasswordRequest,
  481. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  482. ):
  483. require_admin(session_token)
  484. password = (payload.password or "").strip()
  485. if len(password) < 4:
  486. return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
  487. now = datetime.now(timezone.utc).replace(tzinfo=None)
  488. conn = db_conn()
  489. try:
  490. with conn.cursor() as cur:
  491. cur.execute(
  492. "UPDATE user SET password_hash=%s, updated_at=%s WHERE username=%s",
  493. (hash_password(password), now, username),
  494. )
  495. if cur.rowcount == 0:
  496. return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
  497. finally:
  498. conn.close()
  499. return JSONResponse(content={"success": True})
  500. class AdminToggleUserRequest(BaseModel):
  501. is_active: bool
  502. @app.post("/admin/users/{username}/status")
  503. async def admin_toggle_user_status(
  504. username: str,
  505. payload: AdminToggleUserRequest,
  506. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  507. ):
  508. require_admin(session_token)
  509. if username == "admin" and not payload.is_active:
  510. return JSONResponse(status_code=400, content={"success": False, "error": "不能禁用admin"})
  511. now = datetime.now(timezone.utc).replace(tzinfo=None)
  512. conn = db_conn()
  513. try:
  514. with conn.cursor() as cur:
  515. cur.execute(
  516. "UPDATE user SET is_active=%s, updated_at=%s WHERE username=%s",
  517. (1 if payload.is_active else 0, now, username),
  518. )
  519. if cur.rowcount == 0:
  520. return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
  521. finally:
  522. conn.close()
  523. return JSONResponse(content={"success": True})
  524. class AdminConfigRequest(BaseModel):
  525. config_key: str
  526. config_value: Optional[str] = None
  527. @app.get("/admin/config")
  528. async def admin_get_config(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  529. require_admin(session_token)
  530. conn = db_conn()
  531. try:
  532. with conn.cursor() as cur:
  533. cur.execute("SELECT config_key, config_value, updated_at FROM user_config ORDER BY config_key")
  534. rows = cur.fetchall()
  535. finally:
  536. conn.close()
  537. return JSONResponse(content=jsonable_encoder({"success": True, "configs": rows}))
  538. @app.post("/admin/config")
  539. async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  540. require_admin(session_token)
  541. config_key = (payload.config_key or "").strip()
  542. if not config_key:
  543. return JSONResponse(status_code=400, content={"success": False, "error": "config_key 不能为空"})
  544. now = datetime.now(timezone.utc).replace(tzinfo=None)
  545. conn = db_conn()
  546. try:
  547. with conn.cursor() as cur:
  548. cur.execute(
  549. """
  550. INSERT INTO user_config (config_key, config_value, updated_at)
  551. VALUES (%s, %s, %s)
  552. ON DUPLICATE KEY UPDATE config_value=VALUES(config_value), updated_at=VALUES(updated_at)
  553. """,
  554. (config_key, payload.config_value, now),
  555. )
  556. finally:
  557. conn.close()
  558. return JSONResponse(content={"success": True})
  559. # PDF upload endpoint
  560. @app.post("/upload-pdf")
  561. async def upload_pdf(
  562. file: UploadFile = File(...),
  563. custom_name: str = Form(...),
  564. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  565. ):
  566. user = require_user(session_token)
  567. if file.content_type != "application/pdf":
  568. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  569. sanitized_name = sanitize_filename(custom_name)
  570. if not sanitized_name:
  571. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  572. unique_filename = f"{sanitized_name}.pdf"
  573. user_dir = get_user_dir(user["username"])
  574. file_path = os.path.join(user_dir, unique_filename)
  575. if os.path.exists(file_path):
  576. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  577. try:
  578. with open(file_path, "wb") as buffer:
  579. shutil.copyfileobj(file.file, buffer)
  580. except Exception:
  581. raise HTTPException(status_code=500, detail="上传过程中出错")
  582. finally:
  583. file.file.close()
  584. file_relative_path = build_user_file_url(user["username"], unique_filename)
  585. now = datetime.now(timezone.utc).replace(tzinfo=None)
  586. conn = db_conn()
  587. try:
  588. with conn.cursor() as cur:
  589. cur.execute(
  590. "UPDATE user SET last_file=%s, last_page=1, updated_at=%s WHERE id=%s",
  591. (file_relative_path, now, user["id"]),
  592. )
  593. finally:
  594. conn.close()
  595. return JSONResponse(content={"success": True, "file_path": file_relative_path})
  596. # List PDFs endpoint
  597. @app.get("/list-pdfs")
  598. async def list_pdfs(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  599. user = require_user(session_token)
  600. try:
  601. user_dir = get_user_dir(user["username"])
  602. files = os.listdir(user_dir)
  603. pdf_files = [
  604. {"name": file, "url": build_user_file_url(user["username"], file)}
  605. for file in files
  606. if file.lower().endswith(".pdf")
  607. ]
  608. pdf_files.sort(key=lambda x: x["name"].lower())
  609. return JSONResponse(content={"success": True, "files": pdf_files})
  610. except Exception:
  611. raise HTTPException(status_code=500, detail="无法获取文件列表")
  612. class ReadingProgressRequest(BaseModel):
  613. file: str
  614. page: int
  615. @app.get("/reading-progress")
  616. async def get_reading_progress(file: str, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  617. user = require_user(session_token)
  618. normalized_file = (file or "").strip()
  619. if not normalized_file:
  620. return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
  621. conn = db_conn()
  622. try:
  623. with conn.cursor() as cur:
  624. cur.execute(
  625. "SELECT page FROM user_progress WHERE user_id=%s AND file_path=%s",
  626. (user["id"], normalized_file),
  627. )
  628. row = cur.fetchone()
  629. page = row["page"] if row else None
  630. finally:
  631. conn.close()
  632. return JSONResponse(content={"success": True, "file": normalized_file, "page": page})
  633. @app.post("/reading-progress")
  634. async def save_reading_progress(request: ReadingProgressRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  635. user = require_user(session_token)
  636. normalized_file = (request.file or "").strip()
  637. page = int(request.page)
  638. if not normalized_file:
  639. return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
  640. if page < 1:
  641. return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
  642. now = datetime.now(timezone.utc).replace(tzinfo=None)
  643. conn = db_conn()
  644. try:
  645. with conn.cursor() as cur:
  646. cur.execute(
  647. """
  648. INSERT INTO user_progress (user_id, file_path, page, updated_at)
  649. VALUES (%s, %s, %s, %s)
  650. ON DUPLICATE KEY UPDATE page=VALUES(page), updated_at=VALUES(updated_at)
  651. """,
  652. (user["id"], normalized_file, page, now),
  653. )
  654. cur.execute(
  655. "UPDATE user SET last_file=%s, last_page=%s, updated_at=%s WHERE id=%s",
  656. (normalized_file, page, now, user["id"]),
  657. )
  658. finally:
  659. conn.close()
  660. return JSONResponse(content={"success": True})
  661. class TextToSpeechRequest(BaseModel):
  662. user_input: str
  663. voice: str = "af_heart"
  664. speed: float = 1.0
  665. @app.post("/text-to-speech/")
  666. async def text_to_speech(request: TextToSpeechRequest):
  667. user_input = request.user_input.strip()
  668. if not user_input:
  669. raise HTTPException(status_code=400, detail="输入文本为空")
  670. text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
  671. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  672. if os.path.exists(audio_path):
  673. with open(audio_path, "rb") as f:
  674. return Response(content=f.read(), media_type="audio/wav")
  675. async def audio_generator() -> AsyncGenerator[bytes, None]:
  676. try:
  677. async with aiohttp.ClientSession() as session:
  678. async with session.post(
  679. "http://141.140.15.30:8028/generate",
  680. headers={"Content-Type": "application/json"},
  681. json={"text": user_input, "voice": request.voice, "speed": request.speed},
  682. ) as response:
  683. if response.status != 200:
  684. raise HTTPException(status_code=500, detail="TTS API 请求失败")
  685. buffer = ""
  686. full_audio = io.BytesIO()
  687. async for chunk in response.content.iter_any():
  688. buffer += chunk.decode("utf-8")
  689. lines = buffer.split("\n")
  690. buffer = lines[-1]
  691. for line in lines[:-1]:
  692. if not line.strip():
  693. continue
  694. try:
  695. data = json.loads(line)
  696. if data.get("error"):
  697. raise HTTPException(status_code=500, detail=data["error"])
  698. audio_b64 = data.get("audio")
  699. if audio_b64:
  700. audio_bytes = base64.b64decode(audio_b64)
  701. full_audio.write(audio_bytes)
  702. yield audio_bytes
  703. except json.JSONDecodeError as e:
  704. logger.error(f"JSON decode error: {str(e)}")
  705. continue
  706. if buffer.strip():
  707. try:
  708. data = json.loads(buffer)
  709. if data.get("audio"):
  710. audio_bytes = base64.b64decode(data["audio"])
  711. full_audio.write(audio_bytes)
  712. yield audio_bytes
  713. except json.JSONDecodeError:
  714. pass
  715. full_audio.seek(0)
  716. with open(audio_path, "wb") as f:
  717. f.write(full_audio.getvalue())
  718. except Exception as e:
  719. logger.error(f"TTS error: {str(e)}")
  720. raise HTTPException(status_code=500, detail=str(e))
  721. return StreamingResponse(audio_generator(), media_type="audio/wav")
  722. MAX_CHUNK_SIZE = 200
  723. def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
  724. import re
  725. sentences = re.split(r"(?<=[.!?]) +", text)
  726. chunks = []
  727. current_chunk = ""
  728. for sentence in sentences:
  729. if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
  730. current_chunk += " " + sentence if current_chunk else sentence
  731. else:
  732. if current_chunk:
  733. chunks.append(current_chunk)
  734. if len(sentence) > max_chunk_size:
  735. for i in range(0, len(sentence), max_chunk_size):
  736. chunks.append(sentence[i : i + max_chunk_size])
  737. current_chunk = ""
  738. else:
  739. current_chunk = sentence
  740. if current_chunk:
  741. chunks.append(current_chunk)
  742. return chunks
  743. async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  744. text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
  745. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.wav")
  746. if os.path.exists(audio_path):
  747. with open(audio_path, "rb") as f:
  748. yield f.read()
  749. else:
  750. try:
  751. async with aiohttp.ClientSession() as session:
  752. async with session.post(
  753. "http://141.140.15.30:8028/generate",
  754. headers={"Content-Type": "application/json"},
  755. json={"text": chunk, "voice": voice, "speed": speed},
  756. ) as response:
  757. if response.status != 200:
  758. raise HTTPException(status_code=500, detail="TTS API 请求失败")
  759. buffer = ""
  760. async for part in response.content.iter_any():
  761. buffer += part.decode("utf-8")
  762. lines = buffer.split("\n")
  763. buffer = lines[-1]
  764. for line in lines[:-1]:
  765. if not line.strip():
  766. continue
  767. try:
  768. data = json.loads(line)
  769. if data.get("error"):
  770. raise HTTPException(status_code=500, detail=data["error"])
  771. audio_b64 = data.get("audio")
  772. if audio_b64:
  773. audio_bytes = base64.b64decode(audio_b64)
  774. yield audio_bytes
  775. with open(audio_path, "wb") as f:
  776. f.write(audio_bytes)
  777. except json.JSONDecodeError as e:
  778. logger.error(f"JSON decode error: {str(e)}")
  779. continue
  780. if buffer.strip():
  781. try:
  782. data = json.loads(buffer)
  783. if data.get("audio"):
  784. audio_bytes = base64.b64decode(data["audio"])
  785. yield audio_bytes
  786. with open(audio_path, "wb") as f:
  787. f.write(audio_bytes)
  788. except json.JSONDecodeError:
  789. pass
  790. except Exception as e:
  791. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  792. @app.post("/page-to-speech/")
  793. async def page_to_speech(request: TextToSpeechRequest):
  794. user_input = request.user_input.strip()
  795. if not user_input:
  796. raise HTTPException(status_code=400, detail="输入文本为空")
  797. full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
  798. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.wav")
  799. if os.path.exists(full_audio_path):
  800. return StreamingResponse(open(full_audio_path, "rb"), media_type="audio/wav")
  801. chunks = split_text_into_chunks(user_input)
  802. async def audio_generator() -> AsyncGenerator[bytes, None]:
  803. full_audio_buffer = io.BytesIO()
  804. for chunk in chunks:
  805. async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
  806. yield audio_data
  807. full_audio_buffer.write(audio_data)
  808. await asyncio.sleep(0)
  809. full_audio_buffer.seek(0)
  810. with open(full_audio_path, "wb") as f:
  811. f.write(full_audio_buffer.getvalue())
  812. return StreamingResponse(audio_generator(), media_type="audio/wav")
  813. @app.get("/health")
  814. async def health_check():
  815. return {"status": "healthy"}
  816. if __name__ == "__main__":
  817. import uvicorn
  818. uvicorn.run(app, host="0.0.0.0", port=8005)