|
|
@@ -0,0 +1,455 @@
|
|
|
+"""Chat session helpers."""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
+
|
|
|
+from fastapi import HTTPException, status
|
|
|
+from sqlalchemy import select, func, text
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+
|
|
|
+from .. import config
|
|
|
+from .. import db as db_core
|
|
|
+from ..db import FILE_LOCK, MessageContent, db_call, now_utc
|
|
|
+from ..models import ChatMessage, ChatSession, ExportedContent, User
|
|
|
+
|
|
|
+
|
|
|
+def serialize_content(content: MessageContent) -> str:
|
|
|
+ return json.dumps(content, ensure_ascii=False)
|
|
|
+
|
|
|
+
|
|
|
+def deserialize_content(raw: str) -> MessageContent:
|
|
|
+ try:
|
|
|
+ return json.loads(raw)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ return raw
|
|
|
+
|
|
|
+
|
|
|
+def text_from_content(content: MessageContent) -> str:
|
|
|
+ if isinstance(content, str):
|
|
|
+ return content
|
|
|
+ if isinstance(content, list):
|
|
|
+ pieces: List[str] = []
|
|
|
+ for part in content:
|
|
|
+ if part.get("type") == "text":
|
|
|
+ pieces.append(part.get("text", ""))
|
|
|
+ return " ".join(pieces)
|
|
|
+ return str(content)
|
|
|
+
|
|
|
+
|
|
|
+def extract_history_title(messages: List[Dict[str, Any]]) -> str:
|
|
|
+ for message in messages:
|
|
|
+ if message.get("role") != "user":
|
|
|
+ continue
|
|
|
+ title = text_from_content(message.get("content", "")).strip()
|
|
|
+ if title:
|
|
|
+ return title[:10]
|
|
|
+
|
|
|
+ if messages:
|
|
|
+ fallback = text_from_content(messages[0].get("content", "")).strip()
|
|
|
+ if fallback:
|
|
|
+ return fallback[:10]
|
|
|
+
|
|
|
+ return "空的聊天"[:10]
|
|
|
+
|
|
|
+
|
|
|
+def history_backup_path(session_id: int) -> Path:
|
|
|
+ return config.BACKUP_DIR / f"chat_history_{session_id}.json"
|
|
|
+
|
|
|
+
|
|
|
+def build_download_url(filename: str) -> str:
|
|
|
+ base = config.DOWNLOAD_BASE or ""
|
|
|
+ return f"{base}/{filename}" if base else filename
|
|
|
+
|
|
|
+
|
|
|
+def ensure_session_numbering() -> None:
|
|
|
+ # make sure the database is ready before touching schema
|
|
|
+ db_core.ensure_database_initialized()
|
|
|
+ _ensure_session_number_column()
|
|
|
+ _backfill_session_numbers()
|
|
|
+
|
|
|
+
|
|
|
+def _ensure_session_number_column() -> None:
|
|
|
+ engine = db_core.ENGINE
|
|
|
+ if engine is None:
|
|
|
+ db_core.ensure_database_initialized()
|
|
|
+ engine = db_core.ENGINE
|
|
|
+ if engine is None:
|
|
|
+ return
|
|
|
+ with engine.begin() as connection:
|
|
|
+ exists = connection.execute(
|
|
|
+ text(
|
|
|
+ """
|
|
|
+ SELECT 1 FROM information_schema.columns
|
|
|
+ WHERE table_schema = :schema
|
|
|
+ AND table_name = 'chat_sessions'
|
|
|
+ AND column_name = 'user_session_no'
|
|
|
+ """
|
|
|
+ ),
|
|
|
+ {"schema": config.DATABASE_NAME},
|
|
|
+ ).first()
|
|
|
+ if exists:
|
|
|
+ return
|
|
|
+ connection.execute(text("ALTER TABLE chat_sessions ADD COLUMN user_session_no INT NOT NULL DEFAULT 0"))
|
|
|
+
|
|
|
+
|
|
|
+def _backfill_session_numbers() -> None:
|
|
|
+ session_factory = db_core.SessionLocal
|
|
|
+ if session_factory is None:
|
|
|
+ db_core.ensure_database_initialized()
|
|
|
+ session_factory = db_core.SessionLocal
|
|
|
+ if session_factory is None:
|
|
|
+ return
|
|
|
+ with session_factory() as session:
|
|
|
+ users = session.execute(select(User.id)).scalars().all()
|
|
|
+ for user_id in users:
|
|
|
+ sessions = (
|
|
|
+ session.execute(
|
|
|
+ select(ChatSession)
|
|
|
+ .where(ChatSession.user_id == user_id)
|
|
|
+ .order_by(ChatSession.created_at, ChatSession.id)
|
|
|
+ )
|
|
|
+ .scalars()
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ dirty = False
|
|
|
+ for index, chat_session in enumerate(sessions, start=1):
|
|
|
+ if chat_session.user_session_no != index:
|
|
|
+ chat_session.user_session_no = index
|
|
|
+ dirty = True
|
|
|
+ if dirty:
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+
|
|
|
+def _next_session_number(session: Session, user_id: int) -> int:
|
|
|
+ current = (
|
|
|
+ session.execute(select(func.max(ChatSession.user_session_no)).where(ChatSession.user_id == user_id)).scalar()
|
|
|
+ or 0
|
|
|
+ )
|
|
|
+ return current + 1
|
|
|
+
|
|
|
+
|
|
|
+async def get_session_payload(session_id: int, user_id: int, allow_archived: bool = False) -> Dict[str, Any]:
|
|
|
+ def loader(session: Session) -> Dict[str, Any]:
|
|
|
+ stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
|
|
|
+ chat_session = session.execute(stmt).scalar_one_or_none()
|
|
|
+ if not chat_session:
|
|
|
+ raise HTTPException(status_code=404, detail="会话不存在")
|
|
|
+ if chat_session.archived and not allow_archived:
|
|
|
+ raise HTTPException(status_code=404, detail="会话不存在")
|
|
|
+ messages = (
|
|
|
+ session.execute(
|
|
|
+ select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
|
|
|
+ )
|
|
|
+ .scalars()
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
|
|
|
+ if not chat_session.user_session_no:
|
|
|
+ chat_session.user_session_no = _next_session_number(session, user_id)
|
|
|
+ session.commit()
|
|
|
+ return {
|
|
|
+ "session_id": chat_session.id,
|
|
|
+ "session_number": chat_session.user_session_no,
|
|
|
+ "messages": payload,
|
|
|
+ "archived": chat_session.archived,
|
|
|
+ }
|
|
|
+
|
|
|
+ return await db_call(loader)
|
|
|
+
|
|
|
+
|
|
|
+async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]:
|
|
|
+ payload = await get_session_payload(session_id, user_id, allow_archived=True)
|
|
|
+ return payload["messages"]
|
|
|
+
|
|
|
+
|
|
|
+async def ensure_active_session(session_id: Optional[int], user_id: int) -> Dict[str, Any]:
|
|
|
+ if session_id:
|
|
|
+ try:
|
|
|
+ payload = await get_session_payload(session_id, user_id, allow_archived=False)
|
|
|
+ return payload
|
|
|
+ except HTTPException as exc:
|
|
|
+ if exc.status_code != 404:
|
|
|
+ raise
|
|
|
+ return await create_chat_session(user_id)
|
|
|
+
|
|
|
+
|
|
|
+async def create_chat_session(user_id: int) -> Dict[str, Any]:
|
|
|
+ def creator(session: Session) -> Dict[str, Any]:
|
|
|
+ chat_session = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id))
|
|
|
+ session.add(chat_session)
|
|
|
+ session.commit()
|
|
|
+ session.refresh(chat_session)
|
|
|
+ return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": []}
|
|
|
+
|
|
|
+ return await db_call(creator)
|
|
|
+
|
|
|
+
|
|
|
+async def get_latest_session(user_id: int) -> Dict[str, Any]:
|
|
|
+ def loader(session: Session) -> Dict[str, Any]:
|
|
|
+ stmt = (
|
|
|
+ select(ChatSession)
|
|
|
+ .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
|
|
|
+ .order_by(ChatSession.updated_at.desc())
|
|
|
+ )
|
|
|
+ chat_session = session.execute(stmt).scalars().first()
|
|
|
+ if not chat_session:
|
|
|
+ chat_session = ChatSession(user_id=user_id)
|
|
|
+ session.add(chat_session)
|
|
|
+ session.commit()
|
|
|
+ session.refresh(chat_session)
|
|
|
+ messages: List[Dict[str, Any]] = []
|
|
|
+ else:
|
|
|
+ messages = (
|
|
|
+ session.execute(
|
|
|
+ select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
|
|
|
+ )
|
|
|
+ .scalars()
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
|
|
|
+ if not chat_session.user_session_no:
|
|
|
+ chat_session.user_session_no = _next_session_number(session, user_id)
|
|
|
+ session.commit()
|
|
|
+ return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": messages}
|
|
|
+
|
|
|
+ return await db_call(loader)
|
|
|
+
|
|
|
+
|
|
|
+async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None:
|
|
|
+ def writer(session: Session) -> None:
|
|
|
+ stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
|
|
|
+ chat_session = session.execute(stmt).scalar_one_or_none()
|
|
|
+ if not chat_session:
|
|
|
+ raise HTTPException(status_code=404, detail="会话不存在")
|
|
|
+ serialized = serialize_content(content)
|
|
|
+ message = ChatMessage(session_id=chat_session.id, role=role, content=serialized)
|
|
|
+ session.add(message)
|
|
|
+ chat_session.updated_at = now_utc()
|
|
|
+ if role == "user" and (not chat_session.title or not chat_session.title.strip()):
|
|
|
+ candidate = text_from_content(content).strip()
|
|
|
+ if candidate:
|
|
|
+ chat_session.title = candidate[:30]
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ await db_call(writer)
|
|
|
+
|
|
|
+
|
|
|
+async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]:
|
|
|
+ def lister(session: Session) -> Dict[str, Any]:
|
|
|
+ stmt = (
|
|
|
+ select(ChatSession)
|
|
|
+ .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
|
|
|
+ .order_by(ChatSession.updated_at.desc())
|
|
|
+ )
|
|
|
+ sessions = session.execute(stmt).scalars().all()
|
|
|
+ if not sessions:
|
|
|
+ fresh = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id))
|
|
|
+ session.add(fresh)
|
|
|
+ session.commit()
|
|
|
+ session.refresh(fresh)
|
|
|
+ sessions = [fresh]
|
|
|
+ total = len(sessions)
|
|
|
+ start = max(page, 0) * page_size
|
|
|
+ end = start + page_size
|
|
|
+ subset = sessions[start:end]
|
|
|
+ items: List[Dict[str, Any]] = []
|
|
|
+ for item in subset:
|
|
|
+ items.append(
|
|
|
+ {
|
|
|
+ "session_id": item.id,
|
|
|
+ "session_number": item.user_session_no or 0,
|
|
|
+ "title": (item.title or f"会话 #{item.id}")[:30],
|
|
|
+ "updated_at": (item.updated_at or now_utc()).isoformat(),
|
|
|
+ "filename": f"session_{item.id}.json",
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return {"page": page, "page_size": page_size, "total": total, "items": items}
|
|
|
+
|
|
|
+ return await db_call(lister)
|
|
|
+
|
|
|
+
|
|
|
+async def move_history_file(user_id: int, session_id: int) -> None:
|
|
|
+ def mark_archived(session: Session) -> List[Dict[str, Any]]:
|
|
|
+ stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
|
|
|
+ chat_session = session.execute(stmt).scalar_one_or_none()
|
|
|
+ if not chat_session:
|
|
|
+ raise HTTPException(status_code=404, detail="历史记录不存在")
|
|
|
+ messages = (
|
|
|
+ session.execute(
|
|
|
+ select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
|
|
|
+ )
|
|
|
+ .scalars()
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
|
|
|
+ chat_session.archived = True
|
|
|
+ chat_session.updated_at = now_utc()
|
|
|
+ session.commit()
|
|
|
+ return payload
|
|
|
+
|
|
|
+ messages = await db_call(mark_archived)
|
|
|
+ backup_file = history_backup_path(session_id)
|
|
|
+ backup_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ def _write() -> None:
|
|
|
+ with backup_file.open("w", encoding="utf-8") as fp:
|
|
|
+ json.dump(messages, fp, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ async with FILE_LOCK:
|
|
|
+ await asyncio.to_thread(_write) # type: ignore[name-defined]
|
|
|
+
|
|
|
+
|
|
|
+async def delete_history_file(user_id: int, session_id: int) -> None:
|
|
|
+ def deleter(session: Session) -> None:
|
|
|
+ stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
|
|
|
+ chat_session = session.execute(stmt).scalar_one_or_none()
|
|
|
+ if not chat_session:
|
|
|
+ raise HTTPException(status_code=404, detail="历史记录不存在")
|
|
|
+ session.delete(chat_session)
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ await db_call(deleter)
|
|
|
+
|
|
|
+
|
|
|
+async def export_message_to_blog(content: MessageContent) -> str:
|
|
|
+ processed = text_from_content(content)
|
|
|
+ processed = processed.replace("\r\n", "\n")
|
|
|
+ timestamp = now_utc().strftime("%m%d%H%M")
|
|
|
+ first_10 = (
|
|
|
+ processed[:10]
|
|
|
+ .replace(" ", "")
|
|
|
+ .replace("/", "")
|
|
|
+ .replace("\\", "")
|
|
|
+ .replace(":", "")
|
|
|
+ .replace("`", "")
|
|
|
+ )
|
|
|
+ filename = f"{timestamp}_{first_10 or 'export'}.txt"
|
|
|
+ path = config.BLOG_DIR / filename
|
|
|
+
|
|
|
+ def _write() -> None:
|
|
|
+ with path.open("w", encoding="utf-8") as fp:
|
|
|
+ fp.write(processed)
|
|
|
+
|
|
|
+ await asyncio.to_thread(_write) # type: ignore[name-defined]
|
|
|
+ return str(path)
|
|
|
+
|
|
|
+
|
|
|
+async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]:
|
|
|
+ def recorder(session: Session) -> Dict[str, Any]:
|
|
|
+ filename = Path(file_path).name
|
|
|
+ preview = text_from_content(content).strip()[:200]
|
|
|
+ export = ExportedContent(
|
|
|
+ user_id=user_id,
|
|
|
+ source_session_id=session_id,
|
|
|
+ filename=filename,
|
|
|
+ file_path=file_path,
|
|
|
+ content_preview=preview,
|
|
|
+ )
|
|
|
+ session.add(export)
|
|
|
+ session.commit()
|
|
|
+ session.refresh(export)
|
|
|
+ user = session.get(User, user_id)
|
|
|
+ username = user.username if user else ""
|
|
|
+ return {
|
|
|
+ "id": export.id,
|
|
|
+ "user_id": user_id,
|
|
|
+ "username": username,
|
|
|
+ "filename": filename,
|
|
|
+ "file_path": file_path,
|
|
|
+ "created_at": (export.created_at or now_utc()).isoformat(),
|
|
|
+ "content_preview": preview,
|
|
|
+ }
|
|
|
+
|
|
|
+ return await db_call(recorder)
|
|
|
+
|
|
|
+
|
|
|
+async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]:
|
|
|
+ def lister(session: Session) -> List[Dict[str, Any]]:
|
|
|
+ stmt = (
|
|
|
+ select(ExportedContent)
|
|
|
+ .where(ExportedContent.user_id == user_id)
|
|
|
+ .order_by(ExportedContent.created_at.desc())
|
|
|
+ )
|
|
|
+ exports = session.execute(stmt).scalars().all()
|
|
|
+ user = session.get(User, user_id)
|
|
|
+ username = user.username if user else ""
|
|
|
+ results: List[Dict[str, Any]] = []
|
|
|
+ for item in exports:
|
|
|
+ results.append(
|
|
|
+ {
|
|
|
+ "id": item.id,
|
|
|
+ "user_id": item.user_id,
|
|
|
+ "username": username,
|
|
|
+ "filename": item.filename,
|
|
|
+ "file_path": item.file_path,
|
|
|
+ "created_at": (item.created_at or now_utc()).isoformat(),
|
|
|
+ "content_preview": (item.content_preview or "")[:200],
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return results
|
|
|
+
|
|
|
+ return await db_call(lister)
|
|
|
+
|
|
|
+
|
|
|
+async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
|
+ def lister(session: Session) -> List[Dict[str, Any]]:
|
|
|
+ stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id)
|
|
|
+ if keyword:
|
|
|
+ stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
|
|
|
+ stmt = stmt.order_by(ExportedContent.created_at.desc())
|
|
|
+ rows = session.execute(stmt).all()
|
|
|
+ results: List[Dict[str, Any]] = []
|
|
|
+ for export, username in rows:
|
|
|
+ results.append(
|
|
|
+ {
|
|
|
+ "id": export.id,
|
|
|
+ "user_id": export.user_id,
|
|
|
+ "username": username,
|
|
|
+ "filename": export.filename,
|
|
|
+ "file_path": export.file_path,
|
|
|
+ "created_at": (export.created_at or now_utc()).isoformat(),
|
|
|
+ "content_preview": (export.content_preview or "")[:200],
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return results
|
|
|
+
|
|
|
+ return await db_call(lister)
|
|
|
+
|
|
|
+
|
|
|
+async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]:
|
|
|
+ def fetcher(session: Session) -> Optional[Dict[str, Any]]:
|
|
|
+ export = session.get(ExportedContent, export_id)
|
|
|
+ if not export:
|
|
|
+ return None
|
|
|
+ user = session.get(User, export.user_id)
|
|
|
+ username = user.username if user else ""
|
|
|
+ return {
|
|
|
+ "id": export.id,
|
|
|
+ "user_id": export.user_id,
|
|
|
+ "username": username,
|
|
|
+ "filename": export.filename,
|
|
|
+ "file_path": export.file_path,
|
|
|
+ }
|
|
|
+
|
|
|
+ return await db_call(fetcher)
|
|
|
+
|
|
|
+
|
|
|
+async def prepare_messages_for_completion(
|
|
|
+ messages: List[Dict[str, Any]],
|
|
|
+ user_content: MessageContent,
|
|
|
+ history_count: int,
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
+ if history_count > 0:
|
|
|
+ trimmed = messages[-history_count:]
|
|
|
+ if trimmed:
|
|
|
+ return trimmed
|
|
|
+ return [{"role": "user", "content": user_content}]
|
|
|
+
|
|
|
+
|
|
|
+async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
|
|
|
+ messages.append({"role": "assistant", "content": content})
|
|
|
+ await append_message(session_id, user_id, "assistant", content)
|