||
- # -*- coding: utf-8 -*-
- import asyncio
- import base64
- import datetime
- import hashlib
- import json
- import os
- import re
- import secrets
- import shutil
- import threading
- import uuid
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional, Union
- from urllib.parse import quote_plus
- from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File, status
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
- from pydantic import BaseModel
- from sqlalchemy import (
- Boolean,
- Column,
- DateTime,
- ForeignKey,
- Integer,
- String,
- Text,
- create_engine,
- select,
- text,
- )
- from sqlalchemy.orm import Session, declarative_base, relationship, sessionmaker
- from openai import OpenAI
- # =============================
- # 基础配置
- # =============================
- BASE_DIR = Path(__file__).resolve().parent
- DATA_DIR = BASE_DIR / "data"
- BACKUP_DIR = BASE_DIR / "data_bak"
- BLOG_DIR = BASE_DIR / "blog"
- UPLOAD_DIR = BASE_DIR / "uploads"
- STATIC_DIR = BASE_DIR / "static"
- MYSQL_HOST = os.getenv("CHATFAST_DB_HOST", "127.0.0.1")
- MYSQL_PORT = int(os.getenv("CHATFAST_DB_PORT", "3306"))
- MYSQL_USER = os.getenv("CHATFAST_DB_USER", "root")
- MYSQL_PASSWORD = os.getenv("CHATFAST_DB_PASSWORD", "792199Zhao*")
- DATABASE_NAME = os.getenv("CHATFAST_DB_NAME", "chat_fast")
- ENCODED_PASSWORD = quote_plus(MYSQL_PASSWORD)
- RAW_DATABASE_URL = (
- f"mysql+pymysql://{MYSQL_USER}:{ENCODED_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/?charset=utf8mb4"
- )
- DATABASE_URL = (
- f"mysql+pymysql://{MYSQL_USER}:{ENCODED_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{DATABASE_NAME}?charset=utf8mb4"
- )
- TOKEN_TTL_HOURS = int(os.getenv("CHATFAST_TOKEN_TTL_HOURS", "72"))
- DEFAULT_ADMIN_USERNAME = os.getenv("CHATFAST_DEFAULT_ADMIN", "admin")
- DEFAULT_ADMIN_PASSWORD = os.getenv("CHATFAST_DEFAULT_ADMIN_PASSWORD", "Admin@123")
- Base = declarative_base()
- ENGINE = None
- SessionLocal: Optional[sessionmaker] = None
- AUTH_SCHEME = HTTPBearer(auto_error=False)
- # 默认上传文件下载地址,可通过环境变量覆盖
- DEFAULT_UPLOAD_BASE = os.getenv("UPLOAD_BASE_URL", "/download/")
- DOWNLOAD_BASE = DEFAULT_UPLOAD_BASE.rstrip("/")
- # 与 appchat.py 相同的模型与密钥配置(仅示例)
- default_key = "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa"
- MODEL_KEYS: Dict[str, str] = {
- "grok-3": default_key,
- "grok-4": default_key,
- "gpt-5.1-2025-11-13": default_key,
- "gpt-5-2025-08-07": default_key,
- "gpt-4o-mini": default_key,
- # "gpt-4.1-mini-2025-04-14": default_key,
- "o1-mini": default_key,
- "o4-mini": default_key,
- "deepseek-v3": default_key,
- "deepseek-r1": default_key,
- "gpt-4o-all": default_key,
- # "gpt-5-mini-2025-08-07": default_key,
- "o3-mini-all": default_key,
- }
- API_URL = "https://yunwu.ai/v1"
- client = OpenAI(api_key=default_key, base_url=API_URL)
- # 锁用于避免并发文件写入导致的数据损坏
- FILE_LOCK = asyncio.Lock()
- MessageContent = Union[str, List[Dict[str, Any]]]
- def ensure_directories() -> None:
- for path in [DATA_DIR, BACKUP_DIR, BLOG_DIR, UPLOAD_DIR, STATIC_DIR]:
- path.mkdir(parents=True, exist_ok=True)
- def ensure_database_initialized() -> None:
- global ENGINE, SessionLocal
- if ENGINE is not None:
- return
- raw_engine = create_engine(RAW_DATABASE_URL, future=True, pool_pre_ping=True)
- with raw_engine.connect() as connection:
- connection.execute(
- text(
- f"CREATE DATABASE IF NOT EXISTS `{DATABASE_NAME}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
- )
- )
- raw_engine.dispose()
- ENGINE = create_engine(DATABASE_URL, future=True, pool_pre_ping=True)
- SessionLocal = sessionmaker(bind=ENGINE, autoflush=False, expire_on_commit=False, future=True)
- Base.metadata.create_all(bind=ENGINE)
- def now_utc() -> datetime.datetime:
- return datetime.datetime.utcnow()
- async def db_call(func: Callable[[Session], Any], *args: Any, **kwargs: Any) -> Any:
- def wrapped() -> Any:
- if SessionLocal is None:
- raise RuntimeError("数据库尚未初始化")
- with SessionLocal() as session:
- return func(session, *args, **kwargs)
- return await asyncio.to_thread(wrapped)
- class User(Base):
- __tablename__ = "users"
- id = Column(Integer, primary_key=True)
- username = Column(String(64), unique=True, nullable=False, index=True)
- password_hash = Column(String(128), nullable=False)
- salt = Column(String(32), nullable=False)
- role = Column(String(16), nullable=False, default="user")
- created_at = Column(DateTime, default=now_utc, nullable=False)
- updated_at = Column(DateTime, default=now_utc, onupdate=now_utc, nullable=False)
- sessions = relationship("ChatSession", back_populates="user", cascade="all, delete-orphan")
- class ChatSession(Base):
- __tablename__ = "chat_sessions"
- id = Column(Integer, primary_key=True)
- user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
- title = Column(String(128), nullable=True)
- archived = Column(Boolean, default=False, nullable=False)
- created_at = Column(DateTime, default=now_utc, nullable=False)
- updated_at = Column(DateTime, default=now_utc, nullable=False)
- user = relationship("User", back_populates="sessions")
- messages = relationship("ChatMessage", back_populates="session", cascade="all, delete-orphan")
- class ChatMessage(Base):
- __tablename__ = "chat_messages"
- id = Column(Integer, primary_key=True)
- session_id = Column(Integer, ForeignKey("chat_sessions.id", ondelete="CASCADE"), nullable=False, index=True)
- role = Column(String(16), nullable=False)
- content = Column(Text, nullable=False)
- created_at = Column(DateTime, default=now_utc, nullable=False)
- session = relationship("ChatSession", back_populates="messages")
- class AuthToken(Base):
- __tablename__ = "auth_tokens"
- token = Column(String(128), primary_key=True)
- user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
- expires_at = Column(DateTime, nullable=False)
- created_at = Column(DateTime, default=now_utc, nullable=False)
- user = relationship("User")
- class ExportedContent(Base):
- __tablename__ = "exported_contents"
- id = Column(Integer, primary_key=True)
- user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
- source_session_id = Column(Integer, ForeignKey("chat_sessions.id", ondelete="SET NULL"), nullable=True)
- filename = Column(String(255), nullable=False)
- file_path = Column(String(500), nullable=False)
- content_preview = Column(Text, nullable=True)
- created_at = Column(DateTime, default=now_utc, nullable=False)
- user = relationship("User")
- def hash_password(password: str, salt: str) -> str:
- return hashlib.sha256((password + salt).encode("utf-8")).hexdigest()
- def normalize_username(username: str) -> str:
- normalized = (username or "").strip()
- if len(normalized) < 3:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名至少需要 3 个字符")
- return normalized
- def enforce_password_strength(password: str) -> str:
- if not password or len(password) < 6:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="密码至少需要 6 位")
- return password
- def serialize_content(content: MessageContent) -> str:
- return json.dumps(content, ensure_ascii=False)
- def deserialize_content(raw: str) -> MessageContent:
- try:
- return json.loads(raw)
- except json.JSONDecodeError:
- return raw
- def ensure_default_admin() -> None:
- if SessionLocal is None:
- return
- with SessionLocal() as session:
- existing = session.execute(select(User).where(User.role == "admin")).first()
- if existing:
- return
- salt = secrets.token_hex(8)
- admin = User(
- username=DEFAULT_ADMIN_USERNAME,
- salt=salt,
- password_hash=hash_password(DEFAULT_ADMIN_PASSWORD, salt),
- role="admin",
- )
- session.add(admin)
- session.commit()
- class UserInfo(BaseModel):
- id: int
- username: str
- role: str
- async def create_auth_token(user_id: int) -> Dict[str, Any]:
- def creator(session: Session) -> Dict[str, Any]:
- token_value = secrets.token_hex(32)
- expires_at = now_utc() + datetime.timedelta(hours=TOKEN_TTL_HOURS)
- token = AuthToken(token=token_value, user_id=user_id, expires_at=expires_at)
- session.add(token)
- session.commit()
- return {"token": token_value, "expires_at": expires_at.isoformat()}
- return await db_call(creator)
- async def revoke_token(token_value: str) -> None:
- def remover(session: Session) -> None:
- session.query(AuthToken).filter(AuthToken.token == token_value).delete()
- session.commit()
- await db_call(remover)
- async def resolve_token(token_value: str) -> Optional[UserInfo]:
- def resolver(session: Session) -> Optional[UserInfo]:
- token = session.execute(select(AuthToken).where(AuthToken.token == token_value)).scalar_one_or_none()
- if not token:
- return None
- if token.expires_at < now_utc():
- session.delete(token)
- session.commit()
- return None
- user = session.get(User, token.user_id)
- if not user:
- session.delete(token)
- session.commit()
- return None
- return UserInfo(id=user.id, username=user.username, role=user.role)
- return await db_call(resolver)
- async def cleanup_expired_tokens() -> None:
- def cleaner(session: Session) -> None:
- session.query(AuthToken).filter(AuthToken.expires_at < now_utc()).delete()
- session.commit()
- await db_call(cleaner)
- async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(AUTH_SCHEME)) -> UserInfo:
- if not credentials:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录")
- user = await resolve_token(credentials.credentials)
- if not user:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效,请重新登录")
- return user
- async def require_admin(current_user: UserInfo = Depends(get_current_user)) -> UserInfo:
- if current_user.role != "admin":
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
- return current_user
- def text_from_content(content: MessageContent) -> str:
- if isinstance(content, str):
- return content
- if isinstance(content, list):
- pieces: List[str] = []
- for part in content:
- if part.get("type") == "text":
- pieces.append(part.get("text", ""))
- return " ".join(pieces)
- return str(content)
- def extract_history_title(messages: List[Dict[str, Any]]) -> str:
- """Return the first meaningful title extracted from user messages."""
- for message in messages:
- if message.get("role") != "user":
- continue
- title = text_from_content(message.get("content", "")).strip()
- if title:
- return title[:10]
- if messages:
- fallback = text_from_content(messages[0].get("content", "")).strip()
- if fallback:
- return fallback[:10]
- return "空的聊天"[:10]
- def history_backup_path(session_id: int) -> Path:
- return BACKUP_DIR / f"chat_history_{session_id}.json"
- def build_download_url(filename: str) -> str:
- base = DOWNLOAD_BASE or ""
- return f"{base}/{filename}" if base else filename
- async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]:
- def loader(session: Session) -> List[Dict[str, Any]]:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="会话不存在")
- messages = (
- session.execute(
- select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
- )
- .scalars()
- .all()
- )
- return [{"role": message.role, "content": deserialize_content(message.content)} for message in messages]
- return await db_call(loader)
- async def create_chat_session(user_id: int) -> Dict[str, Any]:
- def creator(session: Session) -> Dict[str, Any]:
- chat_session = ChatSession(user_id=user_id)
- session.add(chat_session)
- session.commit()
- session.refresh(chat_session)
- return {"session_id": chat_session.id, "messages": []}
- return await db_call(creator)
- async def get_latest_session(user_id: int) -> Dict[str, Any]:
- def loader(session: Session) -> Dict[str, Any]:
- stmt = (
- select(ChatSession)
- .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
- .order_by(ChatSession.updated_at.desc())
- )
- chat_session = session.execute(stmt).scalars().first()
- if not chat_session:
- chat_session = ChatSession(user_id=user_id)
- session.add(chat_session)
- session.commit()
- session.refresh(chat_session)
- messages: List[Dict[str, Any]] = []
- else:
- messages = (
- session.execute(
- select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
- )
- .scalars()
- .all()
- )
- messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
- return {"session_id": chat_session.id, "messages": messages}
- return await db_call(loader)
- async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None:
- def writer(session: Session) -> None:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="会话不存在")
- serialized = serialize_content(content)
- message = ChatMessage(session_id=chat_session.id, role=role, content=serialized)
- session.add(message)
- chat_session.updated_at = now_utc()
- if role == "user" and (not chat_session.title or not chat_session.title.strip()):
- candidate = text_from_content(content).strip()
- if candidate:
- chat_session.title = candidate[:30]
- session.commit()
- await db_call(writer)
- async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]:
- def lister(session: Session) -> Dict[str, Any]:
- stmt = (
- select(ChatSession)
- .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
- .order_by(ChatSession.updated_at.desc())
- )
- sessions = session.execute(stmt).scalars().all()
- total = len(sessions)
- start = max(page, 0) * page_size
- end = start + page_size
- subset = sessions[start:end]
- items: List[Dict[str, Any]] = []
- for item in subset:
- items.append(
- {
- "session_id": item.id,
- "title": (item.title or f"会话 #{item.id}")[:30],
- "updated_at": (item.updated_at or now_utc()).isoformat(),
- "filename": f"session_{item.id}.json",
- }
- )
- return {"page": page, "page_size": page_size, "total": total, "items": items}
- return await db_call(lister)
- async def move_history_file(user_id: int, session_id: int) -> None:
- def mark_archived(session: Session) -> List[Dict[str, Any]]:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="历史记录不存在")
- messages = (
- session.execute(
- select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
- )
- .scalars()
- .all()
- )
- payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
- chat_session.archived = True
- chat_session.updated_at = now_utc()
- session.commit()
- return payload
- messages = await db_call(mark_archived)
- backup_file = history_backup_path(session_id)
- backup_file.parent.mkdir(parents=True, exist_ok=True)
- def _write() -> None:
- with backup_file.open("w", encoding="utf-8") as fp:
- json.dump(messages, fp, ensure_ascii=False, indent=2)
- async with FILE_LOCK:
- await asyncio.to_thread(_write)
- async def delete_history_file(user_id: int, session_id: int) -> None:
- def deleter(session: Session) -> None:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="历史记录不存在")
- session.delete(chat_session)
- session.commit()
- await db_call(deleter)
- async def export_message_to_blog(content: MessageContent) -> str:
- processed = text_from_content(content)
- processed = processed.replace("\r\n", "\n")
- timestamp = datetime.datetime.now().strftime("%m%d%H%M")
- first_10 = (
- processed[:10]
- .replace(" ", "")
- .replace("/", "")
- .replace("\\", "")
- .replace(":", "")
- .replace("`", "")
- )
- filename = f"{timestamp}_{first_10 or 'export'}.txt"
- path = BLOG_DIR / filename
- def _write() -> None:
- with path.open("w", encoding="utf-8") as fp:
- fp.write(processed)
- await asyncio.to_thread(_write)
- return str(path)
- async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]:
- def recorder(session: Session) -> Dict[str, Any]:
- filename = os.path.basename(file_path)
- preview = text_from_content(content).strip()[:200]
- export = ExportedContent(
- user_id=user_id,
- source_session_id=session_id,
- filename=filename,
- file_path=file_path,
- content_preview=preview,
- )
- session.add(export)
- session.commit()
- session.refresh(export)
- user = session.get(User, user_id)
- username = user.username if user else ""
- return {
- "id": export.id,
- "user_id": user_id,
- "username": username,
- "filename": filename,
- "file_path": file_path,
- "created_at": (export.created_at or now_utc()).isoformat(),
- "content_preview": preview,
- }
- return await db_call(recorder)
- async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]:
- def lister(session: Session) -> List[Dict[str, Any]]:
- stmt = (
- select(ExportedContent)
- .where(ExportedContent.user_id == user_id)
- .order_by(ExportedContent.created_at.desc())
- )
- exports = session.execute(stmt).scalars().all()
- user = session.get(User, user_id)
- username = user.username if user else ""
- results: List[Dict[str, Any]] = []
- for item in exports:
- results.append(
- {
- "id": item.id,
- "user_id": item.user_id,
- "username": username,
- "filename": item.filename,
- "file_path": item.file_path,
- "created_at": (item.created_at or now_utc()).isoformat(),
- "content_preview": (item.content_preview or "")[:200],
- }
- )
- return results
- return await db_call(lister)
- async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]:
- def lister(session: Session) -> List[Dict[str, Any]]:
- stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id)
- if keyword:
- stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
- stmt = stmt.order_by(ExportedContent.created_at.desc())
- rows = session.execute(stmt).all()
- results: List[Dict[str, Any]] = []
- for export, username in rows:
- results.append(
- {
- "id": export.id,
- "user_id": export.user_id,
- "username": username,
- "filename": export.filename,
- "file_path": export.file_path,
- "created_at": (export.created_at or now_utc()).isoformat(),
- "content_preview": (export.content_preview or "")[:200],
- }
- )
- return results
- return await db_call(lister)
- async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]:
- def fetcher(session: Session) -> Optional[Dict[str, Any]]:
- export = session.get(ExportedContent, export_id)
- if not export:
- return None
- user = session.get(User, export.user_id)
- username = user.username if user else ""
- return {
- "id": export.id,
- "user_id": export.user_id,
- "username": username,
- "filename": export.filename,
- "file_path": export.file_path,
- }
- return await db_call(fetcher)
- class RegisterRequest(BaseModel):
- username: str
- password: str
- class LoginRequest(BaseModel):
- username: str
- password: str
- class AdminUserRequest(BaseModel):
- username: str
- password: str
- class AdminUserUpdateRequest(BaseModel):
- username: Optional[str] = None
- password: Optional[str] = None
- class MessageModel(BaseModel):
- role: str
- content: MessageContent
- class ChatRequest(BaseModel):
- session_id: int
- model: str
- content: MessageContent
- history_count: int = 0
- stream: bool = True
- class HistoryActionRequest(BaseModel):
- session_id: int
- class ExportRequest(BaseModel):
- content: MessageContent
- session_id: Optional[int] = None
- class UploadResponseItem(BaseModel):
- type: str
- filename: str
- data: Optional[str] = None
- url: Optional[str] = None
- # 确保静态与数据目录在应用初始化前存在
- ensure_directories()
- ensure_database_initialized()
- ensure_default_admin()
- app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
- @app.on_event("startup")
- async def on_startup() -> None:
- ensure_directories()
- ensure_database_initialized()
- ensure_default_admin()
- await cleanup_expired_tokens()
- INDEX_HTML = STATIC_DIR / "index.html"
- @app.get("/", response_class=HTMLResponse)
- async def serve_index() -> str:
- if not INDEX_HTML.exists():
- raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
- return INDEX_HTML.read_text(encoding="utf-8")
- @app.get("/download/{filename}")
- async def download_file(filename: str) -> FileResponse:
- target = UPLOAD_DIR / filename
- if not target.exists():
- raise HTTPException(status_code=404, detail="File not found")
- return FileResponse(target, filename=filename)
- @app.get("/api/config")
- async def get_config() -> Dict[str, Any]:
- models = list(MODEL_KEYS.keys())
- return {
- "title": "ChatGPT-like Clone",
- "models": models,
- "default_model": models[0] if models else "",
- "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
- "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
- }
- @app.post("/api/auth/register")
- async def api_register(payload: RegisterRequest) -> Dict[str, Any]:
- username = normalize_username(payload.username)
- password = enforce_password_strength(payload.password)
- def creator(session: Session) -> Dict[str, Any]:
- existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
- if existing:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
- salt = secrets.token_hex(8)
- user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
- session.add(user)
- session.commit()
- session.refresh(user)
- return {"id": user.id, "username": user.username, "role": user.role}
- user = await db_call(creator)
- token_data = await create_auth_token(user["id"])
- return {"user": user, "token": token_data["token"], "expires_at": token_data["expires_at"]}
- @app.post("/api/auth/login")
- async def api_login(payload: LoginRequest) -> Dict[str, Any]:
- username = (payload.username or "").strip()
- if not username:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请输入用户名")
- password = payload.password or ""
- def verifier(session: Session) -> Dict[str, Any]:
- user = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
- if not user:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
- hashed = hash_password(password, user.salt)
- if hashed != user.password_hash:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
- return {"id": user.id, "username": user.username, "role": user.role}
- user = await db_call(verifier)
- token_data = await create_auth_token(user["id"])
- return {"user": user, "token": token_data["token"], "expires_at": token_data["expires_at"]}
- @app.post("/api/auth/logout")
- async def api_logout(
- credentials: HTTPAuthorizationCredentials = Depends(AUTH_SCHEME),
- ) -> Dict[str, str]:
- if not credentials:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未登录")
- user = await resolve_token(credentials.credentials)
- if not user:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效")
- await revoke_token(credentials.credentials)
- return {"status": "ok"}
- @app.get("/api/auth/me")
- async def api_me(current_user: UserInfo = Depends(get_current_user)) -> UserInfo:
- return current_user
- @app.get("/api/session/latest")
- async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- return await get_latest_session(current_user.id)
- @app.get("/api/session/{session_id}")
- async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- messages = await load_messages(session_id, current_user.id)
- return {"session_id": session_id, "messages": messages}
- @app.post("/api/session/new")
- async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- return await create_chat_session(current_user.id)
- @app.get("/api/history")
- async def api_history(
- page: int = 0,
- page_size: int = 10,
- current_user: UserInfo = Depends(get_current_user),
- ) -> Dict[str, Any]:
- return await list_history(current_user.id, page, page_size)
- @app.post("/api/history/move")
- async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- await move_history_file(current_user.id, payload.session_id)
- return {"status": "ok"}
- @app.delete("/api/history/{session_id}")
- async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- await delete_history_file(current_user.id, session_id)
- return {"status": "ok"}
- @app.get("/api/admin/users")
- async def admin_list_users(
- keyword: Optional[str] = None,
- page: int = 0,
- page_size: int = 20,
- admin: UserInfo = Depends(require_admin),
- ) -> Dict[str, Any]:
- def lister(session: Session) -> Dict[str, Any]:
- stmt = select(User).order_by(User.created_at.desc())
- if keyword:
- stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
- users = session.execute(stmt).scalars().all()
- total = len(users)
- start = max(page, 0) * page_size
- end = start + page_size
- subset = users[start:end]
- items = [
- {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- for user in subset
- ]
- return {"items": items, "total": total, "page": page, "page_size": page_size}
- return await db_call(lister)
- @app.post("/api/admin/users")
- async def admin_create_user(payload: AdminUserRequest, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]:
- username = normalize_username(payload.username)
- password = enforce_password_strength(payload.password)
- def creator(session: Session) -> Dict[str, Any]:
- existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
- if existing:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
- salt = secrets.token_hex(8)
- user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
- session.add(user)
- session.commit()
- session.refresh(user)
- return {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- return await db_call(creator)
- @app.get("/api/admin/users/{user_id}")
- async def admin_get_user(user_id: int, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]:
- def getter(session: Session) -> Dict[str, Any]:
- user = session.get(User, user_id)
- if not user:
- raise HTTPException(status_code=404, detail="用户不存在")
- return {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- return await db_call(getter)
- @app.put("/api/admin/users/{user_id}")
- async def admin_update_user(
- user_id: int,
- payload: AdminUserUpdateRequest,
- admin: UserInfo = Depends(require_admin),
- ) -> Dict[str, Any]:
- def updater(session: Session) -> Dict[str, Any]:
- user = session.get(User, user_id)
- if not user:
- raise HTTPException(status_code=404, detail="用户不存在")
- if user.role == "admin":
- raise HTTPException(status_code=400, detail="无法修改管理员信息")
- if payload.username:
- new_username = normalize_username(payload.username)
- conflict = (
- session.execute(select(User).where(User.username == new_username, User.id != user_id)).scalar_one_or_none()
- )
- if conflict:
- raise HTTPException(status_code=400, detail="用户名已被使用")
- user.username = new_username
- if payload.password:
- enforce_password_strength(payload.password)
- salt = secrets.token_hex(8)
- user.salt = salt
- user.password_hash = hash_password(payload.password, salt)
- session.commit()
- return {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- return await db_call(updater)
- @app.delete("/api/admin/users/{user_id}")
- async def admin_delete_user(user_id: int, admin: UserInfo = Depends(require_admin)) -> Dict[str, Any]:
- def deleter(session: Session) -> Dict[str, Any]:
- user = session.get(User, user_id)
- if not user:
- raise HTTPException(status_code=404, detail="用户不存在")
- if user.role == "admin":
- raise HTTPException(status_code=400, detail="无法删除管理员")
- session.delete(user)
- session.commit()
- return {"status": "ok"}
- return await db_call(deleter)
- @app.post("/api/export")
- async def api_export_message(payload: ExportRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- path = await export_message_to_blog(payload.content)
- record = await record_export_entry(current_user.id, payload.session_id, path, payload.content)
- return {"status": "ok", "path": path, "export": record}
- @app.get("/api/exports/me")
- async def api_my_exports(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- items = await list_exports_for_user(current_user.id)
- return {"items": items}
- @app.get("/api/admin/exports")
- async def api_admin_exports(
- keyword: Optional[str] = None,
- admin: UserInfo = Depends(require_admin),
- ) -> Dict[str, Any]:
- items = await list_exports_admin(keyword)
- return {"items": items}
- @app.get("/api/exports/{export_id}/download")
- async def api_download_export(export_id: int, current_user: UserInfo = Depends(get_current_user)) -> FileResponse:
- record = await get_export_record(export_id)
- if not record:
- raise HTTPException(status_code=404, detail="导出记录不存在")
- if record["user_id"] != current_user.id and current_user.role != "admin":
- raise HTTPException(status_code=403, detail="无权下载该内容")
- file_path = Path(record["file_path"])
- if not file_path.exists():
- raise HTTPException(status_code=404, detail="导出文件不存在")
- return FileResponse(file_path, filename=record["filename"])
- @app.post("/api/upload")
- async def api_upload(
- files: List[UploadFile] = File(...),
- current_user: UserInfo = Depends(get_current_user),
- ) -> List[UploadResponseItem]:
- if not files:
- return []
- responses: List[UploadResponseItem] = []
- for upload in files:
- filename = upload.filename or "file"
- safe_filename = Path(filename).name or "file"
- content_type = (upload.content_type or "").lower()
- data = await upload.read()
- unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
- target_path = UPLOAD_DIR / unique_name
- def _write() -> None:
- with target_path.open("wb") as fp:
- fp.write(data)
- await asyncio.to_thread(_write)
- if content_type.startswith("image/"):
- encoded = base64.b64encode(data).decode("utf-8")
- data_url = f"data:{content_type};base64,{encoded}"
- responses.append(
- UploadResponseItem(
- type="image",
- filename=safe_filename,
- data=data_url,
- url=build_download_url(unique_name),
- )
- )
- else:
- responses.append(
- UploadResponseItem(
- type="file",
- filename=safe_filename,
- url=build_download_url(unique_name),
- )
- )
- return responses
- async def prepare_messages_for_completion(
- messages: List[Dict[str, Any]],
- user_content: MessageContent,
- history_count: int,
- ) -> List[Dict[str, Any]]:
- if history_count > 0:
- trimmed = messages[-history_count:]
- if trimmed:
- return trimmed
- return [{"role": "user", "content": user_content}]
- async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
- messages.append({"role": "assistant", "content": content})
- await append_message(session_id, user_id, "assistant", content)
- @app.post("/api/chat")
- async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)):
- if payload.model not in MODEL_KEYS:
- raise HTTPException(status_code=400, detail="未知的模型")
- messages = await load_messages(payload.session_id, current_user.id)
- user_message = {"role": "user", "content": payload.content}
- messages.append(user_message)
- await append_message(payload.session_id, current_user.id, "user", payload.content)
- client.api_key = MODEL_KEYS[payload.model]
- to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
- if payload.stream:
- queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
- aggregated: List[str] = []
- loop = asyncio.get_running_loop()
- def worker() -> None:
- try:
- response = client.chat.completions.create(
- model=payload.model,
- messages=to_send,
- stream=True,
- )
- for chunk in response:
- try:
- delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
- except (IndexError, AttributeError):
- delta = None
- if delta:
- aggregated.append(delta)
- asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
- asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
- except Exception as exc: # pragma: no cover - 网络调用
- asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
- threading.Thread(target=worker, daemon=True).start()
- async def streamer():
- try:
- while True:
- item = await queue.get()
- if item["type"] == "delta":
- yield json.dumps(item, ensure_ascii=False) + "\n"
- elif item["type"] == "complete":
- assistant_text = "".join(aggregated)
- await save_assistant_message(payload.session_id, current_user.id, messages, assistant_text)
- yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
- break
- elif item["type"] == "error":
- yield json.dumps(item, ensure_ascii=False) + "\n"
- break
- except asyncio.CancelledError: # pragma: no cover - 流被取消
- raise
- return StreamingResponse(streamer(), media_type="application/x-ndjson")
- try:
- completion = await asyncio.to_thread(
- client.chat.completions.create,
- model=payload.model,
- messages=to_send,
- stream=False,
- )
- except Exception as exc: # pragma: no cover - 网络调用
- raise HTTPException(status_code=500, detail=str(exc)) from exc
- choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
- if not choice:
- raise HTTPException(status_code=500, detail="响应格式不正确")
- assistant_content = getattr(choice.message, "content", "")
- if not assistant_content:
- assistant_content = ""
- await save_assistant_message(payload.session_id, current_user.id, messages, assistant_content)
- return {"message": assistant_content}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)
|