auth.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. """User and authentication related helpers."""
  2. from __future__ import annotations
  3. import datetime
  4. import secrets
  5. from typing import Any, Dict, Optional
  6. from fastapi import Depends, HTTPException, status
  7. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  8. from pydantic import BaseModel
  9. from sqlalchemy import select
  10. from sqlalchemy.orm import Session
  11. from .. import config
  12. from ..db import SessionLocal, db_call, now_utc
  13. from ..models import AuthToken, User
  14. AUTH_SCHEME = HTTPBearer(auto_error=False)
  15. class UserInfo(BaseModel):
  16. id: int
  17. username: str
  18. role: str
  19. class RegisterRequest(BaseModel):
  20. username: str
  21. password: str
  22. class LoginRequest(BaseModel):
  23. username: str
  24. password: str
  25. class AdminUserRequest(BaseModel):
  26. username: str
  27. password: str
  28. class AdminUserUpdateRequest(BaseModel):
  29. username: Optional[str] = None
  30. password: Optional[str] = None
  31. def hash_password(password: str, salt: str) -> str:
  32. import hashlib
  33. return hashlib.sha256((password + salt).encode("utf-8")).hexdigest()
  34. def normalize_username(username: str) -> str:
  35. normalized = (username or "").strip()
  36. if len(normalized) < 3:
  37. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名至少需要 3 个字符")
  38. return normalized
  39. def enforce_password_strength(password: str) -> str:
  40. if not password or len(password) < 6:
  41. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="密码至少需要 6 位")
  42. return password
  43. async def create_auth_token(user_id: int) -> Dict[str, Any]:
  44. def creator(session: Session) -> Dict[str, Any]:
  45. token_value = secrets.token_hex(32)
  46. expires_at = now_utc() + datetime.timedelta(hours=config.TOKEN_TTL_HOURS)
  47. token = AuthToken(token=token_value, user_id=user_id, expires_at=expires_at)
  48. session.add(token)
  49. session.commit()
  50. return {"token": token_value, "expires_at": expires_at.isoformat()}
  51. return await db_call(creator)
  52. async def revoke_token(token_value: str) -> None:
  53. def remover(session: Session) -> None:
  54. session.query(AuthToken).filter(AuthToken.token == token_value).delete()
  55. session.commit()
  56. await db_call(remover)
  57. async def resolve_token(token_value: str) -> Optional[UserInfo]:
  58. def resolver(session: Session) -> Optional[UserInfo]:
  59. token = session.execute(select(AuthToken).where(AuthToken.token == token_value)).scalar_one_or_none()
  60. if not token:
  61. return None
  62. if token.expires_at < now_utc():
  63. session.delete(token)
  64. session.commit()
  65. return None
  66. user = session.get(User, token.user_id)
  67. if not user:
  68. session.delete(token)
  69. session.commit()
  70. return None
  71. return UserInfo(id=user.id, username=user.username, role=user.role)
  72. return await db_call(resolver)
  73. async def cleanup_expired_tokens() -> None:
  74. def cleaner(session: Session) -> None:
  75. session.query(AuthToken).filter(AuthToken.expires_at < now_utc()).delete()
  76. session.commit()
  77. await db_call(cleaner)
  78. def ensure_default_admin() -> None:
  79. if SessionLocal is None:
  80. return
  81. with SessionLocal() as session:
  82. existing = session.execute(select(User).where(User.role == "admin")).first()
  83. if existing:
  84. return
  85. salt = secrets.token_hex(8)
  86. admin = User(
  87. username=config.DEFAULT_ADMIN_USERNAME,
  88. salt=salt,
  89. password_hash=hash_password(config.DEFAULT_ADMIN_PASSWORD, salt),
  90. role="admin",
  91. )
  92. session.add(admin)
  93. session.commit()
  94. async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = Depends(AUTH_SCHEME)) -> UserInfo:
  95. if not credentials:
  96. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录")
  97. user = await resolve_token(credentials.credentials)
  98. if not user:
  99. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="登录已失效,请重新登录")
  100. return user
  101. async def require_admin(current_user: UserInfo = Depends(get_current_user)) -> UserInfo:
  102. if current_user.role != "admin":
  103. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
  104. return current_user
  105. async def admin_list_users(keyword: Optional[str], page: int, page_size: int) -> Dict[str, Any]:
  106. def lister(session: Session) -> Dict[str, Any]:
  107. stmt = select(User).order_by(User.created_at.desc())
  108. if keyword:
  109. stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
  110. users = session.execute(stmt).scalars().all()
  111. total = len(users)
  112. start = max(page, 0) * page_size
  113. end = start + page_size
  114. subset = users[start:end]
  115. items = [
  116. {
  117. "id": user.id,
  118. "username": user.username,
  119. "role": user.role,
  120. "created_at": (user.created_at or now_utc()).isoformat(),
  121. }
  122. for user in subset
  123. ]
  124. return {"items": items, "total": total, "page": page, "page_size": page_size}
  125. return await db_call(lister)
  126. async def admin_create_user(username: str, password: str) -> Dict[str, Any]:
  127. username = normalize_username(username)
  128. password = enforce_password_strength(password)
  129. def creator(session: Session) -> Dict[str, Any]:
  130. existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
  131. if existing:
  132. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
  133. salt = secrets.token_hex(8)
  134. user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
  135. session.add(user)
  136. session.commit()
  137. session.refresh(user)
  138. return {
  139. "id": user.id,
  140. "username": user.username,
  141. "role": user.role,
  142. "created_at": (user.created_at or now_utc()).isoformat(),
  143. }
  144. return await db_call(creator)
  145. async def admin_get_user(user_id: int) -> Dict[str, Any]:
  146. def getter(session: Session) -> Dict[str, Any]:
  147. user = session.get(User, user_id)
  148. if not user:
  149. raise HTTPException(status_code=404, detail="用户不存在")
  150. return {
  151. "id": user.id,
  152. "username": user.username,
  153. "role": user.role,
  154. "created_at": (user.created_at or now_utc()).isoformat(),
  155. }
  156. return await db_call(getter)
  157. async def admin_update_user(user_id: int, payload: AdminUserUpdateRequest) -> Dict[str, Any]:
  158. def updater(session: Session) -> Dict[str, Any]:
  159. user = session.get(User, user_id)
  160. if not user:
  161. raise HTTPException(status_code=404, detail="用户不存在")
  162. if user.role == "admin":
  163. raise HTTPException(status_code=400, detail="无法修改管理员信息")
  164. if payload.username:
  165. new_username = normalize_username(payload.username)
  166. conflict = (
  167. session.execute(select(User).where(User.username == new_username, User.id != user_id)).scalar_one_or_none()
  168. )
  169. if conflict:
  170. raise HTTPException(status_code=400, detail="用户名已被使用")
  171. user.username = new_username
  172. if payload.password:
  173. enforce_password_strength(payload.password)
  174. salt = secrets.token_hex(8)
  175. user.salt = salt
  176. user.password_hash = hash_password(payload.password, salt)
  177. session.commit()
  178. return {
  179. "id": user.id,
  180. "username": user.username,
  181. "role": user.role,
  182. "created_at": (user.created_at or now_utc()).isoformat(),
  183. }
  184. return await db_call(updater)
  185. async def admin_delete_user(user_id: int) -> Dict[str, Any]:
  186. def deleter(session: Session) -> Dict[str, Any]:
  187. user = session.get(User, user_id)
  188. if not user:
  189. raise HTTPException(status_code=404, detail="用户不存在")
  190. if user.role == "admin":
  191. raise HTTPException(status_code=400, detail="无法删除管理员")
  192. session.delete(user)
  193. session.commit()
  194. return {"status": "ok"}
  195. return await db_call(deleter)
  196. async def register_user(username: str, password: str) -> Dict[str, Any]:
  197. username = normalize_username(username)
  198. password = enforce_password_strength(password)
  199. def creator(session: Session) -> Dict[str, Any]:
  200. existing = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
  201. if existing:
  202. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在")
  203. salt = secrets.token_hex(8)
  204. user = User(username=username, salt=salt, password_hash=hash_password(password, salt), role="user")
  205. session.add(user)
  206. session.commit()
  207. session.refresh(user)
  208. return {"id": user.id, "username": user.username, "role": user.role}
  209. return await db_call(creator)
  210. async def login_user(username: str, password: str) -> Dict[str, Any]:
  211. username = (username or "").strip()
  212. if not username:
  213. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请输入用户名")
  214. password = password or ""
  215. def verifier(session: Session) -> Dict[str, Any]:
  216. user = session.execute(select(User).where(User.username == username)).scalar_one_or_none()
  217. if not user:
  218. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
  219. hashed = hash_password(password, user.salt)
  220. if hashed != user.password_hash:
  221. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户名或密码错误")
  222. return {"id": user.id, "username": user.username, "role": user.role}
  223. return await db_call(verifier)