| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- """User and authentication related helpers."""
- from __future__ import annotations
- import datetime
- import secrets
- from typing import Any, Dict, Optional
- from fastapi import Depends, HTTPException, status
- from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
- from pydantic import BaseModel
- from sqlalchemy import select
- from sqlalchemy.orm import Session
- from .. import config
- from ..db import SessionLocal, db_call, now_utc
- from ..models import AuthToken, User
- AUTH_SCHEME = HTTPBearer(auto_error=False)
- class UserInfo(BaseModel):
- id: int
- username: str
- role: str
- class RegisterRequest(BaseModel):
- username: str
- password: str
- class LoginRequest(BaseModel):
- username: str
- password: str
- class AdminUserRequest(BaseModel):
- username: str
- password: str
- class AdminUserUpdateRequest(BaseModel):
- username: Optional[str] = None
- password: Optional[str] = None
- def hash_password(password: str, salt: str) -> str:
- import hashlib
- return hashlib.sha256((password + salt).encode("utf-8")).hexdigest()
- def normalize_username(username: str) -> str:
- normalized = (username or "").strip()
- if len(normalized) < 3:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名至少需要 3 个字符")
- return normalized
- def enforce_password_strength(password: str) -> str:
- if not password or len(password) < 6:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="密码至少需要 6 位")
- return password
- async def create_auth_token(user_id: int) -> Dict[str, Any]:
- def creator(session: Session) -> Dict[str, Any]:
- token_value = secrets.token_hex(32)
- expires_at = now_utc() + datetime.timedelta(hours=config.TOKEN_TTL_HOURS)
- token = AuthToken(token=token_value, user_id=user_id, expires_at=expires_at)
- session.add(token)
- session.commit()
- return {"token": token_value, "expires_at": expires_at.isoformat()}
- return await db_call(creator)
- async def revoke_token(token_value: str) -> None:
- def remover(session: Session) -> None:
- session.query(AuthToken).filter(AuthToken.token == token_value).delete()
- session.commit()
- await db_call(remover)
- async def resolve_token(token_value: str) -> Optional[UserInfo]:
- def resolver(session: Session) -> Optional[UserInfo]:
- token = session.execute(select(AuthToken).where(AuthToken.token == token_value)).scalar_one_or_none()
- if not token:
- return None
- if token.expires_at < now_utc():
- session.delete(token)
- session.commit()
- return None
- user = session.get(User, token.user_id)
- if not user:
- session.delete(token)
- session.commit()
- return None
- return UserInfo(id=user.id, username=user.username, role=user.role)
- return await db_call(resolver)
- async def cleanup_expired_tokens() -> None:
- def cleaner(session: Session) -> None:
- session.query(AuthToken).filter(AuthToken.expires_at < now_utc()).delete()
- session.commit()
- await db_call(cleaner)
- def ensure_default_admin() -> None:
- if SessionLocal is None:
- return
- with SessionLocal() as session:
- existing = session.execute(select(User).where(User.role == "admin")).first()
- if existing:
- return
- salt = secrets.token_hex(8)
- admin = User(
- username=config.DEFAULT_ADMIN_USERNAME,
- salt=salt,
- password_hash=hash_password(config.DEFAULT_ADMIN_PASSWORD, salt),
- role="admin",
- )
- session.add(admin)
- session.commit()
- async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(AUTH_SCHEME)) -> UserInfo:
- if not credentials:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录")
- user = await resolve_token(credentials.credentials)
- if not user:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效,请重新登录")
- return user
- async def require_admin(current_user: UserInfo = Depends(get_current_user)) -> UserInfo:
- if current_user.role != "admin":
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
- return current_user
- async def admin_list_users(keyword: Optional[str], page: int, page_size: int) -> Dict[str, Any]:
- def lister(session: Session) -> Dict[str, Any]:
- stmt = select(User).order_by(User.created_at.desc())
- if keyword:
- stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
- users = session.execute(stmt).scalars().all()
- total = len(users)
- start = max(page, 0) * page_size
- end = start + page_size
- subset = users[start:end]
- items = [
- {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- for user in subset
- ]
- return {"items": items, "total": total, "page": page, "page_size": page_size}
- return await db_call(lister)
- async def admin_create_user(username: str, password: str) -> Dict[str, Any]:
- username = normalize_username(username)
- password = enforce_password_strength(password)
- def creator(session: Session) -> Dict[str, Any]:
- existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
- if existing:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
- salt = secrets.token_hex(8)
- user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
- session.add(user)
- session.commit()
- session.refresh(user)
- return {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- return await db_call(creator)
- async def admin_get_user(user_id: int) -> Dict[str, Any]:
- def getter(session: Session) -> Dict[str, Any]:
- user = session.get(User, user_id)
- if not user:
- raise HTTPException(status_code=404, detail="用户不存在")
- return {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- return await db_call(getter)
- async def admin_update_user(user_id: int, payload: AdminUserUpdateRequest) -> Dict[str, Any]:
- def updater(session: Session) -> Dict[str, Any]:
- user = session.get(User, user_id)
- if not user:
- raise HTTPException(status_code=404, detail="用户不存在")
- if user.role == "admin":
- raise HTTPException(status_code=400, detail="无法修改管理员信息")
- if payload.username:
- new_username = normalize_username(payload.username)
- conflict = (
- session.execute(select(User).where(User.username == new_username, User.id != user_id)).scalar_one_or_none()
- )
- if conflict:
- raise HTTPException(status_code=400, detail="用户名已被使用")
- user.username = new_username
- if payload.password:
- enforce_password_strength(payload.password)
- salt = secrets.token_hex(8)
- user.salt = salt
- user.password_hash = hash_password(payload.password, salt)
- session.commit()
- return {
- "id": user.id,
- "username": user.username,
- "role": user.role,
- "created_at": (user.created_at or now_utc()).isoformat(),
- }
- return await db_call(updater)
- async def admin_delete_user(user_id: int) -> Dict[str, Any]:
- def deleter(session: Session) -> Dict[str, Any]:
- user = session.get(User, user_id)
- if not user:
- raise HTTPException(status_code=404, detail="用户不存在")
- if user.role == "admin":
- raise HTTPException(status_code=400, detail="无法删除管理员")
- session.delete(user)
- session.commit()
- return {"status": "ok"}
- return await db_call(deleter)
- async def register_user(username: str, password: str) -> Dict[str, Any]:
- username = normalize_username(username)
- password = enforce_password_strength(password)
- def creator(session: Session) -> Dict[str, Any]:
- existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
- if existing:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
- salt = secrets.token_hex(8)
- user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
- session.add(user)
- session.commit()
- session.refresh(user)
- return {"id": user.id, "username": user.username, "role": user.role}
- return await db_call(creator)
- async def login_user(username: str, password: str) -> Dict[str, Any]:
- username = (username or "").strip()
- if not username:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请输入用户名")
- password = password or ""
- def verifier(session: Session) -> Dict[str, Any]:
- user = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
- if not user:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
- hashed = hash_password(password, user.salt)
- if hashed != user.password_hash:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
- return {"id": user.id, "username": user.username, "role": user.role}
- return await db_call(verifier)
|