fastchat.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132
  1. # -*- coding: utf-8 -*-
  2. import asyncio
  3. import base64
  4. import datetime
  5. import hashlib
  6. import json
  7. import os
  8. import re
  9. import secrets
  10. import shutil
  11. import threading
  12. import uuid
  13. from pathlib import Path
  14. from typing import Any, Callable, Dict, List, Optional, Union
  15. from urllib.parse import quote_plus
  16. from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File, status
  17. from fastapi.middleware.cors import CORSMiddleware
  18. from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
  19. from fastapi.staticfiles import StaticFiles
  20. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  21. from pydantic import BaseModel
  22. from sqlalchemy import (
  23. Boolean,
  24. Column,
  25. DateTime,
  26. ForeignKey,
  27. Integer,
  28. String,
  29. Text,
  30. create_engine,
  31. select,
  32. text,
  33. )
  34. from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker
  35. from openai import OpenAI
  36. # =============================
  37. # 基础配置
  38. # =============================
  39. BASE_DIR = Path(__file__).resolve().parent
  40. DATA_DIR = BASE_DIR / "data"
  41. BACKUP_DIR = BASE_DIR / "data_bak"
  42. BLOG_DIR = BASE_DIR / "blog"
  43. UPLOAD_DIR = BASE_DIR / "uploads"
  44. STATIC_DIR = BASE_DIR / "static"
  45. MYSQL_HOST = os.getenv("CHATFAST_DB_HOST", "127.0.0.1")
  46. MYSQL_PORT = int(os.getenv("CHATFAST_DB_PORT", "3306"))
  47. MYSQL_USER = os.getenv("CHATFAST_DB_USER", "root")
  48. MYSQL_PASSWORD = os.getenv("CHATFAST_DB_PASSWORD", "792199Zhao*")
  49. DATABASE_NAME = os.getenv("CHATFAST_DB_NAME", "chat_fast")
  50. ENCODED_PASSWORD = quote_plus(MYSQL_PASSWORD)
  51. RAW_DATABASE_URL = (
  52. f"mysql+pymysql://{MYSQL_USER}:{ENCODED_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/?charset=utf8mb4"
  53. )
  54. DATABASE_URL = (
  55. f"mysql+pymysql://{MYSQL_USER}:{ENCODED_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{DATABASE_NAME}?charset=utf8mb4"
  56. )
  57. TOKEN_TTL_HOURS = int(os.getenv("CHATFAST_TOKEN_TTL_HOURS", "72"))
  58. DEFAULT_ADMIN_USERNAME = os.getenv("CHATFAST_DEFAULT_ADMIN", "admin")
  59. DEFAULT_ADMIN_PASSWORD = os.getenv("CHATFAST_DEFAULT_ADMIN_PASSWORD", "Admin@123")
  60. Base = declarative_base()
  61. ENGINE = None
  62. SessionLocal: Optional[sessionmaker] = None
  63. AUTH_SCHEME = HTTPBearer(auto_error=False)
  64. # 默认上传文件下载地址,可通过环境变量覆盖
  65. DEFAULT_UPLOAD_BASE = os.getenv("UPLOAD_BASE_URL", "/download/")
  66. DOWNLOAD_BASE = DEFAULT_UPLOAD_BASE.rstrip("/")
  67. # 与 appchat.py 相同的模型与密钥配置(仅示例)
  68. default_key = "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa"
  69. MODEL_KEYS: Dict[str, str] = {
  70. "grok-3": default_key,
  71. "grok-4": default_key,
  72. "gpt-5.1-2025-11-13": default_key,
  73. "gpt-5-2025-08-07": default_key,
  74. "gpt-4o-mini": default_key,
  75. # "gpt-4.1-mini-2025-04-14": default_key,
  76. "o1-mini": default_key,
  77. "o4-mini": default_key,
  78. "deepseek-v3": default_key,
  79. "deepseek-r1": default_key,
  80. "gpt-4o-all": default_key,
  81. # "gpt-5-mini-2025-08-07": default_key,
  82. "o3-mini-all": default_key,
  83. }
  84. API_URL = "https://yunwu.ai/v1"
  85. client = OpenAI(api_key=default_key, base_url=API_URL)
  86. # 锁用于避免并发文件写入导致的数据损坏
  87. FILE_LOCK = asyncio.Lock()
  88. MessageContent = Union[str, List[Dict[str, Any]]]
  89. def ensure_directories() -> None:
  90. for path in [DATA_DIR, BACKUP_DIR, BLOG_DIR, UPLOAD_DIR, STATIC_DIR]:
  91. path.mkdir(parents=True, exist_ok=True)
  92. def ensure_database_initialized() -> None:
  93. global ENGINE, SessionLocal
  94. if ENGINE is not None:
  95. return
  96. raw_engine = create_engine(RAW_DATABASE_URL, future=True, pool_pre_ping=True)
  97. with raw_engine.connect() as connection:
  98. connection.execute(
  99. text(
  100. f"CREATE DATABASE IF NOT EXISTS `{DATABASE_NAME}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
  101. )
  102. )
  103. raw_engine.dispose()
  104. ENGINE = create_engine(DATABASE_URL, future=True, pool_pre_ping=True)
  105. SessionLocal = sessionmaker(bind=ENGINE, autoflush=False, expire_on_commit=False, future=True)
  106. Base.metadata.create_all(bind=ENGINE)
  107. def now_utc() -> datetime.datetime:
  108. return datetime.datetime.utcnow()
  109. async def db_call(func: Callable[[Session], Any], *args: Any, **kwargs: Any) -> Any:
  110. def wrapped() -> Any:
  111. if SessionLocal is None:
  112. raise RuntimeError("数据库尚未初始化")
  113. with SessionLocal() as session:
  114. return func(session, *args, **kwargs)
  115. return await asyncio.to_thread(wrapped)
  116. class User(Base):
  117. __tablename__ = "users"
  118. id = Column(Integer, primary_key=True)
  119. username = Column(String(64), unique=True, nullable=False, index=True)
  120. password_hash = Column(String(128), nullable=False)
  121. salt = Column(String(32), nullable=False)
  122. role = Column(String(16), nullable=False, default="user")
  123. created_at = Column(DateTime, default=now_utc, nullable=False)
  124. updated_at = Column(DateTime, default=now_utc, onupdate=now_utc, nullable=False)
  125. sessions = relationship("ChatSession", back_populates="user", cascade="all, delete-orphan")
  126. class ChatSession(Base):
  127. __tablename__ = "chat_sessions"
  128. id = Column(Integer, primary_key=True)
  129. user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
  130. title = Column(String(128), nullable=True)
  131. archived = Column(Boolean, default=False, nullable=False)
  132. created_at = Column(DateTime, default=now_utc, nullable=False)
  133. updated_at = Column(DateTime, default=now_utc, nullable=False)
  134. user = relationship("User", back_populates="sessions")
  135. messages = relationship("ChatMessage", back_populates="session", cascade="all, delete-orphan")
  136. class ChatMessage(Base):
  137. __tablename__ = "chat_messages"
  138. id = Column(Integer, primary_key=True)
  139. session_id = Column(Integer, ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
  140. role = Column(String(16), nullable=False)
  141. content = Column(Text, nullable=False)
  142. created_at = Column(DateTime, default=now_utc, nullable=False)
  143. session = relationship("ChatSession", back_populates="messages")
  144. class AuthToken(Base):
  145. __tablename__ = "auth_tokens"
  146. token = Column(String(128), primary_key=True)
  147. user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
  148. expires_at = Column(DateTime, nullable=False)
  149. created_at = Column(DateTime, default=now_utc, nullable=False)
  150. user = relationship("User")
  151. class ExportedContent(Base):
  152. __tablename__ = "exported_contents"
  153. id = Column(Integer, primary_key=True)
  154. user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
  155. source_session_id = Column(Integer, ForeignKey("chat_sessions.id", ondelete="SET NULL"), nullable=True)
  156. filename = Column(String(255), nullable=False)
  157. file_path = Column(String(500), nullable=False)
  158. content_preview = Column(Text, nullable=True)
  159. created_at = Column(DateTime, default=now_utc, nullable=False)
  160. user = relationship("User")
  161. def hash_password(password: str, salt: str) -> str:
  162. return hashlib.sha256((password + salt).encode("utf-8")).hexdigest()
  163. def normalize_username(username: str) -> str:
  164. normalized = (username or "").strip()
  165. if len(normalized) < 3:
  166. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名至少需要 3 个字符")
  167. return normalized
  168. def enforce_password_strength(password: str) -> str:
  169. if not password or len(password) < 6:
  170. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="密码至少需要 6 位")
  171. return password
  172. def serialize_content(content: MessageContent) -> str:
  173. return json.dumps(content, ensure_ascii=False)
  174. def deserialize_content(raw: str) -> MessageContent:
  175. try:
  176. return json.loads(raw)
  177. except json.JSONDecodeError:
  178. return raw
  179. def ensure_default_admin() -> None:
  180. if SessionLocal is None:
  181. return
  182. with SessionLocal() as session:
  183. existing = session.execute(select(User).where(User.role == "admin")).first()
  184. if existing:
  185. return
  186. salt = secrets.token_hex(8)
  187. admin = User(
  188. username=DEFAULT_ADMIN_USERNAME,
  189. salt=salt,
  190. password_hash=hash_password(DEFAULT_ADMIN_PASSWORD, salt),
  191. role="admin",
  192. )
  193. session.add(admin)
  194. session.commit()
  195. class UserInfo(BaseModel):
  196. id: int
  197. username: str
  198. role: str
  199. async def create_auth_token(user_id: int) -> Dict[str, Any]:
  200. def creator(session: Session) -> Dict[str, Any]:
  201. token_value = secrets.token_hex(32)
  202. expires_at = now_utc() + datetime.timedelta(hours=TOKEN_TTL_HOURS)
  203. token = AuthToken(token=token_value, user_id=user_id, expires_at=expires_at)
  204. session.add(token)
  205. session.commit()
  206. return {"token": token_value, "expires_at": expires_at.isoformat()}
  207. return await db_call(creator)
  208. async def revoke_token(token_value: str) -> None:
  209. def remover(session: Session) -> None:
  210. session.query(AuthToken).filter(AuthToken.token == token_value).delete()
  211. session.commit()
  212. await db_call(remover)
  213. async def resolve_token(token_value: str) -> Optional[UserInfo]:
  214. def resolver(session: Session) -> Optional[UserInfo]:
  215. token = session.execute(select(AuthToken).where(AuthToken.token == token_value)).scalar_one_or_none()
  216. if not token:
  217. return None
  218. if token.expires_at < now_utc():
  219. session.delete(token)
  220. session.commit()
  221. return None
  222. user = session.get(User, token.user_id)
  223. if not user:
  224. session.delete(token)
  225. session.commit()
  226. return None
  227. return UserInfo(id=user.id, username=user.username, role=user.role)
  228. return await db_call(resolver)
  229. async def cleanup_expired_tokens() -> None:
  230. def cleaner(session: Session) -> None:
  231. session.query(AuthToken).filter(AuthToken.expires_at < now_utc()).delete()
  232. session.commit()
  233. await db_call(cleaner)
  234. async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(AUTH_SCHEME)) -> UserInfo:
  235. if not credentials:
  236. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录")
  237. user = await resolve_token(credentials.credentials)
  238. if not user:
  239. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效,请重新登录")
  240. return user
  241. async def require_admin(current_user: UserInfo = Depends(get_current_user)) -> UserInfo:
  242. if current_user.role != "admin":
  243. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
  244. return current_user
  245. def text_from_content(content: MessageContent) -> str:
  246. if isinstance(content, str):
  247. return content
  248. if isinstance(content, list):
  249. pieces: List[str] = []
  250. for part in content:
  251. if part.get("type") == "text":
  252. pieces.append(part.get("text", ""))
  253. return " ".join(pieces)
  254. return str(content)
  255. def extract_history_title(messages: List[Dict[str, Any]]) -> str:
  256. """Return the first meaningful title extracted from user messages."""
  257. for message in messages:
  258. if message.get("role") != "user":
  259. continue
  260. title = text_from_content(message.get("content", "")).strip()
  261. if title:
  262. return title[:10]
  263. if messages:
  264. fallback = text_from_content(messages[0].get("content", "")).strip()
  265. if fallback:
  266. return fallback[:10]
  267. return "空的聊天"[:10]
  268. def history_backup_path(session_id: int) -> Path:
  269. return BACKUP_DIR / f"chat_history_{session_id}.json"
  270. def build_download_url(filename: str) -> str:
  271. base = DOWNLOAD_BASE or ""
  272. return f"{base}/{filename}" if base else filename
  273. async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]:
  274. def loader(session: Session) -> List[Dict[str, Any]]:
  275. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  276. chat_session = session.execute(stmt).scalar_one_or_none()
  277. if not chat_session:
  278. raise HTTPException(status_code=404, detail="会话不存在")
  279. messages = (
  280. session.execute(
  281. select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
  282. )
  283. .scalars()
  284. .all()
  285. )
  286. return [{"role": message.role, "content": deserialize_content(message.content)} for message in messages]
  287. return await db_call(loader)
  288. async def create_chat_session(user_id: int) -> Dict[str, Any]:
  289. def creator(session: Session) -> Dict[str, Any]:
  290. chat_session = ChatSession(user_id=user_id)
  291. session.add(chat_session)
  292. session.commit()
  293. session.refresh(chat_session)
  294. return {"session_id": chat_session.id, "messages": []}
  295. return await db_call(creator)
  296. async def get_latest_session(user_id: int) -> Dict[str, Any]:
  297. def loader(session: Session) -> Dict[str, Any]:
  298. stmt = (
  299. select(ChatSession)
  300. .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
  301. .order_by(ChatSession.updated_at.desc())
  302. )
  303. chat_session = session.execute(stmt).scalars().first()
  304. if not chat_session:
  305. chat_session = ChatSession(user_id=user_id)
  306. session.add(chat_session)
  307. session.commit()
  308. session.refresh(chat_session)
  309. messages: List[Dict[str, Any]] = []
  310. else:
  311. messages = (
  312. session.execute(
  313. select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
  314. )
  315. .scalars()
  316. .all()
  317. )
  318. messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
  319. return {"session_id": chat_session.id, "messages": messages}
  320. return await db_call(loader)
  321. async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None:
  322. def writer(session: Session) -> None:
  323. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  324. chat_session = session.execute(stmt).scalar_one_or_none()
  325. if not chat_session:
  326. raise HTTPException(status_code=404, detail="会话不存在")
  327. serialized = serialize_content(content)
  328. message = ChatMessage(session_id=chat_session.id, role=role, content=serialized)
  329. session.add(message)
  330. chat_session.updated_at = now_utc()
  331. if role == "user" and (not chat_session.title or not chat_session.title.strip()):
  332. candidate = text_from_content(content).strip()
  333. if candidate:
  334. chat_session.title = candidate[:30]
  335. session.commit()
  336. await db_call(writer)
  337. async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]:
  338. def lister(session: Session) -> Dict[str, Any]:
  339. stmt = (
  340. select(ChatSession)
  341. .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
  342. .order_by(ChatSession.updated_at.desc())
  343. )
  344. sessions = session.execute(stmt).scalars().all()
  345. total = len(sessions)
  346. start = max(page, 0) * page_size
  347. end = start + page_size
  348. subset = sessions[start:end]
  349. items: List[Dict[str, Any]] = []
  350. for item in subset:
  351. items.append(
  352. {
  353. "session_id": item.id,
  354. "title": (item.title or f"会话 #{item.id}")[:30],
  355. "updated_at": (item.updated_at or now_utc()).isoformat(),
  356. "filename": f"session_{item.id}.json",
  357. }
  358. )
  359. return {"page": page, "page_size": page_size, "total": total, "items": items}
  360. return await db_call(lister)
  361. async def move_history_file(user_id: int, session_id: int) -> None:
  362. def mark_archived(session: Session) -> List[Dict[str, Any]]:
  363. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  364. chat_session = session.execute(stmt).scalar_one_or_none()
  365. if not chat_session:
  366. raise HTTPException(status_code=404, detail="历史记录不存在")
  367. messages = (
  368. session.execute(
  369. select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
  370. )
  371. .scalars()
  372. .all()
  373. )
  374. payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
  375. chat_session.archived = True
  376. chat_session.updated_at = now_utc()
  377. session.commit()
  378. return payload
  379. messages = await db_call(mark_archived)
  380. backup_file = history_backup_path(session_id)
  381. backup_file.parent.mkdir(parents=True, exist_ok=True)
  382. def _write() -> None:
  383. with backup_file.open("w", encoding="utf-8") as fp:
  384. json.dump(messages, fp, ensure_ascii=False, indent=2)
  385. async with FILE_LOCK:
  386. await asyncio.to_thread(_write)
  387. async def delete_history_file(user_id: int, session_id: int) -> None:
  388. def deleter(session: Session) -> None:
  389. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  390. chat_session = session.execute(stmt).scalar_one_or_none()
  391. if not chat_session:
  392. raise HTTPException(status_code=404, detail="历史记录不存在")
  393. session.delete(chat_session)
  394. session.commit()
  395. await db_call(deleter)
  396. async def export_message_to_blog(content: MessageContent) -> str:
  397. processed = text_from_content(content)
  398. processed = processed.replace("\r\n", "\n")
  399. timestamp = datetime.datetime.now().strftime("%m%d%H%M")
  400. first_10 = (
  401. processed[:10]
  402. .replace(" ", "")
  403. .replace("/", "")
  404. .replace("\\", "")
  405. .replace(":", "")
  406. .replace("`", "")
  407. )
  408. filename = f"{timestamp}_{first_10 or 'export'}.txt"
  409. path = BLOG_DIR / filename
  410. def _write() -> None:
  411. with path.open("w", encoding="utf-8") as fp:
  412. fp.write(processed)
  413. await asyncio.to_thread(_write)
  414. return str(path)
  415. async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]:
  416. def recorder(session: Session) -> Dict[str, Any]:
  417. filename = os.path.basename(file_path)
  418. preview = text_from_content(content).strip()[:200]
  419. export = ExportedContent(
  420. user_id=user_id,
  421. source_session_id=session_id,
  422. filename=filename,
  423. file_path=file_path,
  424. content_preview=preview,
  425. )
  426. session.add(export)
  427. session.commit()
  428. session.refresh(export)
  429. user = session.get(User, user_id)
  430. username = user.username if user else ""
  431. return {
  432. "id": export.id,
  433. "user_id": user_id,
  434. "username": username,
  435. "filename": filename,
  436. "file_path": file_path,
  437. "created_at": (export.created_at or now_utc()).isoformat(),
  438. "content_preview": preview,
  439. }
  440. return await db_call(recorder)
  441. async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]:
  442. def lister(session: Session) -> List[Dict[str, Any]]:
  443. stmt = (
  444. select(ExportedContent)
  445. .where(ExportedContent.user_id == user_id)
  446. .order_by(ExportedContent.created_at.desc())
  447. )
  448. exports = session.execute(stmt).scalars().all()
  449. user = session.get(User, user_id)
  450. username = user.username if user else ""
  451. results: List[Dict[str, Any]] = []
  452. for item in exports:
  453. results.append(
  454. {
  455. "id": item.id,
  456. "user_id": item.user_id,
  457. "username": username,
  458. "filename": item.filename,
  459. "file_path": item.file_path,
  460. "created_at": (item.created_at or now_utc()).isoformat(),
  461. "content_preview": (item.content_preview or "")[:200],
  462. }
  463. )
  464. return results
  465. return await db_call(lister)
  466. async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]:
  467. def lister(session: Session) -> List[Dict[str, Any]]:
  468. stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id)
  469. if keyword:
  470. stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
  471. stmt = stmt.order_by(ExportedContent.created_at.desc())
  472. rows = session.execute(stmt).all()
  473. results: List[Dict[str, Any]] = []
  474. for export, username in rows:
  475. results.append(
  476. {
  477. "id": export.id,
  478. "user_id": export.user_id,
  479. "username": username,
  480. "filename": export.filename,
  481. "file_path": export.file_path,
  482. "created_at": (export.created_at or now_utc()).isoformat(),
  483. "content_preview": (export.content_preview or "")[:200],
  484. }
  485. )
  486. return results
  487. return await db_call(lister)
  488. async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]:
  489. def fetcher(session: Session) -> Optional[Dict[str, Any]]:
  490. export = session.get(ExportedContent, export_id)
  491. if not export:
  492. return None
  493. user = session.get(User, export.user_id)
  494. username = user.username if user else ""
  495. return {
  496. "id": export.id,
  497. "user_id": export.user_id,
  498. "username": username,
  499. "filename": export.filename,
  500. "file_path": export.file_path,
  501. }
  502. return await db_call(fetcher)
  503. class RegisterRequest(BaseModel):
  504. username: str
  505. password: str
  506. class LoginRequest(BaseModel):
  507. username: str
  508. password: str
  509. class AdminUserRequest(BaseModel):
  510. username: str
  511. password: str
  512. class AdminUserUpdateRequest(BaseModel):
  513. username: Optional[str] = None
  514. password: Optional[str] = None
  515. class MessageModel(BaseModel):
  516. role: str
  517. content: MessageContent
  518. class ChatRequest(BaseModel):
  519. session_id: int
  520. model: str
  521. content: MessageContent
  522. history_count: int = 0
  523. stream: bool = True
  524. class HistoryActionRequest(BaseModel):
  525. session_id: int
  526. class ExportRequest(BaseModel):
  527. content: MessageContent
  528. session_id: Optional[int] = None
  529. class UploadResponseItem(BaseModel):
  530. type: str
  531. filename: str
  532. data: Optional[str] = None
  533. url: Optional[str] = None
  534. # 确保静态与数据目录在应用初始化前存在
  535. ensure_directories()
  536. ensure_database_initialized()
  537. ensure_default_admin()
  538. app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
  539. app.add_middleware(
  540. CORSMiddleware,
  541. allow_origins=["*"],
  542. allow_credentials=True,
  543. allow_methods=["*"],
  544. allow_headers=["*"],
  545. )
  546. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  547. @app.on_event("startup")
  548. async def on_startup() -> None:
  549. ensure_directories()
  550. ensure_database_initialized()
  551. ensure_default_admin()
  552. await cleanup_expired_tokens()
  553. INDEX_HTML = STATIC_DIR / "index.html"
  554. @app.get("/", response_class=HTMLResponse)
  555. async def serve_index() -> str:
  556. if not INDEX_HTML.exists():
  557. raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
  558. return INDEX_HTML.read_text(encoding="utf-8")
  559. @app.get("/download/{filename}")
  560. async def download_file(filename: str) -> FileResponse:
  561. target = UPLOAD_DIR / filename
  562. if not target.exists():
  563. raise HTTPException(status_code=404, detail="File not found")
  564. return FileResponse(target, filename=filename)
  565. @app.get("/api/config")
  566. async def get_config() -> Dict[str, Any]:
  567. models = list(MODEL_KEYS.keys())
  568. return {
  569. "title": "ChatGPT-like Clone",
  570. "models": models,
  571. "default_model": models[0] if models else "",
  572. "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
  573. "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
  574. }
  575. @app.post("/api/auth/register")
  576. async def api_register(payload: RegisterRequest) -> Dict[str, Any]:
  577. username = normalize_username(payload.username)
  578. password = enforce_password_strength(payload.password)
  579. def creator(session: Session) -> Dict[str, Any]:
  580. existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
  581. if existing:
  582. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
  583. salt = secrets.token_hex(8)
  584. user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
  585. session.add(user)
  586. session.commit()
  587. session.refresh(user)
  588. return {"id": user.id, "username": user.username, "role": user.role}
  589. user = await db_call(creator)
  590. token_data = await create_auth_token(user["id"])
  591. return {"user": user, "token": token_data["token"], "expires_at": token_data["expires_at"]}
  592. @app.post("/api/auth/login")
  593. async def api_login(payload: LoginRequest) -> Dict[str, Any]:
  594. username = (payload.username or "").strip()
  595. if not username:
  596. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请输入用户名")
  597. password = payload.password or ""
  598. def verifier(session: Session) -> Dict[str, Any]:
  599. user = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
  600. if not user:
  601. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
  602. hashed = hash_password(password, user.salt)
  603. if hashed != user.password_hash:
  604. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
  605. return {"id": user.id, "username": user.username, "role": user.role}
  606. user = await db_call(verifier)
  607. token_data = await create_auth_token(user["id"])
  608. return {"user": user, "token": token_data["token"], "expires_at": token_data["expires_at"]}
  609. @app.post("/api/auth/logout")
  610. async def api_logout(
  611. credentials: HTTPAuthorizationCredentials = Depends(AUTH_SCHEME),
  612. ) -> Dict[str, str]:
  613. if not credentials:
  614. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未登录")
  615. user = await resolve_token(credentials.credentials)
  616. if not user:
  617. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效")
  618. await revoke_token(credentials.credentials)
  619. return {"status": "ok"}
  620. @app.get("/api/auth/me")
  621. async def api_me(current_user: UserInfo = Depends(get_current_user)) -> UserInfo:
  622. return current_user
  623. @app.get("/api/session/latest")
  624. async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  625. return await get_latest_session(current_user.id)
  626. @app.get("/api/session/{session_id}")
  627. async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  628. messages = await load_messages(session_id, current_user.id)
  629. return {"session_id": session_id, "messages": messages}
  630. @app.post("/api/session/new")
  631. async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  632. return await create_chat_session(current_user.id)
  633. @app.get("/api/history")
  634. async def api_history(
  635. page: int = 0,
  636. page_size: int = 10,
  637. current_user: UserInfo = Depends(get_current_user),
  638. ) -> Dict[str, Any]:
  639. return await list_history(current_user.id, page, page_size)
  640. @app.post("/api/history/move")
  641. async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  642. await move_history_file(current_user.id, payload.session_id)
  643. return {"status": "ok"}
  644. @app.delete("/api/history/{session_id}")
  645. async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  646. await delete_history_file(current_user.id, session_id)
  647. return {"status": "ok"}
  648. @app.get("/api/admin/users")
  649. async def admin_list_users(
  650. keyword: Optional[str] = None,
  651. page: int = 0,
  652. page_size: int = 20,
  653. admin: UserInfo = Depends(require_admin),
  654. ) -> Dict[str, Any]:
  655. def lister(session: Session) -> Dict[str, Any]:
  656. stmt = select(User).order_by(User.created_at.desc())
  657. if keyword:
  658. stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
  659. users = session.execute(stmt).scalars().all()
  660. total = len(users)
  661. start = max(page, 0) * page_size
  662. end = start + page_size
  663. subset = users[start:end]
  664. items = [
  665. {
  666. "id": user.id,
  667. "username": user.username,
  668. "role": user.role,
  669. "created_at": (user.created_at or now_utc()).isoformat(),
  670. }
  671. for user in subset
  672. ]
  673. return {"items": items, "total": total, "page": page, "page_size": page_size}
  674. return await db_call(lister)
  675. @app.post("/api/admin/users")
  676. async def admin_create_user(payload: AdminUserRequest, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]:
  677. username = normalize_username(payload.username)
  678. password = enforce_password_strength(payload.password)
  679. def creator(session: Session) -> Dict[str, Any]:
  680. existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
  681. if existing:
  682. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
  683. salt = secrets.token_hex(8)
  684. user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
  685. session.add(user)
  686. session.commit()
  687. session.refresh(user)
  688. return {
  689. "id": user.id,
  690. "username": user.username,
  691. "role": user.role,
  692. "created_at": (user.created_at or now_utc()).isoformat(),
  693. }
  694. return await db_call(creator)
  695. @app.get("/api/admin/users/{user_id}")
  696. async def admin_get_user(user_id: int, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]:
  697. def getter(session: Session) -> Dict[str, Any]:
  698. user = session.get(User, user_id)
  699. if not user:
  700. raise HTTPException(status_code=404, detail="用户不存在")
  701. return {
  702. "id": user.id,
  703. "username": user.username,
  704. "role": user.role,
  705. "created_at": (user.created_at or now_utc()).isoformat(),
  706. }
  707. return await db_call(getter)
  708. @app.put("/api/admin/users/{user_id}")
  709. async def admin_update_user(
  710. user_id: int,
  711. payload: AdminUserUpdateRequest,
  712. admin: UserInfo = Depends(require_admin),
  713. ) -> Dict[str, Any]:
  714. def updater(session: Session) -> Dict[str, Any]:
  715. user = session.get(User, user_id)
  716. if not user:
  717. raise HTTPException(status_code=404, detail="用户不存在")
  718. if user.role == "admin":
  719. raise HTTPException(status_code=400, detail="无法修改管理员信息")
  720. if payload.username:
  721. new_username = normalize_username(payload.username)
  722. conflict = (
  723. session.execute(select(User).where(User.username == new_username, User.id != user_id)).scalar_one_or_none()
  724. )
  725. if conflict:
  726. raise HTTPException(status_code=400, detail="用户名已被使用")
  727. user.username = new_username
  728. if payload.password:
  729. enforce_password_strength(payload.password)
  730. salt = secrets.token_hex(8)
  731. user.salt = salt
  732. user.password_hash = hash_password(payload.password, salt)
  733. session.commit()
  734. return {
  735. "id": user.id,
  736. "username": user.username,
  737. "role": user.role,
  738. "created_at": (user.created_at or now_utc()).isoformat(),
  739. }
  740. return await db_call(updater)
  741. @app.delete("/api/admin/users/{user_id}")
  742. async def admin_delete_user(user_id: int, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]:
  743. def deleter(session: Session) -> Dict[str, Any]:
  744. user = session.get(User, user_id)
  745. if not user:
  746. raise HTTPException(status_code=404, detail="用户不存在")
  747. if user.role == "admin":
  748. raise HTTPException(status_code=400, detail="无法删除管理员")
  749. session.delete(user)
  750. session.commit()
  751. return {"status": "ok"}
  752. return await db_call(deleter)
  753. @app.post("/api/export")
  754. async def api_export_message(payload: ExportRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  755. path = await export_message_to_blog(payload.content)
  756. record = await record_export_entry(current_user.id, payload.session_id, path, payload.content)
  757. return {"status": "ok", "path": path, "export": record}
  758. @app.get("/api/exports/me")
  759. async def api_my_exports(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  760. items = await list_exports_for_user(current_user.id)
  761. return {"items": items}
  762. @app.get("/api/admin/exports")
  763. async def api_admin_exports(
  764. keyword: Optional[str] = None,
  765. admin: UserInfo = Depends(require_admin),
  766. ) -> Dict[str, Any]:
  767. items = await list_exports_admin(keyword)
  768. return {"items": items}
  769. @app.get("/api/exports/{export_id}/download")
  770. async def api_download_export(export_id: int, current_user: UserInfo = Depends(get_current_user)) -> FileResponse:
  771. record = await get_export_record(export_id)
  772. if not record:
  773. raise HTTPException(status_code=404, detail="导出记录不存在")
  774. if record["user_id"] != current_user.id and current_user.role != "admin":
  775. raise HTTPException(status_code=403, detail="无权下载该内容")
  776. file_path = Path(record["file_path"])
  777. if not file_path.exists():
  778. raise HTTPException(status_code=404, detail="导出文件不存在")
  779. return FileResponse(file_path, filename=record["filename"])
  780. @app.post("/api/upload")
  781. async def api_upload(
  782. files: List[UploadFile] = File(...),
  783. current_user: UserInfo = Depends(get_current_user),
  784. ) -> List[UploadResponseItem]:
  785. if not files:
  786. return []
  787. responses: List[UploadResponseItem] = []
  788. for upload in files:
  789. filename = upload.filename or "file"
  790. safe_filename = Path(filename).name or "file"
  791. content_type = (upload.content_type or "").lower()
  792. data = await upload.read()
  793. unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
  794. target_path = UPLOAD_DIR / unique_name
  795. def _write() -> None:
  796. with target_path.open("wb") as fp:
  797. fp.write(data)
  798. await asyncio.to_thread(_write)
  799. if content_type.startswith("image/"):
  800. encoded = base64.b64encode(data).decode("utf-8")
  801. data_url = f"data:{content_type};base64,{encoded}"
  802. responses.append(
  803. UploadResponseItem(
  804. type="image",
  805. filename=safe_filename,
  806. data=data_url,
  807. url=build_download_url(unique_name),
  808. )
  809. )
  810. else:
  811. responses.append(
  812. UploadResponseItem(
  813. type="file",
  814. filename=safe_filename,
  815. url=build_download_url(unique_name),
  816. )
  817. )
  818. return responses
  819. async def prepare_messages_for_completion(
  820. messages: List[Dict[str, Any]],
  821. user_content: MessageContent,
  822. history_count: int,
  823. ) -> List[Dict[str, Any]]:
  824. if history_count > 0:
  825. trimmed = messages[-history_count:]
  826. if trimmed:
  827. return trimmed
  828. return [{"role": "user", "content": user_content}]
  829. async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
  830. messages.append({"role": "assistant", "content": content})
  831. await append_message(session_id, user_id, "assistant", content)
  832. @app.post("/api/chat")
  833. async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)):
  834. if payload.model not in MODEL_KEYS:
  835. raise HTTPException(status_code=400, detail="未知的模型")
  836. messages = await load_messages(payload.session_id, current_user.id)
  837. user_message = {"role": "user", "content": payload.content}
  838. messages.append(user_message)
  839. await append_message(payload.session_id, current_user.id, "user", payload.content)
  840. client.api_key = MODEL_KEYS[payload.model]
  841. to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
  842. if payload.stream:
  843. queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
  844. aggregated: List[str] = []
  845. loop = asyncio.get_running_loop()
  846. def worker() -> None:
  847. try:
  848. response = client.chat.completions.create(
  849. model=payload.model,
  850. messages=to_send,
  851. stream=True,
  852. )
  853. for chunk in response:
  854. try:
  855. delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
  856. except (IndexError, AttributeError):
  857. delta = None
  858. if delta:
  859. aggregated.append(delta)
  860. asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
  861. asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
  862. except Exception as exc: # pragma: no cover - 网络调用
  863. asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
  864. threading.Thread(target=worker, daemon=True).start()
  865. async def streamer():
  866. try:
  867. while True:
  868. item = await queue.get()
  869. if item["type"] == "delta":
  870. yield json.dumps(item, ensure_ascii=False) + "\n"
  871. elif item["type"] == "complete":
  872. assistant_text = "".join(aggregated)
  873. await save_assistant_message(payload.session_id, current_user.id, messages, assistant_text)
  874. yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
  875. break
  876. elif item["type"] == "error":
  877. yield json.dumps(item, ensure_ascii=False) + "\n"
  878. break
  879. except asyncio.CancelledError: # pragma: no cover - 流被取消
  880. raise
  881. return StreamingResponse(streamer(), media_type="application/x-ndjson")
  882. try:
  883. completion = await asyncio.to_thread(
  884. client.chat.completions.create,
  885. model=payload.model,
  886. messages=to_send,
  887. stream=False,
  888. )
  889. except Exception as exc: # pragma: no cover - 网络调用
  890. raise HTTPException(status_code=500, detail=str(exc)) from exc
  891. choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
  892. if not choice:
  893. raise HTTPException(status_code=500, detail="响应格式不正确")
  894. assistant_content = getattr(choice.message, "content", "")
  895. if not assistant_content:
  896. assistant_content = ""
  897. await save_assistant_message(payload.session_id, current_user.id, messages, assistant_content)
  898. return {"message": assistant_content}
  899. if __name__ == "__main__":
  900. import uvicorn
  901. uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)