main_server.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Response, Cookie
  2. from fastapi.encoders import jsonable_encoder
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse, HTMLResponse
  5. from fastapi.staticfiles import StaticFiles
  6. from pydantic import BaseModel
  7. import os
  8. import shutil
  9. import hashlib
  10. import asyncio
  11. from typing import AsyncGenerator, Optional
  12. import aiohttp
  13. import io
  14. import logging
  15. import base64
  16. import json
  17. from datetime import datetime, timedelta, timezone
  18. import secrets
  19. import pymysql
  20. from config import MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE
  21. OPENAI_TTS_BASE_URL = os.getenv("OPENAI_TTS_BASE_URL", "https://api.aimanyi.top")
  22. OPENAI_TTS_API_KEY = os.getenv(
  23. "OPENAI_TTS_API_KEY",
  24. "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa",
  25. )
  26. OPENAI_TTS_MODEL = os.getenv("OPENAI_TTS_MODEL", "gpt-4o-mini-tts")
  27. OPENAI_TTS_DEFAULT_VOICE = os.getenv("OPENAI_TTS_DEFAULT_VOICE", "sage")
  28. OPENAI_TTS_FORMAT = os.getenv("OPENAI_TTS_FORMAT", "wav")
  29. CLIENT_COOKIE = "reader_pro_client"
  30. PROGRESS_FILE = "reading_progress.json"
  31. # Set up logging
  32. logging.basicConfig(level=logging.INFO)
  33. logger = logging.getLogger(__name__)
  34. # Initialize FastAPI app
  35. app = FastAPI()
  36. # Configure CORS
  37. origins = ["*"]
  38. app.add_middleware(
  39. CORSMiddleware,
  40. allow_origins=origins,
  41. allow_credentials=True,
  42. allow_methods=["*"],
  43. allow_headers=["*"],
  44. )
  45. # Base directories
  46. BASE_STATIC_FILES_DIR = "static/files"
  47. os.makedirs(BASE_STATIC_FILES_DIR, exist_ok=True)
  48. # Mount static files
  49. app.mount("/static/files", StaticFiles(directory=BASE_STATIC_FILES_DIR), name="static_files")
  50. app.mount("/static/web", StaticFiles(directory="static/web"), name="static_web")
  51. app.mount("/static", StaticFiles(directory="static"), name="static")
  52. # Audio cache directory
  53. CACHE_DIR = "audio_cache"
  54. os.makedirs(CACHE_DIR, exist_ok=True)
  55. SESSION_COOKIE = "reader_pro_session"
  56. SESSION_TTL_DAYS = 1
  57. SESSION_TTL_DAYS_REMEMBER = 30
  58. def db_conn():
  59. return pymysql.connect(
  60. host=MYSQL_HOST,
  61. port=MYSQL_PORT,
  62. user=MYSQL_USER,
  63. password=MYSQL_PASSWORD,
  64. database=MYSQL_DATABASE,
  65. charset="utf8mb4",
  66. autocommit=True,
  67. cursorclass=pymysql.cursors.DictCursor,
  68. )
  69. def init_db() -> None:
  70. conn = db_conn()
  71. try:
  72. with conn.cursor() as cur:
  73. cur.execute(
  74. """
  75. CREATE TABLE IF NOT EXISTS user (
  76. id BIGINT PRIMARY KEY AUTO_INCREMENT,
  77. username VARCHAR(64) NOT NULL UNIQUE,
  78. password_hash VARCHAR(255) NOT NULL,
  79. is_admin TINYINT(1) NOT NULL DEFAULT 0,
  80. is_active TINYINT(1) NOT NULL DEFAULT 1,
  81. session_token VARCHAR(128) NULL,
  82. session_expires_at DATETIME NULL,
  83. last_file VARCHAR(1024) NULL,
  84. last_page INT NULL,
  85. created_at DATETIME NOT NULL,
  86. updated_at DATETIME NOT NULL
  87. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  88. """
  89. )
  90. # backfill old schema without is_active column
  91. cur.execute("SHOW COLUMNS FROM user LIKE 'is_active'")
  92. if not cur.fetchone():
  93. cur.execute("ALTER TABLE user ADD COLUMN is_active TINYINT(1) NOT NULL DEFAULT 1 AFTER is_admin")
  94. cur.execute(
  95. """
  96. CREATE TABLE IF NOT EXISTS user_progress (
  97. id BIGINT PRIMARY KEY AUTO_INCREMENT,
  98. user_id BIGINT NOT NULL,
  99. file_path VARCHAR(512) NOT NULL,
  100. page INT NOT NULL,
  101. updated_at DATETIME NOT NULL,
  102. UNIQUE KEY uniq_user_file (user_id, file_path),
  103. KEY idx_user_updated (user_id, updated_at),
  104. CONSTRAINT fk_user_progress_user
  105. FOREIGN KEY (user_id) REFERENCES user(id)
  106. ON DELETE CASCADE
  107. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  108. """
  109. )
  110. cur.execute(
  111. """
  112. CREATE TABLE IF NOT EXISTS user_config (
  113. config_key VARCHAR(128) PRIMARY KEY,
  114. config_value TEXT NULL,
  115. updated_at DATETIME NOT NULL
  116. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  117. """
  118. )
  119. # seed admin
  120. cur.execute("SELECT id FROM user WHERE username=%s", ("admin",))
  121. admin = cur.fetchone()
  122. if not admin:
  123. now = datetime.now(timezone.utc).replace(tzinfo=None)
  124. cur.execute(
  125. """
  126. INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at)
  127. VALUES (%s, %s, 1, 1, %s, %s)
  128. """,
  129. ("admin", hash_password("admin"), now, now),
  130. )
  131. logger.info("Seeded default admin account: admin/admin")
  132. finally:
  133. conn.close()
  134. def hash_password(password: str) -> str:
  135. # Keep simple deterministic hash for compatibility; can migrate to bcrypt later.
  136. return hashlib.sha256(password.encode("utf-8")).hexdigest()
  137. def sanitize_filename(name: str) -> str:
  138. return "".join(c for c in name if c.isalnum() or c in (" ", ".", "_", "-")).rstrip()
  139. def build_file_url(filename: str) -> str:
  140. return f"/static/files/{filename}"
  141. def get_or_create_client_id(client_id: Optional[str]) -> str:
  142. normalized = (client_id or "").strip()
  143. return normalized or secrets.token_urlsafe(24)
  144. def load_progress_store() -> dict:
  145. if not os.path.exists(PROGRESS_FILE):
  146. return {}
  147. try:
  148. with open(PROGRESS_FILE, "r", encoding="utf-8") as f:
  149. data = json.load(f)
  150. return data if isinstance(data, dict) else {}
  151. except Exception:
  152. return {}
  153. def save_progress_store(data: dict) -> None:
  154. with open(PROGRESS_FILE, "w", encoding="utf-8") as f:
  155. json.dump(data, f, ensure_ascii=False, indent=2)
  156. def get_user_dir(username: str) -> str:
  157. safe = sanitize_filename(username) or "user"
  158. path = os.path.join(BASE_STATIC_FILES_DIR, safe)
  159. os.makedirs(path, exist_ok=True)
  160. return path
  161. def build_user_file_url(username: str, filename: str) -> str:
  162. return f"/static/files/{sanitize_filename(username)}/{filename}"
  163. def get_user_by_session(session_token: Optional[str]):
  164. if not session_token:
  165. return None
  166. now = datetime.now(timezone.utc).replace(tzinfo=None)
  167. conn = db_conn()
  168. try:
  169. with conn.cursor() as cur:
  170. cur.execute(
  171. """
  172. SELECT * FROM user
  173. WHERE session_token=%s AND session_expires_at IS NOT NULL AND session_expires_at>%s
  174. """,
  175. (session_token, now),
  176. )
  177. return cur.fetchone()
  178. finally:
  179. conn.close()
  180. def require_user(session_token: Optional[str]):
  181. user = get_user_by_session(session_token)
  182. if not user:
  183. raise HTTPException(status_code=401, detail="未登录或会话已过期")
  184. return user
  185. def require_admin(session_token: Optional[str]):
  186. user = require_user(session_token)
  187. if not user.get("is_admin"):
  188. raise HTTPException(status_code=403, detail="需要管理员权限")
  189. return user
  190. def set_session_for_user(username: str, remember_me: bool):
  191. token = secrets.token_urlsafe(48)
  192. expire_days = SESSION_TTL_DAYS_REMEMBER if remember_me else SESSION_TTL_DAYS
  193. expires_at = datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=expire_days)
  194. conn = db_conn()
  195. try:
  196. with conn.cursor() as cur:
  197. cur.execute(
  198. """
  199. UPDATE user
  200. SET session_token=%s, session_expires_at=%s, updated_at=%s
  201. WHERE username=%s
  202. """,
  203. (token, expires_at, datetime.now(timezone.utc).replace(tzinfo=None), username),
  204. )
  205. finally:
  206. conn.close()
  207. return token, expire_days
  208. @app.on_event("startup")
  209. def on_startup():
  210. init_db()
  211. @app.get("/")
  212. def root(client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
  213. current_client_id = get_or_create_client_id(client_id)
  214. progress = load_progress_store().get(current_client_id, {})
  215. last_file = (progress.get("last_file") or "").strip()
  216. if last_file:
  217. response = RedirectResponse(url=f"/static/web/viewer.html?file={last_file}", status_code=302)
  218. else:
  219. files = sorted([f for f in os.listdir(BASE_STATIC_FILES_DIR) if f.lower().endswith(".pdf")])
  220. if files:
  221. response = RedirectResponse(
  222. url=f"/static/web/viewer.html?file={build_file_url(files[0])}",
  223. status_code=302,
  224. )
  225. else:
  226. response = RedirectResponse(url="/static/web/viewer.html?file=/static/files/compress.pdf", status_code=302)
  227. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  228. return response
  229. @app.get("/login")
  230. def login_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  231. user = get_user_by_session(session_token)
  232. if user:
  233. return RedirectResponse(url="/", status_code=302)
  234. html = """
  235. <!doctype html>
  236. <html lang="zh-CN">
  237. <head>
  238. <meta charset="utf-8"/>
  239. <meta name="viewport" content="width=device-width, initial-scale=1"/>
  240. <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 登录</title>
  241. <style>
  242. *{box-sizing:border-box}
  243. body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
  244. .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
  245. .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
  246. .sub{margin:0 0 18px;font-size:13px;color:#64748b}
  247. input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
  248. input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
  249. input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
  250. button{margin-top:12px;border:none;background:#2563eb;color:#fff;font-weight:600;cursor:pointer}
  251. button:hover{background:#1d4ed8}
  252. .row{display:flex;gap:8px;align-items:center;margin-top:10px;font-size:13px;color:#475569}
  253. .row input{width:auto;margin:0}
  254. .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
  255. .links{margin-top:8px;text-align:right}
  256. .links a{font-size:13px;color:#2563eb;text-decoration:none}
  257. .links a:hover{text-decoration:underline}
  258. .hint{margin-top:12px;padding:10px;border-radius:10px;background:#f8fafc;color:#475569;font-size:12px}
  259. </style>
  260. </head>
  261. <body>
  262. <div class="card">
  263. <h1 class="title">VoiceFlow AI Reader</h1>
  264. <p class="sub">【小满TTS英文听书】</p>
  265. <input id="loginUser" placeholder="用户名" autocomplete="username"/>
  266. <input id="loginPass" type="password" placeholder="密码" autocomplete="current-password"/>
  267. <label class="row"><input id="remember" type="checkbox"/>30天内免登录</label>
  268. <button onclick="login()">登录</button>
  269. <div id="loginMsg" class="msg"></div>
  270. <div class="links"><a href="/register">没有账号?去注册</a></div>
  271. <div class="hint">管理员默认账号:admin / admin</div>
  272. </div>
  273. <script>
  274. async function login(){
  275. const username=document.getElementById('loginUser').value.trim();
  276. const password=document.getElementById('loginPass').value;
  277. const remember_me=document.getElementById('remember').checked;
  278. const msg=document.getElementById('loginMsg');
  279. msg.textContent='';
  280. const r=await fetch('/auth/login',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password,remember_me})});
  281. const d=await r.json();
  282. if(!r.ok||!d.success){msg.textContent=d.error||'登录失败';return;}
  283. window.location.href='/';
  284. }
  285. </script>
  286. </body>
  287. </html>
  288. """
  289. return HTMLResponse(content=html)
  290. @app.get("/register")
  291. def register_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  292. user = get_user_by_session(session_token)
  293. if user:
  294. return RedirectResponse(url="/", status_code=302)
  295. html = """
  296. <!doctype html>
  297. <html lang="zh-CN">
  298. <head>
  299. <meta charset="utf-8"/>
  300. <meta name="viewport" content="width=device-width, initial-scale=1"/>
  301. <title>VoiceFlow AI Reader 【小满TTS英文听书】 - 注册</title>
  302. <style>
  303. *{box-sizing:border-box}
  304. body{margin:0;min-height:100vh;display:grid;place-items:center;background:linear-gradient(145deg,#f2f6fb,#e7eef8);font-family:Arial,sans-serif;color:#1f2937}
  305. .card{width:min(92vw,420px);background:#fff;border-radius:14px;padding:24px;box-shadow:0 10px 35px rgba(17,24,39,.10)}
  306. .title{margin:0 0 6px;font-size:22px;font-weight:700;color:#0f172a}
  307. .sub{margin:0 0 18px;font-size:13px;color:#64748b}
  308. input,button{width:100%;padding:11px 12px;border-radius:10px;font-size:14px}
  309. input{border:1px solid #d7e0ec;margin-top:10px;outline:none}
  310. input:focus{border-color:#3b82f6;box-shadow:0 0 0 3px rgba(59,130,246,.12)}
  311. button{margin-top:12px;border:none;background:#16a34a;color:#fff;font-weight:600;cursor:pointer}
  312. button:hover{background:#15803d}
  313. .msg{margin-top:10px;color:#b00020;min-height:20px;font-size:13px}
  314. .ok{color:#0f7b0f}
  315. .links{margin-top:8px;text-align:right}
  316. .links a{font-size:13px;color:#2563eb;text-decoration:none}
  317. .links a:hover{text-decoration:underline}
  318. </style>
  319. </head>
  320. <body>
  321. <div class="card">
  322. <h1 class="title">创建新账号</h1>
  323. <p class="sub">VoiceFlow AI Reader 【小满TTS英文听书】</p>
  324. <input id="regUser" placeholder="用户名(至少3位)" autocomplete="username"/>
  325. <input id="regPass" type="password" placeholder="密码(至少4位)" autocomplete="new-password"/>
  326. <button onclick="registerUser()">注册</button>
  327. <div id="regMsg" class="msg"></div>
  328. <div class="links"><a href="/login">已有账号?去登录</a></div>
  329. </div>
  330. <script>
  331. async function registerUser(){
  332. const username=document.getElementById('regUser').value.trim();
  333. const password=document.getElementById('regPass').value;
  334. const msg=document.getElementById('regMsg');
  335. msg.textContent='';
  336. const r=await fetch('/auth/register',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({username,password})});
  337. const d=await r.json();
  338. if(!r.ok||!d.success){msg.textContent=d.error||'注册失败';return;}
  339. msg.textContent='注册成功,2秒后跳转登录页';
  340. msg.className='msg ok';
  341. setTimeout(()=>{window.location.href='/login';},2000);
  342. }
  343. </script>
  344. </body>
  345. </html>
  346. """
  347. return HTMLResponse(content=html)
  348. class LoginRequest(BaseModel):
  349. username: str
  350. password: str
  351. remember_me: bool = False
  352. class RegisterRequest(BaseModel):
  353. username: str
  354. password: str
  355. @app.post("/auth/register")
  356. async def auth_register(request: RegisterRequest):
  357. username = sanitize_filename((request.username or "").strip())
  358. password = (request.password or "").strip()
  359. if len(username) < 3:
  360. return JSONResponse(status_code=400, content={"success": False, "error": "用户名至少3位"})
  361. if len(password) < 4:
  362. return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
  363. now = datetime.now(timezone.utc).replace(tzinfo=None)
  364. conn = db_conn()
  365. try:
  366. with conn.cursor() as cur:
  367. cur.execute("SELECT id FROM user WHERE username=%s", (username,))
  368. if cur.fetchone():
  369. return JSONResponse(status_code=400, content={"success": False, "error": "用户名已存在"})
  370. cur.execute(
  371. """
  372. INSERT INTO user (username, password_hash, is_admin, created_at, updated_at)
  373. VALUES (%s, %s, 0, %s, %s)
  374. """,
  375. (username, hash_password(password), now, now),
  376. )
  377. finally:
  378. conn.close()
  379. get_user_dir(username)
  380. return JSONResponse(content={"success": True})
  381. @app.post("/auth/login")
  382. async def auth_login(request: LoginRequest):
  383. username = sanitize_filename((request.username or "").strip())
  384. password = (request.password or "").strip()
  385. conn = db_conn()
  386. try:
  387. with conn.cursor() as cur:
  388. cur.execute("SELECT * FROM user WHERE username=%s", (username,))
  389. user = cur.fetchone()
  390. finally:
  391. conn.close()
  392. if not user or user["password_hash"] != hash_password(password):
  393. return JSONResponse(status_code=401, content={"success": False, "error": "用户名或密码错误"})
  394. if int(user.get("is_active", 1)) != 1:
  395. return JSONResponse(status_code=403, content={"success": False, "error": "账号已被禁用"})
  396. token, expire_days = set_session_for_user(username, request.remember_me)
  397. resp = JSONResponse(content={"success": True, "is_admin": bool(user.get("is_admin"))})
  398. max_age = expire_days * 24 * 3600
  399. resp.set_cookie(
  400. key=SESSION_COOKIE,
  401. value=token,
  402. max_age=max_age,
  403. httponly=True,
  404. samesite="lax",
  405. secure=False,
  406. path="/",
  407. )
  408. return resp
  409. @app.post("/auth/logout")
  410. async def auth_logout(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  411. user = get_user_by_session(session_token)
  412. if user:
  413. conn = db_conn()
  414. try:
  415. with conn.cursor() as cur:
  416. cur.execute(
  417. "UPDATE user SET session_token=NULL, session_expires_at=NULL, updated_at=%s WHERE id=%s",
  418. (datetime.now(timezone.utc).replace(tzinfo=None), user["id"]),
  419. )
  420. finally:
  421. conn.close()
  422. resp = JSONResponse(content={"success": True})
  423. resp.delete_cookie(SESSION_COOKIE, path="/")
  424. return resp
  425. @app.get("/auth/me")
  426. async def auth_me(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  427. user = get_user_by_session(session_token)
  428. if not user:
  429. return JSONResponse(status_code=401, content={"success": False})
  430. return JSONResponse(
  431. content={
  432. "success": True,
  433. "username": user["username"],
  434. "is_admin": bool(user.get("is_admin")),
  435. "is_active": bool(user.get("is_active", 1)),
  436. }
  437. )
  438. @app.get("/admin")
  439. def admin_page(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  440. require_admin(session_token)
  441. return RedirectResponse(url="/static/web/admin.html", status_code=302)
  442. class AdminUserRequest(BaseModel):
  443. username: str
  444. password: str
  445. is_admin: bool = False
  446. @app.get("/admin/users")
  447. async def admin_users(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  448. require_admin(session_token)
  449. conn = db_conn()
  450. try:
  451. with conn.cursor() as cur:
  452. cur.execute("SELECT id, username, is_admin, is_active, created_at FROM user ORDER BY id ASC")
  453. users = cur.fetchall()
  454. finally:
  455. conn.close()
  456. return JSONResponse(content=jsonable_encoder({"success": True, "users": users}))
  457. @app.post("/admin/users")
  458. async def admin_create_or_reset_user(payload: AdminUserRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  459. require_admin(session_token)
  460. username = sanitize_filename((payload.username or "").strip())
  461. password = (payload.password or "").strip()
  462. if len(username) < 3 or len(password) < 4:
  463. return JSONResponse(status_code=400, content={"success": False, "error": "用户名或密码不合法"})
  464. now = datetime.now(timezone.utc).replace(tzinfo=None)
  465. conn = db_conn()
  466. try:
  467. with conn.cursor() as cur:
  468. cur.execute("SELECT id FROM user WHERE username=%s", (username,))
  469. row = cur.fetchone()
  470. if row:
  471. cur.execute(
  472. "UPDATE user SET password_hash=%s, is_admin=%s, is_active=1, updated_at=%s WHERE username=%s",
  473. (hash_password(password), 1 if payload.is_admin else 0, now, username),
  474. )
  475. else:
  476. cur.execute(
  477. "INSERT INTO user (username, password_hash, is_admin, is_active, created_at, updated_at) VALUES (%s,%s,%s,1,%s,%s)",
  478. (username, hash_password(password), 1 if payload.is_admin else 0, now, now),
  479. )
  480. finally:
  481. conn.close()
  482. get_user_dir(username)
  483. return JSONResponse(content={"success": True})
  484. @app.delete("/admin/users/{username}")
  485. async def admin_delete_user(
  486. username: str,
  487. delete_files: bool = False,
  488. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  489. ):
  490. require_admin(session_token)
  491. if username == "admin":
  492. return JSONResponse(status_code=400, content={"success": False, "error": "不能删除admin"})
  493. conn = db_conn()
  494. try:
  495. with conn.cursor() as cur:
  496. cur.execute("DELETE FROM user WHERE username=%s", (username,))
  497. finally:
  498. conn.close()
  499. if delete_files:
  500. user_dir = get_user_dir(username)
  501. if os.path.isdir(user_dir):
  502. shutil.rmtree(user_dir, ignore_errors=True)
  503. return JSONResponse(content={"success": True})
  504. class AdminResetPasswordRequest(BaseModel):
  505. password: str
  506. @app.post("/admin/users/{username}/reset-password")
  507. async def admin_reset_password(
  508. username: str,
  509. payload: AdminResetPasswordRequest,
  510. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  511. ):
  512. require_admin(session_token)
  513. password = (payload.password or "").strip()
  514. if len(password) < 4:
  515. return JSONResponse(status_code=400, content={"success": False, "error": "密码至少4位"})
  516. now = datetime.now(timezone.utc).replace(tzinfo=None)
  517. conn = db_conn()
  518. try:
  519. with conn.cursor() as cur:
  520. cur.execute(
  521. "UPDATE user SET password_hash=%s, updated_at=%s WHERE username=%s",
  522. (hash_password(password), now, username),
  523. )
  524. if cur.rowcount == 0:
  525. return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
  526. finally:
  527. conn.close()
  528. return JSONResponse(content={"success": True})
  529. class AdminToggleUserRequest(BaseModel):
  530. is_active: bool
  531. @app.post("/admin/users/{username}/status")
  532. async def admin_toggle_user_status(
  533. username: str,
  534. payload: AdminToggleUserRequest,
  535. session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE),
  536. ):
  537. require_admin(session_token)
  538. if username == "admin" and not payload.is_active:
  539. return JSONResponse(status_code=400, content={"success": False, "error": "不能禁用admin"})
  540. now = datetime.now(timezone.utc).replace(tzinfo=None)
  541. conn = db_conn()
  542. try:
  543. with conn.cursor() as cur:
  544. cur.execute(
  545. "UPDATE user SET is_active=%s, updated_at=%s WHERE username=%s",
  546. (1 if payload.is_active else 0, now, username),
  547. )
  548. if cur.rowcount == 0:
  549. return JSONResponse(status_code=404, content={"success": False, "error": "用户不存在"})
  550. finally:
  551. conn.close()
  552. return JSONResponse(content={"success": True})
  553. class AdminConfigRequest(BaseModel):
  554. config_key: str
  555. config_value: Optional[str] = None
  556. @app.get("/admin/config")
  557. async def admin_get_config(session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  558. require_admin(session_token)
  559. conn = db_conn()
  560. try:
  561. with conn.cursor() as cur:
  562. cur.execute("SELECT config_key, config_value, updated_at FROM user_config ORDER BY config_key")
  563. rows = cur.fetchall()
  564. finally:
  565. conn.close()
  566. return JSONResponse(content=jsonable_encoder({"success": True, "configs": rows}))
  567. @app.post("/admin/config")
  568. async def admin_set_config(payload: AdminConfigRequest, session_token: Optional[str] = Cookie(default=None, alias=SESSION_COOKIE)):
  569. require_admin(session_token)
  570. config_key = (payload.config_key or "").strip()
  571. if not config_key:
  572. return JSONResponse(status_code=400, content={"success": False, "error": "config_key 不能为空"})
  573. now = datetime.now(timezone.utc).replace(tzinfo=None)
  574. conn = db_conn()
  575. try:
  576. with conn.cursor() as cur:
  577. cur.execute(
  578. """
  579. INSERT INTO user_config (config_key, config_value, updated_at)
  580. VALUES (%s, %s, %s)
  581. ON DUPLICATE KEY UPDATE config_value=VALUES(config_value), updated_at=VALUES(updated_at)
  582. """,
  583. (config_key, payload.config_value, now),
  584. )
  585. finally:
  586. conn.close()
  587. return JSONResponse(content={"success": True})
  588. # PDF upload endpoint
  589. @app.post("/upload-pdf")
  590. async def upload_pdf(
  591. file: UploadFile = File(...),
  592. custom_name: str = Form(...),
  593. client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE),
  594. ):
  595. if file.content_type != "application/pdf":
  596. raise HTTPException(status_code=400, detail="文件类型必须是 PDF")
  597. sanitized_name = sanitize_filename(custom_name)
  598. if not sanitized_name:
  599. return JSONResponse(status_code=400, content={"success": False, "error": "无效的文件名"})
  600. unique_filename = f"{sanitized_name}.pdf"
  601. file_path = os.path.join(BASE_STATIC_FILES_DIR, unique_filename)
  602. if os.path.exists(file_path):
  603. return JSONResponse(status_code=400, content={"success": False, "error": "文件名已存在,请使用其他名称"})
  604. try:
  605. with open(file_path, "wb") as buffer:
  606. shutil.copyfileobj(file.file, buffer)
  607. except Exception:
  608. raise HTTPException(status_code=500, detail="上传过程中出错")
  609. finally:
  610. file.file.close()
  611. current_client_id = get_or_create_client_id(client_id)
  612. file_relative_path = build_file_url(unique_filename)
  613. store = load_progress_store()
  614. store[current_client_id] = {
  615. "last_file": file_relative_path,
  616. "last_page": 1,
  617. "updated_at": datetime.now(timezone.utc).isoformat(),
  618. }
  619. save_progress_store(store)
  620. response = JSONResponse(content={"success": True, "file_path": file_relative_path})
  621. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  622. return response
  623. # List PDFs endpoint
  624. @app.get("/list-pdfs")
  625. async def list_pdfs():
  626. try:
  627. files = os.listdir(BASE_STATIC_FILES_DIR)
  628. pdf_files = [
  629. {"name": file, "url": build_file_url(file)}
  630. for file in files
  631. if file.lower().endswith(".pdf")
  632. ]
  633. pdf_files.sort(key=lambda x: x["name"].lower())
  634. return JSONResponse(content={"success": True, "files": pdf_files})
  635. except Exception:
  636. raise HTTPException(status_code=500, detail="无法获取文件列表")
  637. class ReadingProgressRequest(BaseModel):
  638. file: str
  639. page: int
  640. @app.get("/reading-progress")
  641. async def get_reading_progress(file: str, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
  642. normalized_file = (file or "").strip()
  643. if not normalized_file:
  644. return JSONResponse(status_code=400, content={"success": False, "error": "缺少 file 参数"})
  645. current_client_id = get_or_create_client_id(client_id)
  646. progress = load_progress_store().get(current_client_id, {})
  647. page = progress.get("last_page") if progress.get("last_file") == normalized_file else None
  648. response = JSONResponse(content={"success": True, "file": normalized_file, "page": page})
  649. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  650. return response
  651. @app.post("/reading-progress")
  652. async def save_reading_progress(request: ReadingProgressRequest, client_id: Optional[str] = Cookie(default=None, alias=CLIENT_COOKIE)):
  653. normalized_file = (request.file or "").strip()
  654. page = int(request.page)
  655. if not normalized_file:
  656. return JSONResponse(status_code=400, content={"success": False, "error": "file 不能为空"})
  657. if page < 1:
  658. return JSONResponse(status_code=400, content={"success": False, "error": "page 必须 >= 1"})
  659. current_client_id = get_or_create_client_id(client_id)
  660. store = load_progress_store()
  661. store[current_client_id] = {
  662. "last_file": normalized_file,
  663. "last_page": page,
  664. "updated_at": datetime.now(timezone.utc).isoformat(),
  665. }
  666. save_progress_store(store)
  667. response = JSONResponse(content={"success": True})
  668. response.set_cookie(key=CLIENT_COOKIE, value=current_client_id, httponly=True, samesite="lax", secure=False, path="/")
  669. return response
  670. class TextToSpeechRequest(BaseModel):
  671. user_input: str
  672. voice: str = OPENAI_TTS_DEFAULT_VOICE
  673. speed: float = 1.0
  674. @app.post("/generate")
  675. async def generate_proxy(request: TextToSpeechRequest):
  676. user_input = request.user_input.strip()
  677. if not user_input:
  678. raise HTTPException(status_code=400, detail="输入文本为空")
  679. async def stream_generator() -> AsyncGenerator[bytes, None]:
  680. try:
  681. chunks = split_text_into_chunks(user_input)
  682. for index, chunk in enumerate(chunks):
  683. audio_bytes = await request_openai_tts_audio(chunk, request.voice)
  684. payload = {
  685. "index": index,
  686. "text": chunk,
  687. "audio": base64.b64encode(audio_bytes).decode("utf-8"),
  688. }
  689. yield (json.dumps(payload, ensure_ascii=False) + "\n").encode("utf-8")
  690. await asyncio.sleep(0)
  691. except HTTPException as e:
  692. logger.error("generate proxy http error: %s", e.detail)
  693. yield (json.dumps({"error": e.detail}, ensure_ascii=False) + "\n").encode("utf-8")
  694. except Exception as e:
  695. logger.error(f"generate proxy error: {str(e)}")
  696. yield (json.dumps({"error": "TTS生成失败"}, ensure_ascii=False) + "\n").encode("utf-8")
  697. return StreamingResponse(stream_generator(), media_type="application/x-ndjson")
  698. def normalize_openai_voice(voice: str) -> str:
  699. allowed_voices = {"alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse"}
  700. normalized = (voice or "").strip().lower()
  701. return normalized if normalized in allowed_voices else OPENAI_TTS_DEFAULT_VOICE
  702. def get_audio_media_type(audio_format: str) -> str:
  703. mapping = {
  704. "wav": "audio/wav",
  705. "mp3": "audio/mpeg",
  706. "flac": "audio/flac",
  707. "opus": "audio/opus",
  708. "pcm16": "audio/L16",
  709. }
  710. return mapping.get(audio_format.lower(), "application/octet-stream")
  711. async def request_openai_tts_audio(text: str, voice: str) -> bytes:
  712. payload = {
  713. "model": OPENAI_TTS_MODEL,
  714. "voice": normalize_openai_voice(voice),
  715. "input": text,
  716. "response_format": OPENAI_TTS_FORMAT,
  717. "speed": 1.0,
  718. }
  719. headers = {
  720. "Authorization": f"Bearer {OPENAI_TTS_API_KEY}",
  721. "Content-Type": "application/json",
  722. }
  723. async with aiohttp.ClientSession() as session:
  724. async with session.post(
  725. f"{OPENAI_TTS_BASE_URL.rstrip('/')}/v1/audio/speech",
  726. headers=headers,
  727. json=payload,
  728. ) as response:
  729. if response.status != 200:
  730. response_text = await response.text()
  731. logger.error("OpenAI TTS request failed: %s", response_text)
  732. error_detail = "OpenAI TTS API 请求失败"
  733. try:
  734. error_data = json.loads(response_text)
  735. error_obj = error_data.get("error", {})
  736. error_message = error_obj.get("message")
  737. error_code = error_obj.get("code")
  738. if error_message:
  739. error_detail = f"{error_detail}: {error_message}"
  740. if response.status == 429 or error_code in {"rate_limit_exceeded", "model_not_found", "upstream_error"}:
  741. raise HTTPException(status_code=503, detail=error_detail)
  742. except json.JSONDecodeError:
  743. pass
  744. raise HTTPException(status_code=500 if response.status < 500 else 502, detail=error_detail)
  745. return await response.read()
  746. @app.post("/text-to-speech/")
  747. async def text_to_speech(request: TextToSpeechRequest):
  748. user_input = request.user_input.strip()
  749. if not user_input:
  750. raise HTTPException(status_code=400, detail="输入文本为空")
  751. text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
  752. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
  753. media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
  754. if os.path.exists(audio_path):
  755. with open(audio_path, "rb") as f:
  756. return Response(content=f.read(), media_type=media_type)
  757. try:
  758. audio_bytes = await request_openai_tts_audio(user_input, request.voice)
  759. with open(audio_path, "wb") as f:
  760. f.write(audio_bytes)
  761. return Response(content=audio_bytes, media_type=media_type)
  762. except HTTPException:
  763. raise
  764. except Exception as e:
  765. logger.error(f"TTS error: {str(e)}")
  766. raise HTTPException(status_code=500, detail="TTS生成失败")
  767. MAX_CHUNK_SIZE = 200
  768. def split_text_into_chunks(text: str, max_chunk_size: int = MAX_CHUNK_SIZE) -> list:
  769. import re
  770. sentences = re.split(r"(?<=[.!?]) +", text)
  771. chunks = []
  772. current_chunk = ""
  773. for sentence in sentences:
  774. if len(current_chunk) + len(sentence) + 1 <= max_chunk_size:
  775. current_chunk += " " + sentence if current_chunk else sentence
  776. else:
  777. if current_chunk:
  778. chunks.append(current_chunk)
  779. if len(sentence) > max_chunk_size:
  780. for i in range(0, len(sentence), max_chunk_size):
  781. chunks.append(sentence[i : i + max_chunk_size])
  782. current_chunk = ""
  783. else:
  784. current_chunk = sentence
  785. if current_chunk:
  786. chunks.append(current_chunk)
  787. return chunks
  788. async def generate_api_audio(chunk: str, voice: str, speed: float) -> AsyncGenerator[bytes, None]:
  789. text_hash = hashlib.md5(chunk.encode("utf-8")).hexdigest()
  790. audio_path = os.path.join(CACHE_DIR, f"{text_hash}.{OPENAI_TTS_FORMAT}")
  791. if os.path.exists(audio_path):
  792. with open(audio_path, "rb") as f:
  793. yield f.read()
  794. else:
  795. try:
  796. audio_bytes = await request_openai_tts_audio(chunk, voice)
  797. with open(audio_path, "wb") as f:
  798. f.write(audio_bytes)
  799. yield audio_bytes
  800. except HTTPException as e:
  801. raise HTTPException(status_code=e.status_code, detail=e.detail)
  802. except Exception as e:
  803. raise HTTPException(status_code=500, detail=f"TTS生成失败: {str(e)}")
  804. @app.post("/page-to-speech/")
  805. async def page_to_speech(request: TextToSpeechRequest):
  806. user_input = request.user_input.strip()
  807. if not user_input:
  808. raise HTTPException(status_code=400, detail="输入文本为空")
  809. full_text_hash = hashlib.md5(user_input.encode("utf-8")).hexdigest()
  810. full_audio_path = os.path.join(CACHE_DIR, f"{full_text_hash}_full.{OPENAI_TTS_FORMAT}")
  811. media_type = get_audio_media_type(OPENAI_TTS_FORMAT)
  812. if os.path.exists(full_audio_path):
  813. return StreamingResponse(open(full_audio_path, "rb"), media_type=media_type)
  814. chunks = split_text_into_chunks(user_input)
  815. async def audio_generator() -> AsyncGenerator[bytes, None]:
  816. full_audio_buffer = io.BytesIO()
  817. for chunk in chunks:
  818. async for audio_data in generate_api_audio(chunk, request.voice, request.speed):
  819. yield audio_data
  820. full_audio_buffer.write(audio_data)
  821. await asyncio.sleep(0)
  822. full_audio_buffer.seek(0)
  823. with open(full_audio_path, "wb") as f:
  824. f.write(full_audio_buffer.getvalue())
  825. return StreamingResponse(audio_generator(), media_type=media_type)
  826. @app.get("/health")
  827. async def health_check():
  828. return {"status": "healthy"}
  829. if __name__ == "__main__":
  830. import uvicorn
  831. uvicorn.run(app, host="0.0.0.0", port=8005)