# -*- coding: utf-8 -*- import asyncio import base64 import datetime import hashlib import json import os import re import secrets import shutil import threading import uuid from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import quote_plus from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel from sqlalchemy import ( Boolean, Column, DateTime, ForeignKey, Integer, String, Text, create_engine, select, text, ) from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker from openai import OpenAI # ============================= # 基础配置 # ============================= BASE_DIR = Path(__file__).resolve().parent DATA_DIR = BASE_DIR / "data" BACKUP_DIR = BASE_DIR / "data_bak" BLOG_DIR = BASE_DIR / "blog" UPLOAD_DIR = BASE_DIR / "uploads" STATIC_DIR = BASE_DIR / "static" MYSQL_HOST = os.getenv("CHATFAST_DB_HOST", "127.0.0.1") MYSQL_PORT = int(os.getenv("CHATFAST_DB_PORT", "3306")) MYSQL_USER = os.getenv("CHATFAST_DB_USER", "root") MYSQL_PASSWORD = os.getenv("CHATFAST_DB_PASSWORD", "792199Zhao*") DATABASE_NAME = os.getenv("CHATFAST_DB_NAME", "chat_fast") ENCODED_PASSWORD = quote_plus(MYSQL_PASSWORD) RAW_DATABASE_URL = ( f"mysql+pymysql://{MYSQL_USER}:{ENCODED_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/?charset=utf8mb4" ) DATABASE_URL = ( f"mysql+pymysql://{MYSQL_USER}:{ENCODED_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{DATABASE_NAME}?charset=utf8mb4" ) TOKEN_TTL_HOURS = int(os.getenv("CHATFAST_TOKEN_TTL_HOURS", "72")) DEFAULT_ADMIN_USERNAME = os.getenv("CHATFAST_DEFAULT_ADMIN", "admin") DEFAULT_ADMIN_PASSWORD = os.getenv("CHATFAST_DEFAULT_ADMIN_PASSWORD", "Admin@123") Base = declarative_base() ENGINE = None SessionLocal: Optional[sessionmaker] = None AUTH_SCHEME = HTTPBearer(auto_error=False) # 默认上传文件下载地址,可通过环境变量覆盖 DEFAULT_UPLOAD_BASE = os.getenv("UPLOAD_BASE_URL", "/download/") DOWNLOAD_BASE = DEFAULT_UPLOAD_BASE.rstrip("/") # 与 appchat.py 相同的模型与密钥配置(仅示例) default_key = "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa" MODEL_KEYS: Dict[str, str] = { "grok-3": default_key, "grok-4": default_key, "gpt-5.1-2025-11-13": default_key, "gpt-5-2025-08-07": default_key, "gpt-4o-mini": default_key, # "gpt-4.1-mini-2025-04-14": default_key, "o1-mini": default_key, "o4-mini": default_key, "deepseek-v3": default_key, "deepseek-r1": default_key, "gpt-4o-all": default_key, # "gpt-5-mini-2025-08-07": default_key, "o3-mini-all": default_key, } API_URL = "https://yunwu.ai/v1" client = OpenAI(api_key=default_key, base_url=API_URL) # 锁用于避免并发文件写入导致的数据损坏 FILE_LOCK = asyncio.Lock() MessageContent = Union[str, List[Dict[str, Any]]] def ensure_directories() -> None: for path in [DATA_DIR, BACKUP_DIR, BLOG_DIR, UPLOAD_DIR, STATIC_DIR]: path.mkdir(parents=True, exist_ok=True) def ensure_database_initialized() -> None: global ENGINE, SessionLocal if ENGINE is not None: return raw_engine = create_engine(RAW_DATABASE_URL, future=True, pool_pre_ping=True) with raw_engine.connect() as connection: connection.execute( text( f"CREATE DATABASE IF NOT EXISTS `{DATABASE_NAME}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" ) ) raw_engine.dispose() ENGINE = create_engine(DATABASE_URL, future=True, pool_pre_ping=True) SessionLocal = sessionmaker(bind=ENGINE, autoflush=False, expire_on_commit=False, future=True) Base.metadata.create_all(bind=ENGINE) def now_utc() -> datetime.datetime: return datetime.datetime.utcnow() async def db_call(func: Callable[[Session], Any], *args: Any, **kwargs: Any) -> Any: def wrapped() -> Any: if SessionLocal is None: raise RuntimeError("数据库尚未初始化") with SessionLocal() as session: return func(session, *args, **kwargs) return await asyncio.to_thread(wrapped) class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) username = Column(String(64), unique=True, nullable=False, index=True) password_hash = Column(String(128), nullable=False) salt = Column(String(32), nullable=False) role = Column(String(16), nullable=False, default="user") created_at = Column(DateTime, default=now_utc, nullable=False) updated_at = Column(DateTime, default=now_utc, onupdate=now_utc, nullable=False) sessions = relationship("ChatSession", back_populates="user", cascade="all, delete-orphan") class ChatSession(Base): __tablename__ = "chat_sessions" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) title = Column(String(128), nullable=True) archived = Column(Boolean, default=False, nullable=False) created_at = Column(DateTime, default=now_utc, nullable=False) updated_at = Column(DateTime, default=now_utc, nullable=False) user = relationship("User", back_populates="sessions") messages = relationship("ChatMessage", back_populates="session", cascade="all, delete-orphan") class ChatMessage(Base): __tablename__ = "chat_messages" id = Column(Integer, primary_key=True) session_id = Column(Integer, ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True) role = Column(String(16), nullable=False) content = Column(Text, nullable=False) created_at = Column(DateTime, default=now_utc, nullable=False) session = relationship("ChatSession", back_populates="messages") class AuthToken(Base): __tablename__ = "auth_tokens" token = Column(String(128), primary_key=True) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) expires_at = Column(DateTime, nullable=False) created_at = Column(DateTime, default=now_utc, nullable=False) user = relationship("User") class ExportedContent(Base): __tablename__ = "exported_contents" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) source_session_id = Column(Integer, ForeignKey("chat_sessions.id", ondelete="SET NULL"), nullable=True) filename = Column(String(255), nullable=False) file_path = Column(String(500), nullable=False) content_preview = Column(Text, nullable=True) created_at = Column(DateTime, default=now_utc, nullable=False) user = relationship("User") def hash_password(password: str, salt: str) -> str: return hashlib.sha256((password + salt).encode("utf-8")).hexdigest() def normalize_username(username: str) -> str: normalized = (username or "").strip() if len(normalized) < 3: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名至少需要 3 个字符") return normalized def enforce_password_strength(password: str) -> str: if not password or len(password) < 6: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="密码至少需要 6 位") return password def serialize_content(content: MessageContent) -> str: return json.dumps(content, ensure_ascii=False) def deserialize_content(raw: str) -> MessageContent: try: return json.loads(raw) except json.JSONDecodeError: return raw def ensure_default_admin() -> None: if SessionLocal is None: return with SessionLocal() as session: existing = session.execute(select(User).where(User.role == "admin")).first() if existing: return salt = secrets.token_hex(8) admin = User( username=DEFAULT_ADMIN_USERNAME, salt=salt, password_hash=hash_password(DEFAULT_ADMIN_PASSWORD, salt), role="admin", ) session.add(admin) session.commit() class UserInfo(BaseModel): id: int username: str role: str async def create_auth_token(user_id: int) -> Dict[str, Any]: def creator(session: Session) -> Dict[str, Any]: token_value = secrets.token_hex(32) expires_at = now_utc() + datetime.timedelta(hours=TOKEN_TTL_HOURS) token = AuthToken(token=token_value, user_id=user_id, expires_at=expires_at) session.add(token) session.commit() return {"token": token_value, "expires_at": expires_at.isoformat()} return await db_call(creator) async def revoke_token(token_value: str) -> None: def remover(session: Session) -> None: session.query(AuthToken).filter(AuthToken.token == token_value).delete() session.commit() await db_call(remover) async def resolve_token(token_value: str) -> Optional[UserInfo]: def resolver(session: Session) -> Optional[UserInfo]: token = session.execute(select(AuthToken).where(AuthToken.token == token_value)).scalar_one_or_none() if not token: return None if token.expires_at < now_utc(): session.delete(token) session.commit() return None user = session.get(User, token.user_id) if not user: session.delete(token) session.commit() return None return UserInfo(id=user.id, username=user.username, role=user.role) return await db_call(resolver) async def cleanup_expired_tokens() -> None: def cleaner(session: Session) -> None: session.query(AuthToken).filter(AuthToken.expires_at < now_utc()).delete() session.commit() await db_call(cleaner) async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(AUTH_SCHEME)) -> UserInfo: if not credentials: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录") user = await resolve_token(credentials.credentials) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效,请重新登录") return user async def require_admin(current_user: UserInfo = Depends(get_current_user)) -> UserInfo: if current_user.role != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限") return current_user def text_from_content(content: MessageContent) -> str: if isinstance(content, str): return content if isinstance(content, list): pieces: List[str] = [] for part in content: if part.get("type") == "text": pieces.append(part.get("text", "")) return " ".join(pieces) return str(content) def extract_history_title(messages: List[Dict[str, Any]]) -> str: """Return the first meaningful title extracted from user messages.""" for message in messages: if message.get("role") != "user": continue title = text_from_content(message.get("content", "")).strip() if title: return title[:10] if messages: fallback = text_from_content(messages[0].get("content", "")).strip() if fallback: return fallback[:10] return "空的聊天"[:10] def history_backup_path(session_id: int) -> Path: return BACKUP_DIR / f"chat_history_{session_id}.json" def build_download_url(filename: str) -> str: base = DOWNLOAD_BASE or "" return f"{base}/{filename}" if base else filename async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]: def loader(session: Session) -> List[Dict[str, Any]]: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="会话不存在") messages = ( session.execute( select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id) ) .scalars() .all() ) return [{"role": message.role, "content": deserialize_content(message.content)} for message in messages] return await db_call(loader) async def create_chat_session(user_id: int) -> Dict[str, Any]: def creator(session: Session) -> Dict[str, Any]: chat_session = ChatSession(user_id=user_id) session.add(chat_session) session.commit() session.refresh(chat_session) return {"session_id": chat_session.id, "messages": []} return await db_call(creator) async def get_latest_session(user_id: int) -> Dict[str, Any]: def loader(session: Session) -> Dict[str, Any]: stmt = ( select(ChatSession) .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False)) .order_by(ChatSession.updated_at.desc()) ) chat_session = session.execute(stmt).scalars().first() if not chat_session: chat_session = ChatSession(user_id=user_id) session.add(chat_session) session.commit() session.refresh(chat_session) messages: List[Dict[str, Any]] = [] else: messages = ( session.execute( select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id) ) .scalars() .all() ) messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages] return {"session_id": chat_session.id, "messages": messages} return await db_call(loader) async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None: def writer(session: Session) -> None: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="会话不存在") serialized = serialize_content(content) message = ChatMessage(session_id=chat_session.id, role=role, content=serialized) session.add(message) chat_session.updated_at = now_utc() if role == "user" and (not chat_session.title or not chat_session.title.strip()): candidate = text_from_content(content).strip() if candidate: chat_session.title = candidate[:30] session.commit() await db_call(writer) async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]: def lister(session: Session) -> Dict[str, Any]: stmt = ( select(ChatSession) .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False)) .order_by(ChatSession.updated_at.desc()) ) sessions = session.execute(stmt).scalars().all() total = len(sessions) start = max(page, 0) * page_size end = start + page_size subset = sessions[start:end] items: List[Dict[str, Any]] = [] for item in subset: items.append( { "session_id": item.id, "title": (item.title or f"会话 #{item.id}")[:30], "updated_at": (item.updated_at or now_utc()).isoformat(), "filename": f"session_{item.id}.json", } ) return {"page": page, "page_size": page_size, "total": total, "items": items} return await db_call(lister) async def move_history_file(user_id: int, session_id: int) -> None: def mark_archived(session: Session) -> List[Dict[str, Any]]: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="历史记录不存在") messages = ( session.execute( select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id) ) .scalars() .all() ) payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages] chat_session.archived = True chat_session.updated_at = now_utc() session.commit() return payload messages = await db_call(mark_archived) backup_file = history_backup_path(session_id) backup_file.parent.mkdir(parents=True, exist_ok=True) def _write() -> None: with backup_file.open("w", encoding="utf-8") as fp: json.dump(messages, fp, ensure_ascii=False, indent=2) async with FILE_LOCK: await asyncio.to_thread(_write) async def delete_history_file(user_id: int, session_id: int) -> None: def deleter(session: Session) -> None: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="历史记录不存在") session.delete(chat_session) session.commit() await db_call(deleter) async def export_message_to_blog(content: MessageContent) -> str: processed = text_from_content(content) processed = processed.replace("\r\n", "\n") timestamp = datetime.datetime.now().strftime("%m%d%H%M") first_10 = ( processed[:10] .replace(" ", "") .replace("/", "") .replace("\\", "") .replace(":", "") .replace("`", "") ) filename = f"{timestamp}_{first_10 or 'export'}.txt" path = BLOG_DIR / filename def _write() -> None: with path.open("w", encoding="utf-8") as fp: fp.write(processed) await asyncio.to_thread(_write) return str(path) async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]: def recorder(session: Session) -> Dict[str, Any]: filename = os.path.basename(file_path) preview = text_from_content(content).strip()[:200] export = ExportedContent( user_id=user_id, source_session_id=session_id, filename=filename, file_path=file_path, content_preview=preview, ) session.add(export) session.commit() session.refresh(export) user = session.get(User, user_id) username = user.username if user else "" return { "id": export.id, "user_id": user_id, "username": username, "filename": filename, "file_path": file_path, "created_at": (export.created_at or now_utc()).isoformat(), "content_preview": preview, } return await db_call(recorder) async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]: def lister(session: Session) -> List[Dict[str, Any]]: stmt = ( select(ExportedContent) .where(ExportedContent.user_id == user_id) .order_by(ExportedContent.created_at.desc()) ) exports = session.execute(stmt).scalars().all() user = session.get(User, user_id) username = user.username if user else "" results: List[Dict[str, Any]] = [] for item in exports: results.append( { "id": item.id, "user_id": item.user_id, "username": username, "filename": item.filename, "file_path": item.file_path, "created_at": (item.created_at or now_utc()).isoformat(), "content_preview": (item.content_preview or "")[:200], } ) return results return await db_call(lister) async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]: def lister(session: Session) -> List[Dict[str, Any]]: stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id) if keyword: stmt = stmt.where(User.username.like(f"%{keyword.strip()}%")) stmt = stmt.order_by(ExportedContent.created_at.desc()) rows = session.execute(stmt).all() results: List[Dict[str, Any]] = [] for export, username in rows: results.append( { "id": export.id, "user_id": export.user_id, "username": username, "filename": export.filename, "file_path": export.file_path, "created_at": (export.created_at or now_utc()).isoformat(), "content_preview": (export.content_preview or "")[:200], } ) return results return await db_call(lister) async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]: def fetcher(session: Session) -> Optional[Dict[str, Any]]: export = session.get(ExportedContent, export_id) if not export: return None user = session.get(User, export.user_id) username = user.username if user else "" return { "id": export.id, "user_id": export.user_id, "username": username, "filename": export.filename, "file_path": export.file_path, } return await db_call(fetcher) class RegisterRequest(BaseModel): username: str password: str class LoginRequest(BaseModel): username: str password: str class AdminUserRequest(BaseModel): username: str password: str class AdminUserUpdateRequest(BaseModel): username: Optional[str] = None password: Optional[str] = None class MessageModel(BaseModel): role: str content: MessageContent class ChatRequest(BaseModel): session_id: int model: str content: MessageContent history_count: int = 0 stream: bool = True class HistoryActionRequest(BaseModel): session_id: int class ExportRequest(BaseModel): content: MessageContent session_id: Optional[int] = None class UploadResponseItem(BaseModel): type: str filename: str data: Optional[str] = None url: Optional[str] = None # 确保静态与数据目录在应用初始化前存在 ensure_directories() ensure_database_initialized() ensure_default_admin() app = FastAPI(title="ChatGPT-like Clone", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.on_event("startup") async def on_startup() -> None: ensure_directories() ensure_database_initialized() ensure_default_admin() await cleanup_expired_tokens() INDEX_HTML = STATIC_DIR / "index.html" @app.get("/", response_class=HTMLResponse) async def serve_index() -> str: if not INDEX_HTML.exists(): raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在") return INDEX_HTML.read_text(encoding="utf-8") @app.get("/download/{filename}") async def download_file(filename: str) -> FileResponse: target = UPLOAD_DIR / filename if not target.exists(): raise HTTPException(status_code=404, detail="File not found") return FileResponse(target, filename=filename) @app.get("/api/config") async def get_config() -> Dict[str, Any]: models = list(MODEL_KEYS.keys()) return { "title": "ChatGPT-like Clone", "models": models, "default_model": models[0] if models else "", "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"], "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "", } @app.post("/api/auth/register") async def api_register(payload: RegisterRequest) -> Dict[str, Any]: username = normalize_username(payload.username) password = enforce_password_strength(payload.password) def creator(session: Session) -> Dict[str, Any]: existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none() if existing: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在") salt = secrets.token_hex(8) user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user") session.add(user) session.commit() session.refresh(user) return {"id": user.id, "username": user.username, "role": user.role} user = await db_call(creator) token_data = await create_auth_token(user["id"]) return {"user": user, "token": token_data["token"], "expires_at": token_data["expires_at"]} @app.post("/api/auth/login") async def api_login(payload: LoginRequest) -> Dict[str, Any]: username = (payload.username or "").strip() if not username: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请输入用户名") password = payload.password or "" def verifier(session: Session) -> Dict[str, Any]: user = session.execute(select(User).where(User.username == username)).scalar_one_or_none() if not user: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误") hashed = hash_password(password, user.salt) if hashed != user.password_hash: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误") return {"id": user.id, "username": user.username, "role": user.role} user = await db_call(verifier) token_data = await create_auth_token(user["id"]) return {"user": user, "token": token_data["token"], "expires_at": token_data["expires_at"]} @app.post("/api/auth/logout") async def api_logout( credentials: HTTPAuthorizationCredentials = Depends(AUTH_SCHEME), ) -> Dict[str, str]: if not credentials: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未登录") user = await resolve_token(credentials.credentials) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效") await revoke_token(credentials.credentials) return {"status": "ok"} @app.get("/api/auth/me") async def api_me(current_user: UserInfo = Depends(get_current_user)) -> UserInfo: return current_user @app.get("/api/session/latest") async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: return await get_latest_session(current_user.id) @app.get("/api/session/{session_id}") async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: messages = await load_messages(session_id, current_user.id) return {"session_id": session_id, "messages": messages} @app.post("/api/session/new") async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: return await create_chat_session(current_user.id) @app.get("/api/history") async def api_history( page: int = 0, page_size: int = 10, current_user: UserInfo = Depends(get_current_user), ) -> Dict[str, Any]: return await list_history(current_user.id, page, page_size) @app.post("/api/history/move") async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: await move_history_file(current_user.id, payload.session_id) return {"status": "ok"} @app.delete("/api/history/{session_id}") async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: await delete_history_file(current_user.id, session_id) return {"status": "ok"} @app.get("/api/admin/users") async def admin_list_users( keyword: Optional[str] = None, page: int = 0, page_size: int = 20, admin: UserInfo = Depends(require_admin), ) -> Dict[str, Any]: def lister(session: Session) -> Dict[str, Any]: stmt = select(User).order_by(User.created_at.desc()) if keyword: stmt = stmt.where(User.username.like(f"%{keyword.strip()}%")) users = session.execute(stmt).scalars().all() total = len(users) start = max(page, 0) * page_size end = start + page_size subset = users[start:end] items = [ { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } for user in subset ] return {"items": items, "total": total, "page": page, "page_size": page_size} return await db_call(lister) @app.post("/api/admin/users") async def admin_create_user(payload: AdminUserRequest, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]: username = normalize_username(payload.username) password = enforce_password_strength(payload.password) def creator(session: Session) -> Dict[str, Any]: existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none() if existing: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在") salt = secrets.token_hex(8) user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user") session.add(user) session.commit() session.refresh(user) return { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } return await db_call(creator) @app.get("/api/admin/users/{user_id}") async def admin_get_user(user_id: int, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]: def getter(session: Session) -> Dict[str, Any]: user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="用户不存在") return { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } return await db_call(getter) @app.put("/api/admin/users/{user_id}") async def admin_update_user( user_id: int, payload: AdminUserUpdateRequest, admin: UserInfo = Depends(require_admin), ) -> Dict[str, Any]: def updater(session: Session) -> Dict[str, Any]: user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="用户不存在") if user.role == "admin": raise HTTPException(status_code=400, detail="无法修改管理员信息") if payload.username: new_username = normalize_username(payload.username) conflict = ( session.execute(select(User).where(User.username == new_username, User.id != user_id)).scalar_one_or_none() ) if conflict: raise HTTPException(status_code=400, detail="用户名已被使用") user.username = new_username if payload.password: enforce_password_strength(payload.password) salt = secrets.token_hex(8) user.salt = salt user.password_hash = hash_password(payload.password, salt) session.commit() return { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } return await db_call(updater) @app.delete("/api/admin/users/{user_id}") async def admin_delete_user(user_id: int, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]: def deleter(session: Session) -> Dict[str, Any]: user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="用户不存在") if user.role == "admin": raise HTTPException(status_code=400, detail="无法删除管理员") session.delete(user) session.commit() return {"status": "ok"} return await db_call(deleter) @app.post("/api/export") async def api_export_message(payload: ExportRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: path = await export_message_to_blog(payload.content) record = await record_export_entry(current_user.id, payload.session_id, path, payload.content) return {"status": "ok", "path": path, "export": record} @app.get("/api/exports/me") async def api_my_exports(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: items = await list_exports_for_user(current_user.id) return {"items": items} @app.get("/api/admin/exports") async def api_admin_exports( keyword: Optional[str] = None, admin: UserInfo = Depends(require_admin), ) -> Dict[str, Any]: items = await list_exports_admin(keyword) return {"items": items} @app.get("/api/exports/{export_id}/download") async def api_download_export(export_id: int, current_user: UserInfo = Depends(get_current_user)) -> FileResponse: record = await get_export_record(export_id) if not record: raise HTTPException(status_code=404, detail="导出记录不存在") if record["user_id"] != current_user.id and current_user.role != "admin": raise HTTPException(status_code=403, detail="无权下载该内容") file_path = Path(record["file_path"]) if not file_path.exists(): raise HTTPException(status_code=404, detail="导出文件不存在") return FileResponse(file_path, filename=record["filename"]) @app.post("/api/upload") async def api_upload( files: List[UploadFile] = File(...), current_user: UserInfo = Depends(get_current_user), ) -> List[UploadResponseItem]: if not files: return [] responses: List[UploadResponseItem] = [] for upload in files: filename = upload.filename or "file" safe_filename = Path(filename).name or "file" content_type = (upload.content_type or "").lower() data = await upload.read() unique_name = f"{uuid.uuid4().hex}_{safe_filename}" target_path = UPLOAD_DIR / unique_name def _write() -> None: with target_path.open("wb") as fp: fp.write(data) await asyncio.to_thread(_write) if content_type.startswith("image/"): encoded = base64.b64encode(data).decode("utf-8") data_url = f"data:{content_type};base64,{encoded}" responses.append( UploadResponseItem( type="image", filename=safe_filename, data=data_url, url=build_download_url(unique_name), ) ) else: responses.append( UploadResponseItem( type="file", filename=safe_filename, url=build_download_url(unique_name), ) ) return responses async def prepare_messages_for_completion( messages: List[Dict[str, Any]], user_content: MessageContent, history_count: int, ) -> List[Dict[str, Any]]: if history_count > 0: trimmed = messages[-history_count:] if trimmed: return trimmed return [{"role": "user", "content": user_content}] async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None: messages.append({"role": "assistant", "content": content}) await append_message(session_id, user_id, "assistant", content) @app.post("/api/chat") async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)): if payload.model not in MODEL_KEYS: raise HTTPException(status_code=400, detail="未知的模型") messages = await load_messages(payload.session_id, current_user.id) user_message = {"role": "user", "content": payload.content} messages.append(user_message) await append_message(payload.session_id, current_user.id, "user", payload.content) client.api_key = MODEL_KEYS[payload.model] to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0)) if payload.stream: queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() aggregated: List[str] = [] loop = asyncio.get_running_loop() def worker() -> None: try: response = client.chat.completions.create( model=payload.model, messages=to_send, stream=True, ) for chunk in response: try: delta = chunk.choices[0].delta.content # type: ignore[attr-defined] except (IndexError, AttributeError): delta = None if delta: aggregated.append(delta) asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop) asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop) except Exception as exc: # pragma: no cover - 网络调用 asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop) threading.Thread(target=worker, daemon=True).start() async def streamer(): try: while True: item = await queue.get() if item["type"] == "delta": yield json.dumps(item, ensure_ascii=False) + "\n" elif item["type"] == "complete": assistant_text = "".join(aggregated) await save_assistant_message(payload.session_id, current_user.id, messages, assistant_text) yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n" break elif item["type"] == "error": yield json.dumps(item, ensure_ascii=False) + "\n" break except asyncio.CancelledError: # pragma: no cover - 流被取消 raise return StreamingResponse(streamer(), media_type="application/x-ndjson") try: completion = await asyncio.to_thread( client.chat.completions.create, model=payload.model, messages=to_send, stream=False, ) except Exception as exc: # pragma: no cover - 网络调用 raise HTTPException(status_code=500, detail=str(exc)) from exc choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined] if not choice: raise HTTPException(status_code=500, detail="响应格式不正确") assistant_content = getattr(choice.message, "content", "") if not assistant_content: assistant_content = "" await save_assistant_message(payload.session_id, current_user.id, messages, assistant_content) return {"message": assistant_content} if __name__ == "__main__": import uvicorn uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)