chat.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. """Chat session helpers."""
  2. from __future__ import annotations
  3. import asyncio
  4. import json
  5. from pathlib import Path
  6. from typing import Any, Dict, List, Optional
  7. from fastapi import HTTPException, status
  8. from sqlalchemy import select, func, text
  9. from sqlalchemy.orm import Session
  10. from .. import config
  11. from .. import db as db_core
  12. from ..db import FILE_LOCK, MessageContent, db_call, now_utc
  13. from ..models import ChatMessage, ChatSession, ExportedContent, User
  14. def serialize_content(content: MessageContent) -> str:
  15. return json.dumps(content, ensure_ascii=False)
  16. def deserialize_content(raw: str) -> MessageContent:
  17. try:
  18. return json.loads(raw)
  19. except json.JSONDecodeError:
  20. return raw
  21. def text_from_content(content: MessageContent) -> str:
  22. if isinstance(content, str):
  23. return content
  24. if isinstance(content, list):
  25. pieces: List[str] = []
  26. for part in content:
  27. if part.get("type") == "text":
  28. pieces.append(part.get("text", ""))
  29. return " ".join(pieces)
  30. return str(content)
  31. def extract_history_title(messages: List[Dict[str, Any]]) -> str:
  32. for message in messages:
  33. if message.get("role") != "user":
  34. continue
  35. title = text_from_content(message.get("content", "")).strip()
  36. if title:
  37. return title[:10]
  38. if messages:
  39. fallback = text_from_content(messages[0].get("content", "")).strip()
  40. if fallback:
  41. return fallback[:10]
  42. return "空的聊天"[:10]
  43. def history_backup_path(session_id: int) -> Path:
  44. return config.BACKUP_DIR / f"chat_history_{session_id}.json"
  45. def build_download_url(filename: str) -> str:
  46. base = config.DOWNLOAD_BASE or ""
  47. return f"{base}/{filename}" if base else filename
  48. def ensure_session_numbering() -> None:
  49. # make sure the database is ready before touching schema
  50. db_core.ensure_database_initialized()
  51. _ensure_session_number_column()
  52. _backfill_session_numbers()
  53. def _ensure_session_number_column() -> None:
  54. engine = db_core.ENGINE
  55. if engine is None:
  56. db_core.ensure_database_initialized()
  57. engine = db_core.ENGINE
  58. if engine is None:
  59. return
  60. with engine.begin() as connection:
  61. exists = connection.execute(
  62. text(
  63. """
  64. SELECT 1 FROM information_schema.columns
  65. WHERE table_schema = :schema
  66. AND table_name = 'chat_sessions'
  67. AND column_name = 'user_session_no'
  68. """
  69. ),
  70. {"schema": config.DATABASE_NAME},
  71. ).first()
  72. if exists:
  73. return
  74. connection.execute(text("ALTER TABLE chat_sessions ADD COLUMN user_session_no INT NOT NULL DEFAULT 0"))
  75. def _backfill_session_numbers() -> None:
  76. session_factory = db_core.SessionLocal
  77. if session_factory is None:
  78. db_core.ensure_database_initialized()
  79. session_factory = db_core.SessionLocal
  80. if session_factory is None:
  81. return
  82. with session_factory() as session:
  83. users = session.execute(select(User.id)).scalars().all()
  84. for user_id in users:
  85. sessions = (
  86. session.execute(
  87. select(ChatSession)
  88. .where(ChatSession.user_id == user_id)
  89. .order_by(ChatSession.created_at, ChatSession.id)
  90. )
  91. .scalars()
  92. .all()
  93. )
  94. dirty = False
  95. for index, chat_session in enumerate(sessions, start=1):
  96. if chat_session.user_session_no != index:
  97. chat_session.user_session_no = index
  98. dirty = True
  99. if dirty:
  100. session.commit()
  101. def _next_session_number(session: Session, user_id: int) -> int:
  102. current = (
  103. session.execute(select(func.max(ChatSession.user_session_no)).where(ChatSession.user_id == user_id)).scalar()
  104. or 0
  105. )
  106. return current + 1
  107. async def get_session_payload(session_id: int, user_id: int, allow_archived: bool = False) -> Dict[str, Any]:
  108. def loader(session: Session) -> Dict[str, Any]:
  109. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  110. chat_session = session.execute(stmt).scalar_one_or_none()
  111. if not chat_session:
  112. raise HTTPException(status_code=404, detail="会话不存在")
  113. if chat_session.archived and not allow_archived:
  114. raise HTTPException(status_code=404, detail="会话不存在")
  115. messages = (
  116. session.execute(
  117. select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
  118. )
  119. .scalars()
  120. .all()
  121. )
  122. payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
  123. if not chat_session.user_session_no:
  124. chat_session.user_session_no = _next_session_number(session, user_id)
  125. session.commit()
  126. return {
  127. "session_id": chat_session.id,
  128. "session_number": chat_session.user_session_no,
  129. "messages": payload,
  130. "archived": chat_session.archived,
  131. }
  132. return await db_call(loader)
  133. async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]:
  134. payload = await get_session_payload(session_id, user_id, allow_archived=True)
  135. return payload["messages"]
  136. async def ensure_active_session(session_id: Optional[int], user_id: int) -> Dict[str, Any]:
  137. if session_id:
  138. try:
  139. payload = await get_session_payload(session_id, user_id, allow_archived=False)
  140. return payload
  141. except HTTPException as exc:
  142. if exc.status_code != 404:
  143. raise
  144. return await create_chat_session(user_id)
  145. async def create_chat_session(user_id: int) -> Dict[str, Any]:
  146. def creator(session: Session) -> Dict[str, Any]:
  147. chat_session = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id))
  148. session.add(chat_session)
  149. session.commit()
  150. session.refresh(chat_session)
  151. return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": []}
  152. return await db_call(creator)
  153. async def get_latest_session(user_id: int) -> Dict[str, Any]:
  154. def loader(session: Session) -> Dict[str, Any]:
  155. stmt = (
  156. select(ChatSession)
  157. .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
  158. .order_by(ChatSession.updated_at.desc())
  159. )
  160. chat_session = session.execute(stmt).scalars().first()
  161. if not chat_session:
  162. chat_session = ChatSession(user_id=user_id)
  163. session.add(chat_session)
  164. session.commit()
  165. session.refresh(chat_session)
  166. messages: List[Dict[str, Any]] = []
  167. else:
  168. messages = (
  169. session.execute(
  170. select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
  171. )
  172. .scalars()
  173. .all()
  174. )
  175. messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
  176. if not chat_session.user_session_no:
  177. chat_session.user_session_no = _next_session_number(session, user_id)
  178. session.commit()
  179. return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": messages}
  180. return await db_call(loader)
  181. async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None:
  182. def writer(session: Session) -> None:
  183. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  184. chat_session = session.execute(stmt).scalar_one_or_none()
  185. if not chat_session:
  186. raise HTTPException(status_code=404, detail="会话不存在")
  187. serialized = serialize_content(content)
  188. message = ChatMessage(session_id=chat_session.id, role=role, content=serialized)
  189. session.add(message)
  190. chat_session.updated_at = now_utc()
  191. if role == "user" and (not chat_session.title or not chat_session.title.strip()):
  192. candidate = text_from_content(content).strip()
  193. if candidate:
  194. chat_session.title = candidate[:30]
  195. session.commit()
  196. await db_call(writer)
  197. async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]:
  198. def lister(session: Session) -> Dict[str, Any]:
  199. stmt = (
  200. select(ChatSession)
  201. .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
  202. .order_by(ChatSession.updated_at.desc())
  203. )
  204. sessions = session.execute(stmt).scalars().all()
  205. if not sessions:
  206. fresh = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id))
  207. session.add(fresh)
  208. session.commit()
  209. session.refresh(fresh)
  210. sessions = [fresh]
  211. total = len(sessions)
  212. start = max(page, 0) * page_size
  213. end = start + page_size
  214. subset = sessions[start:end]
  215. items: List[Dict[str, Any]] = []
  216. for item in subset:
  217. items.append(
  218. {
  219. "session_id": item.id,
  220. "session_number": item.user_session_no or 0,
  221. "title": (item.title or f"会话 #{item.id}")[:30],
  222. "updated_at": (item.updated_at or now_utc()).isoformat(),
  223. "filename": f"session_{item.id}.json",
  224. }
  225. )
  226. return {"page": page, "page_size": page_size, "total": total, "items": items}
  227. return await db_call(lister)
  228. async def move_history_file(user_id: int, session_id: int) -> None:
  229. def mark_archived(session: Session) -> List[Dict[str, Any]]:
  230. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  231. chat_session = session.execute(stmt).scalar_one_or_none()
  232. if not chat_session:
  233. raise HTTPException(status_code=404, detail="历史记录不存在")
  234. messages = (
  235. session.execute(
  236. select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
  237. )
  238. .scalars()
  239. .all()
  240. )
  241. payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
  242. chat_session.archived = True
  243. chat_session.updated_at = now_utc()
  244. session.commit()
  245. return payload
  246. messages = await db_call(mark_archived)
  247. backup_file = history_backup_path(session_id)
  248. backup_file.parent.mkdir(parents=True, exist_ok=True)
  249. def _write() -> None:
  250. with backup_file.open("w", encoding="utf-8") as fp:
  251. json.dump(messages, fp, ensure_ascii=False, indent=2)
  252. async with FILE_LOCK:
  253. await asyncio.to_thread(_write) # type: ignore[name-defined]
  254. async def delete_history_file(user_id: int, session_id: int) -> None:
  255. def deleter(session: Session) -> None:
  256. stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
  257. chat_session = session.execute(stmt).scalar_one_or_none()
  258. if not chat_session:
  259. raise HTTPException(status_code=404, detail="历史记录不存在")
  260. session.delete(chat_session)
  261. session.commit()
  262. await db_call(deleter)
  263. async def export_message_to_blog(content: MessageContent) -> str:
  264. processed = text_from_content(content)
  265. processed = processed.replace("\r\n", "\n")
  266. timestamp = now_utc().strftime("%m%d%H%M")
  267. first_10 = (
  268. processed[:10]
  269. .replace(" ", "")
  270. .replace("/", "")
  271. .replace("\\", "")
  272. .replace(":", "")
  273. .replace("`", "")
  274. )
  275. filename = f"{timestamp}_{first_10 or 'export'}.txt"
  276. path = config.BLOG_DIR / filename
  277. def _write() -> None:
  278. with path.open("w", encoding="utf-8") as fp:
  279. fp.write(processed)
  280. await asyncio.to_thread(_write) # type: ignore[name-defined]
  281. return str(path)
  282. async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]:
  283. def recorder(session: Session) -> Dict[str, Any]:
  284. filename = Path(file_path).name
  285. preview = text_from_content(content).strip()[:200]
  286. export = ExportedContent(
  287. user_id=user_id,
  288. source_session_id=session_id,
  289. filename=filename,
  290. file_path=file_path,
  291. content_preview=preview,
  292. )
  293. session.add(export)
  294. session.commit()
  295. session.refresh(export)
  296. user = session.get(User, user_id)
  297. username = user.username if user else ""
  298. return {
  299. "id": export.id,
  300. "user_id": user_id,
  301. "username": username,
  302. "filename": filename,
  303. "file_path": file_path,
  304. "created_at": (export.created_at or now_utc()).isoformat(),
  305. "content_preview": preview,
  306. }
  307. return await db_call(recorder)
  308. async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]:
  309. def lister(session: Session) -> List[Dict[str, Any]]:
  310. stmt = (
  311. select(ExportedContent)
  312. .where(ExportedContent.user_id == user_id)
  313. .order_by(ExportedContent.created_at.desc())
  314. )
  315. exports = session.execute(stmt).scalars().all()
  316. user = session.get(User, user_id)
  317. username = user.username if user else ""
  318. results: List[Dict[str, Any]] = []
  319. for item in exports:
  320. results.append(
  321. {
  322. "id": item.id,
  323. "user_id": item.user_id,
  324. "username": username,
  325. "filename": item.filename,
  326. "file_path": item.file_path,
  327. "created_at": (item.created_at or now_utc()).isoformat(),
  328. "content_preview": (item.content_preview or "")[:200],
  329. }
  330. )
  331. return results
  332. return await db_call(lister)
  333. async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]:
  334. def lister(session: Session) -> List[Dict[str, Any]]:
  335. stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id)
  336. if keyword:
  337. stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
  338. stmt = stmt.order_by(ExportedContent.created_at.desc())
  339. rows = session.execute(stmt).all()
  340. results: List[Dict[str, Any]] = []
  341. for export, username in rows:
  342. results.append(
  343. {
  344. "id": export.id,
  345. "user_id": export.user_id,
  346. "username": username,
  347. "filename": export.filename,
  348. "file_path": export.file_path,
  349. "created_at": (export.created_at or now_utc()).isoformat(),
  350. "content_preview": (export.content_preview or "")[:200],
  351. }
  352. )
  353. return results
  354. return await db_call(lister)
  355. async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]:
  356. def fetcher(session: Session) -> Optional[Dict[str, Any]]:
  357. export = session.get(ExportedContent, export_id)
  358. if not export:
  359. return None
  360. user = session.get(User, export.user_id)
  361. username = user.username if user else ""
  362. return {
  363. "id": export.id,
  364. "user_id": export.user_id,
  365. "username": username,
  366. "filename": export.filename,
  367. "file_path": export.file_path,
  368. }
  369. return await db_call(fetcher)
  370. async def prepare_messages_for_completion(
  371. messages: List[Dict[str, Any]],
  372. user_content: MessageContent,
  373. history_count: int,
  374. ) -> List[Dict[str, Any]]:
  375. if history_count > 0:
  376. trimmed = messages[-history_count:]
  377. if trimmed:
  378. return trimmed
  379. return [{"role": "user", "content": user_content}]
  380. async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
  381. messages.append({"role": "assistant", "content": content})
  382. await append_message(session_id, user_id, "assistant", content)