"""Chat session helpers.""" from __future__ import annotations import asyncio import json from pathlib import Path from typing import Any, Dict, List, Optional from fastapi import HTTPException, status from sqlalchemy import select, func, text from sqlalchemy.orm import Session from .. import config from .. import db as db_core from ..db import FILE_LOCK, MessageContent, db_call, now_utc from ..models import ChatMessage, ChatSession, ExportedContent, User def serialize_content(content: MessageContent) -> str: return json.dumps(content, ensure_ascii=False) def deserialize_content(raw: str) -> MessageContent: try: return json.loads(raw) except json.JSONDecodeError: return raw def text_from_content(content: MessageContent) -> str: if isinstance(content, str): return content if isinstance(content, list): pieces: List[str] = [] for part in content: if part.get("type") == "text": pieces.append(part.get("text", "")) return " ".join(pieces) return str(content) def extract_history_title(messages: List[Dict[str, Any]]) -> str: for message in messages: if message.get("role") != "user": continue title = text_from_content(message.get("content", "")).strip() if title: return title[:10] if messages: fallback = text_from_content(messages[0].get("content", "")).strip() if fallback: return fallback[:10] return "空的聊天"[:10] def history_backup_path(session_id: int) -> Path: return config.BACKUP_DIR / f"chat_history_{session_id}.json" def build_download_url(filename: str) -> str: base = config.DOWNLOAD_BASE or "" return f"{base}/{filename}" if base else filename def ensure_session_numbering() -> None: # make sure the database is ready before touching schema db_core.ensure_database_initialized() _ensure_session_number_column() _backfill_session_numbers() def _ensure_session_number_column() -> None: engine = db_core.ENGINE if engine is None: db_core.ensure_database_initialized() engine = db_core.ENGINE if engine is None: return with engine.begin() as connection: exists = connection.execute( text( """ SELECT 1 FROM information_schema.columns WHERE table_schema = :schema AND table_name = 'chat_sessions' AND column_name = 'user_session_no' """ ), {"schema": config.DATABASE_NAME}, ).first() if exists: return connection.execute(text("ALTER TABLE chat_sessions ADD COLUMN user_session_no INT NOT NULL DEFAULT 0")) def _backfill_session_numbers() -> None: session_factory = db_core.SessionLocal if session_factory is None: db_core.ensure_database_initialized() session_factory = db_core.SessionLocal if session_factory is None: return with session_factory() as session: users = session.execute(select(User.id)).scalars().all() for user_id in users: sessions = ( session.execute( select(ChatSession) .where(ChatSession.user_id == user_id) .order_by(ChatSession.created_at, ChatSession.id) ) .scalars() .all() ) dirty = False for index, chat_session in enumerate(sessions, start=1): if chat_session.user_session_no != index: chat_session.user_session_no = index dirty = True if dirty: session.commit() def _next_session_number(session: Session, user_id: int) -> int: current = ( session.execute(select(func.max(ChatSession.user_session_no)).where(ChatSession.user_id == user_id)).scalar() or 0 ) return current + 1 async def get_session_payload(session_id: int, user_id: int, allow_archived: bool = False) -> Dict[str, Any]: def loader(session: Session) -> Dict[str, Any]: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="会话不存在") if chat_session.archived and not allow_archived: raise HTTPException(status_code=404, detail="会话不存在") messages = ( session.execute( select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id) ) .scalars() .all() ) payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages] if not chat_session.user_session_no: chat_session.user_session_no = _next_session_number(session, user_id) session.commit() return { "session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": payload, "archived": chat_session.archived, } return await db_call(loader) async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]: payload = await get_session_payload(session_id, user_id, allow_archived=True) return payload["messages"] async def ensure_active_session(session_id: Optional[int], user_id: int) -> Dict[str, Any]: if session_id: try: payload = await get_session_payload(session_id, user_id, allow_archived=False) return payload except HTTPException as exc: if exc.status_code != 404: raise return await create_chat_session(user_id) async def create_chat_session(user_id: int) -> Dict[str, Any]: def creator(session: Session) -> Dict[str, Any]: chat_session = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id)) session.add(chat_session) session.commit() session.refresh(chat_session) return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": []} return await db_call(creator) async def get_latest_session(user_id: int) -> Dict[str, Any]: def loader(session: Session) -> Dict[str, Any]: stmt = ( select(ChatSession) .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False)) .order_by(ChatSession.updated_at.desc()) ) chat_session = session.execute(stmt).scalars().first() if not chat_session: chat_session = ChatSession(user_id=user_id) session.add(chat_session) session.commit() session.refresh(chat_session) messages: List[Dict[str, Any]] = [] else: messages = ( session.execute( select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id) ) .scalars() .all() ) messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages] if not chat_session.user_session_no: chat_session.user_session_no = _next_session_number(session, user_id) session.commit() return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": messages} return await db_call(loader) async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None: def writer(session: Session) -> None: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="会话不存在") serialized = serialize_content(content) message = ChatMessage(session_id=chat_session.id, role=role, content=serialized) session.add(message) chat_session.updated_at = now_utc() if role == "user" and (not chat_session.title or not chat_session.title.strip()): candidate = text_from_content(content).strip() if candidate: chat_session.title = candidate[:30] session.commit() await db_call(writer) async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]: def lister(session: Session) -> Dict[str, Any]: stmt = ( select(ChatSession) .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False)) .order_by(ChatSession.updated_at.desc()) ) sessions = session.execute(stmt).scalars().all() if not sessions: fresh = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id)) session.add(fresh) session.commit() session.refresh(fresh) sessions = [fresh] total = len(sessions) start = max(page, 0) * page_size end = start + page_size subset = sessions[start:end] items: List[Dict[str, Any]] = [] for item in subset: items.append( { "session_id": item.id, "session_number": item.user_session_no or 0, "title": (item.title or f"会话 #{item.id}")[:30], "updated_at": (item.updated_at or now_utc()).isoformat(), "filename": f"session_{item.id}.json", } ) return {"page": page, "page_size": page_size, "total": total, "items": items} return await db_call(lister) async def move_history_file(user_id: int, session_id: int) -> None: def mark_archived(session: Session) -> List[Dict[str, Any]]: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="历史记录不存在") messages = ( session.execute( select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id) ) .scalars() .all() ) payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages] chat_session.archived = True chat_session.updated_at = now_utc() session.commit() return payload messages = await db_call(mark_archived) backup_file = history_backup_path(session_id) backup_file.parent.mkdir(parents=True, exist_ok=True) def _write() -> None: with backup_file.open("w", encoding="utf-8") as fp: json.dump(messages, fp, ensure_ascii=False, indent=2) async with FILE_LOCK: await asyncio.to_thread(_write) # type: ignore[name-defined] async def delete_history_file(user_id: int, session_id: int) -> None: def deleter(session: Session) -> None: stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id) chat_session = session.execute(stmt).scalar_one_or_none() if not chat_session: raise HTTPException(status_code=404, detail="历史记录不存在") session.delete(chat_session) session.commit() await db_call(deleter) async def export_message_to_blog(content: MessageContent) -> str: processed = text_from_content(content) processed = processed.replace("\r\n", "\n") timestamp = now_utc().strftime("%m%d%H%M") first_10 = ( processed[:10] .replace(" ", "") .replace("/", "") .replace("\\", "") .replace(":", "") .replace("`", "") ) filename = f"{timestamp}_{first_10 or 'export'}.txt" path = config.BLOG_DIR / filename def _write() -> None: with path.open("w", encoding="utf-8") as fp: fp.write(processed) await asyncio.to_thread(_write) # type: ignore[name-defined] return str(path) async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]: def recorder(session: Session) -> Dict[str, Any]: filename = Path(file_path).name preview = text_from_content(content).strip()[:200] export = ExportedContent( user_id=user_id, source_session_id=session_id, filename=filename, file_path=file_path, content_preview=preview, ) session.add(export) session.commit() session.refresh(export) user = session.get(User, user_id) username = user.username if user else "" return { "id": export.id, "user_id": user_id, "username": username, "filename": filename, "file_path": file_path, "created_at": (export.created_at or now_utc()).isoformat(), "content_preview": preview, } return await db_call(recorder) async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]: def lister(session: Session) -> List[Dict[str, Any]]: stmt = ( select(ExportedContent) .where(ExportedContent.user_id == user_id) .order_by(ExportedContent.created_at.desc()) ) exports = session.execute(stmt).scalars().all() user = session.get(User, user_id) username = user.username if user else "" results: List[Dict[str, Any]] = [] for item in exports: results.append( { "id": item.id, "user_id": item.user_id, "username": username, "filename": item.filename, "file_path": item.file_path, "created_at": (item.created_at or now_utc()).isoformat(), "content_preview": (item.content_preview or "")[:200], } ) return results return await db_call(lister) async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]: def lister(session: Session) -> List[Dict[str, Any]]: stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id) if keyword: stmt = stmt.where(User.username.like(f"%{keyword.strip()}%")) stmt = stmt.order_by(ExportedContent.created_at.desc()) rows = session.execute(stmt).all() results: List[Dict[str, Any]] = [] for export, username in rows: results.append( { "id": export.id, "user_id": export.user_id, "username": username, "filename": export.filename, "file_path": export.file_path, "created_at": (export.created_at or now_utc()).isoformat(), "content_preview": (export.content_preview or "")[:200], } ) return results return await db_call(lister) async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]: def fetcher(session: Session) -> Optional[Dict[str, Any]]: export = session.get(ExportedContent, export_id) if not export: return None user = session.get(User, export.user_id) username = user.username if user else "" return { "id": export.id, "user_id": export.user_id, "username": username, "filename": export.filename, "file_path": export.file_path, } return await db_call(fetcher) async def prepare_messages_for_completion( messages: List[Dict[str, Any]], user_content: MessageContent, history_count: int, ) -> List[Dict[str, Any]]: if history_count > 0: trimmed = messages[-history_count:] if trimmed: return trimmed return [{"role": "user", "content": user_content}] async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None: messages.append({"role": "assistant", "content": content}) await append_message(session_id, user_id, "assistant", content)