"""User and authentication related helpers.""" from __future__ import annotations import datetime import secrets from typing import Any, Dict, Optional from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session from .. import config from ..db import SessionLocal, db_call, now_utc from ..models import AuthToken, User AUTH_SCHEME = HTTPBearer(auto_error=False) class UserInfo(BaseModel): id: int username: str role: str class RegisterRequest(BaseModel): username: str password: str class LoginRequest(BaseModel): username: str password: str class AdminUserRequest(BaseModel): username: str password: str class AdminUserUpdateRequest(BaseModel): username: Optional[str] = None password: Optional[str] = None def hash_password(password: str, salt: str) -> str: import hashlib return hashlib.sha256((password + salt).encode("utf-8")).hexdigest() def normalize_username(username: str) -> str: normalized = (username or "").strip() if len(normalized) < 3: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名至少需要 3 个字符") return normalized def enforce_password_strength(password: str) -> str: if not password or len(password) < 6: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="密码至少需要 6 位") return password async def create_auth_token(user_id: int) -> Dict[str, Any]: def creator(session: Session) -> Dict[str, Any]: token_value = secrets.token_hex(32) expires_at = now_utc() + datetime.timedelta(hours=config.TOKEN_TTL_HOURS) token = AuthToken(token=token_value, user_id=user_id, expires_at=expires_at) session.add(token) session.commit() return {"token": token_value, "expires_at": expires_at.isoformat()} return await db_call(creator) async def revoke_token(token_value: str) -> None: def remover(session: Session) -> None: session.query(AuthToken).filter(AuthToken.token == token_value).delete() session.commit() await db_call(remover) async def resolve_token(token_value: str) -> Optional[UserInfo]: def resolver(session: Session) -> Optional[UserInfo]: token = session.execute(select(AuthToken).where(AuthToken.token == token_value)).scalar_one_or_none() if not token: return None if token.expires_at < now_utc(): session.delete(token) session.commit() return None user = session.get(User, token.user_id) if not user: session.delete(token) session.commit() return None return UserInfo(id=user.id, username=user.username, role=user.role) return await db_call(resolver) async def cleanup_expired_tokens() -> None: def cleaner(session: Session) -> None: session.query(AuthToken).filter(AuthToken.expires_at < now_utc()).delete() session.commit() await db_call(cleaner) def ensure_default_admin() -> None: if SessionLocal is None: return with SessionLocal() as session: existing = session.execute(select(User).where(User.role == "admin")).first() if existing: return salt = secrets.token_hex(8) admin = User( username=config.DEFAULT_ADMIN_USERNAME, salt=salt, password_hash=hash_password(config.DEFAULT_ADMIN_PASSWORD, salt), role="admin", ) session.add(admin) session.commit() async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(AUTH_SCHEME)) -> UserInfo: if not credentials: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录") user = await resolve_token(credentials.credentials) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效,请重新登录") return user async def require_admin(current_user: UserInfo = Depends(get_current_user)) -> UserInfo: if current_user.role != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限") return current_user async def admin_list_users(keyword: Optional[str], page: int, page_size: int) -> Dict[str, Any]: def lister(session: Session) -> Dict[str, Any]: stmt = select(User).order_by(User.created_at.desc()) if keyword: stmt = stmt.where(User.username.like(f"%{keyword.strip()}%")) users = session.execute(stmt).scalars().all() total = len(users) start = max(page, 0) * page_size end = start + page_size subset = users[start:end] items = [ { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } for user in subset ] return {"items": items, "total": total, "page": page, "page_size": page_size} return await db_call(lister) async def admin_create_user(username: str, password: str) -> Dict[str, Any]: username = normalize_username(username) password = enforce_password_strength(password) def creator(session: Session) -> Dict[str, Any]: existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none() if existing: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在") salt = secrets.token_hex(8) user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user") session.add(user) session.commit() session.refresh(user) return { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } return await db_call(creator) async def admin_get_user(user_id: int) -> Dict[str, Any]: def getter(session: Session) -> Dict[str, Any]: user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="用户不存在") return { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } return await db_call(getter) async def admin_update_user(user_id: int, payload: AdminUserUpdateRequest) -> Dict[str, Any]: def updater(session: Session) -> Dict[str, Any]: user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="用户不存在") if user.role == "admin": raise HTTPException(status_code=400, detail="无法修改管理员信息") if payload.username: new_username = normalize_username(payload.username) conflict = ( session.execute(select(User).where(User.username == new_username, User.id != user_id)).scalar_one_or_none() ) if conflict: raise HTTPException(status_code=400, detail="用户名已被使用") user.username = new_username if payload.password: enforce_password_strength(payload.password) salt = secrets.token_hex(8) user.salt = salt user.password_hash = hash_password(payload.password, salt) session.commit() return { "id": user.id, "username": user.username, "role": user.role, "created_at": (user.created_at or now_utc()).isoformat(), } return await db_call(updater) async def admin_delete_user(user_id: int) -> Dict[str, Any]: def deleter(session: Session) -> Dict[str, Any]: user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="用户不存在") if user.role == "admin": raise HTTPException(status_code=400, detail="无法删除管理员") session.delete(user) session.commit() return {"status": "ok"} return await db_call(deleter) async def register_user(username: str, password: str) -> Dict[str, Any]: username = normalize_username(username) password = enforce_password_strength(password) def creator(session: Session) -> Dict[str, Any]: existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none() if existing: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在") salt = secrets.token_hex(8) user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user") session.add(user) session.commit() session.refresh(user) return {"id": user.id, "username": user.username, "role": user.role} return await db_call(creator) async def login_user(username: str, password: str) -> Dict[str, Any]: username = (username or "").strip() if not username: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请输入用户名") password = password or "" def verifier(session: Session) -> Dict[str, Any]: user = session.execute(select(User).where(User.username == username)).scalar_one_or_none() if not user: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误") hashed = hash_password(password, user.salt) if hashed != user.password_hash: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误") return {"id": user.id, "username": user.username, "role": user.role} return await db_call(verifier)