| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455 |
- """Chat session helpers."""
- from __future__ import annotations
- import asyncio
- import json
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- from fastapi import HTTPException, status
- from sqlalchemy import select, func, text
- from sqlalchemy.orm import Session
- from .. import config
- from .. import db as db_core
- from ..db import FILE_LOCK, MessageContent, db_call, now_utc
- from ..models import ChatMessage, ChatSession, ExportedContent, User
- def serialize_content(content: MessageContent) -> str:
- return json.dumps(content, ensure_ascii=False)
- def deserialize_content(raw: str) -> MessageContent:
- try:
- return json.loads(raw)
- except json.JSONDecodeError:
- return raw
- def text_from_content(content: MessageContent) -> str:
- if isinstance(content, str):
- return content
- if isinstance(content, list):
- pieces: List[str] = []
- for part in content:
- if part.get("type") == "text":
- pieces.append(part.get("text", ""))
- return " ".join(pieces)
- return str(content)
- def extract_history_title(messages: List[Dict[str, Any]]) -> str:
- for message in messages:
- if message.get("role") != "user":
- continue
- title = text_from_content(message.get("content", "")).strip()
- if title:
- return title[:10]
- if messages:
- fallback = text_from_content(messages[0].get("content", "")).strip()
- if fallback:
- return fallback[:10]
- return "空的聊天"[:10]
- def history_backup_path(session_id: int) -> Path:
- return config.BACKUP_DIR / f"chat_history_{session_id}.json"
- def build_download_url(filename: str) -> str:
- base = config.DOWNLOAD_BASE or ""
- return f"{base}/{filename}" if base else filename
- def ensure_session_numbering() -> None:
- # make sure the database is ready before touching schema
- db_core.ensure_database_initialized()
- _ensure_session_number_column()
- _backfill_session_numbers()
- def _ensure_session_number_column() -> None:
- engine = db_core.ENGINE
- if engine is None:
- db_core.ensure_database_initialized()
- engine = db_core.ENGINE
- if engine is None:
- return
- with engine.begin() as connection:
- exists = connection.execute(
- text(
- """
- SELECT 1 FROM information_schema.columns
- WHERE table_schema = :schema
- AND table_name = 'chat_sessions'
- AND column_name = 'user_session_no'
- """
- ),
- {"schema": config.DATABASE_NAME},
- ).first()
- if exists:
- return
- connection.execute(text("ALTER TABLE chat_sessions ADD COLUMN user_session_no INT NOT NULL DEFAULT 0"))
- def _backfill_session_numbers() -> None:
- session_factory = db_core.SessionLocal
- if session_factory is None:
- db_core.ensure_database_initialized()
- session_factory = db_core.SessionLocal
- if session_factory is None:
- return
- with session_factory() as session:
- users = session.execute(select(User.id)).scalars().all()
- for user_id in users:
- sessions = (
- session.execute(
- select(ChatSession)
- .where(ChatSession.user_id == user_id)
- .order_by(ChatSession.created_at, ChatSession.id)
- )
- .scalars()
- .all()
- )
- dirty = False
- for index, chat_session in enumerate(sessions, start=1):
- if chat_session.user_session_no != index:
- chat_session.user_session_no = index
- dirty = True
- if dirty:
- session.commit()
- def _next_session_number(session: Session, user_id: int) -> int:
- current = (
- session.execute(select(func.max(ChatSession.user_session_no)).where(ChatSession.user_id == user_id)).scalar()
- or 0
- )
- return current + 1
- async def get_session_payload(session_id: int, user_id: int, allow_archived: bool = False) -> Dict[str, Any]:
- def loader(session: Session) -> Dict[str, Any]:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="会话不存在")
- if chat_session.archived and not allow_archived:
- raise HTTPException(status_code=404, detail="会话不存在")
- messages = (
- session.execute(
- select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
- )
- .scalars()
- .all()
- )
- payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
- if not chat_session.user_session_no:
- chat_session.user_session_no = _next_session_number(session, user_id)
- session.commit()
- return {
- "session_id": chat_session.id,
- "session_number": chat_session.user_session_no,
- "messages": payload,
- "archived": chat_session.archived,
- }
- return await db_call(loader)
- async def load_messages(session_id: int, user_id: int) -> List[Dict[str, Any]]:
- payload = await get_session_payload(session_id, user_id, allow_archived=True)
- return payload["messages"]
- async def ensure_active_session(session_id: Optional[int], user_id: int) -> Dict[str, Any]:
- if session_id:
- try:
- payload = await get_session_payload(session_id, user_id, allow_archived=False)
- return payload
- except HTTPException as exc:
- if exc.status_code != 404:
- raise
- return await create_chat_session(user_id)
- async def create_chat_session(user_id: int) -> Dict[str, Any]:
- def creator(session: Session) -> Dict[str, Any]:
- chat_session = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id))
- session.add(chat_session)
- session.commit()
- session.refresh(chat_session)
- return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": []}
- return await db_call(creator)
- async def get_latest_session(user_id: int) -> Dict[str, Any]:
- def loader(session: Session) -> Dict[str, Any]:
- stmt = (
- select(ChatSession)
- .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
- .order_by(ChatSession.updated_at.desc())
- )
- chat_session = session.execute(stmt).scalars().first()
- if not chat_session:
- chat_session = ChatSession(user_id=user_id)
- session.add(chat_session)
- session.commit()
- session.refresh(chat_session)
- messages: List[Dict[str, Any]] = []
- else:
- messages = (
- session.execute(
- select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
- )
- .scalars()
- .all()
- )
- messages = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
- if not chat_session.user_session_no:
- chat_session.user_session_no = _next_session_number(session, user_id)
- session.commit()
- return {"session_id": chat_session.id, "session_number": chat_session.user_session_no, "messages": messages}
- return await db_call(loader)
- async def append_message(session_id: int, user_id: int, role: str, content: MessageContent) -> None:
- def writer(session: Session) -> None:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="会话不存在")
- serialized = serialize_content(content)
- message = ChatMessage(session_id=chat_session.id, role=role, content=serialized)
- session.add(message)
- chat_session.updated_at = now_utc()
- if role == "user" and (not chat_session.title or not chat_session.title.strip()):
- candidate = text_from_content(content).strip()
- if candidate:
- chat_session.title = candidate[:30]
- session.commit()
- await db_call(writer)
- async def list_history(user_id: int, page: int, page_size: int) -> Dict[str, Any]:
- def lister(session: Session) -> Dict[str, Any]:
- stmt = (
- select(ChatSession)
- .where(ChatSession.user_id == user_id, ChatSession.archived.is_(False))
- .order_by(ChatSession.updated_at.desc())
- )
- sessions = session.execute(stmt).scalars().all()
- if not sessions:
- fresh = ChatSession(user_id=user_id, user_session_no=_next_session_number(session, user_id))
- session.add(fresh)
- session.commit()
- session.refresh(fresh)
- sessions = [fresh]
- total = len(sessions)
- start = max(page, 0) * page_size
- end = start + page_size
- subset = sessions[start:end]
- items: List[Dict[str, Any]] = []
- for item in subset:
- items.append(
- {
- "session_id": item.id,
- "session_number": item.user_session_no or 0,
- "title": (item.title or f"会话 #{item.id}")[:30],
- "updated_at": (item.updated_at or now_utc()).isoformat(),
- "filename": f"session_{item.id}.json",
- }
- )
- return {"page": page, "page_size": page_size, "total": total, "items": items}
- return await db_call(lister)
- async def move_history_file(user_id: int, session_id: int) -> None:
- def mark_archived(session: Session) -> List[Dict[str, Any]]:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="历史记录不存在")
- messages = (
- session.execute(
- select(ChatMessage).where(ChatMessage.session_id == chat_session.id).order_by(ChatMessage.id)
- )
- .scalars()
- .all()
- )
- payload = [{"role": msg.role, "content": deserialize_content(msg.content)} for msg in messages]
- chat_session.archived = True
- chat_session.updated_at = now_utc()
- session.commit()
- return payload
- messages = await db_call(mark_archived)
- backup_file = history_backup_path(session_id)
- backup_file.parent.mkdir(parents=True, exist_ok=True)
- def _write() -> None:
- with backup_file.open("w", encoding="utf-8") as fp:
- json.dump(messages, fp, ensure_ascii=False, indent=2)
- async with FILE_LOCK:
- await asyncio.to_thread(_write) # type: ignore[name-defined]
- async def delete_history_file(user_id: int, session_id: int) -> None:
- def deleter(session: Session) -> None:
- stmt = select(ChatSession).where(ChatSession.id == session_id, ChatSession.user_id == user_id)
- chat_session = session.execute(stmt).scalar_one_or_none()
- if not chat_session:
- raise HTTPException(status_code=404, detail="历史记录不存在")
- session.delete(chat_session)
- session.commit()
- await db_call(deleter)
- async def export_message_to_blog(content: MessageContent) -> str:
- processed = text_from_content(content)
- processed = processed.replace("\r\n", "\n")
- timestamp = now_utc().strftime("%m%d%H%M")
- first_10 = (
- processed[:10]
- .replace(" ", "")
- .replace("/", "")
- .replace("\\", "")
- .replace(":", "")
- .replace("`", "")
- )
- filename = f"{timestamp}_{first_10 or 'export'}.txt"
- path = config.BLOG_DIR / filename
- def _write() -> None:
- with path.open("w", encoding="utf-8") as fp:
- fp.write(processed)
- await asyncio.to_thread(_write) # type: ignore[name-defined]
- return str(path)
- async def record_export_entry(user_id: int, session_id: Optional[int], file_path: str, content: MessageContent) -> Dict[str, Any]:
- def recorder(session: Session) -> Dict[str, Any]:
- filename = Path(file_path).name
- preview = text_from_content(content).strip()[:200]
- export = ExportedContent(
- user_id=user_id,
- source_session_id=session_id,
- filename=filename,
- file_path=file_path,
- content_preview=preview,
- )
- session.add(export)
- session.commit()
- session.refresh(export)
- user = session.get(User, user_id)
- username = user.username if user else ""
- return {
- "id": export.id,
- "user_id": user_id,
- "username": username,
- "filename": filename,
- "file_path": file_path,
- "created_at": (export.created_at or now_utc()).isoformat(),
- "content_preview": preview,
- }
- return await db_call(recorder)
- async def list_exports_for_user(user_id: int) -> List[Dict[str, Any]]:
- def lister(session: Session) -> List[Dict[str, Any]]:
- stmt = (
- select(ExportedContent)
- .where(ExportedContent.user_id == user_id)
- .order_by(ExportedContent.created_at.desc())
- )
- exports = session.execute(stmt).scalars().all()
- user = session.get(User, user_id)
- username = user.username if user else ""
- results: List[Dict[str, Any]] = []
- for item in exports:
- results.append(
- {
- "id": item.id,
- "user_id": item.user_id,
- "username": username,
- "filename": item.filename,
- "file_path": item.file_path,
- "created_at": (item.created_at or now_utc()).isoformat(),
- "content_preview": (item.content_preview or "")[:200],
- }
- )
- return results
- return await db_call(lister)
- async def list_exports_admin(keyword: Optional[str] = None) -> List[Dict[str, Any]]:
- def lister(session: Session) -> List[Dict[str, Any]]:
- stmt = select(ExportedContent, User.username).join(User, ExportedContent.user_id == User.id)
- if keyword:
- stmt = stmt.where(User.username.like(f"%{keyword.strip()}%"))
- stmt = stmt.order_by(ExportedContent.created_at.desc())
- rows = session.execute(stmt).all()
- results: List[Dict[str, Any]] = []
- for export, username in rows:
- results.append(
- {
- "id": export.id,
- "user_id": export.user_id,
- "username": username,
- "filename": export.filename,
- "file_path": export.file_path,
- "created_at": (export.created_at or now_utc()).isoformat(),
- "content_preview": (export.content_preview or "")[:200],
- }
- )
- return results
- return await db_call(lister)
- async def get_export_record(export_id: int) -> Optional[Dict[str, Any]]:
- def fetcher(session: Session) -> Optional[Dict[str, Any]]:
- export = session.get(ExportedContent, export_id)
- if not export:
- return None
- user = session.get(User, export.user_id)
- username = user.username if user else ""
- return {
- "id": export.id,
- "user_id": export.user_id,
- "username": username,
- "filename": export.filename,
- "file_path": export.file_path,
- }
- return await db_call(fetcher)
- async def prepare_messages_for_completion(
- messages: List[Dict[str, Any]],
- user_content: MessageContent,
- history_count: int,
- ) -> List[Dict[str, Any]]:
- if history_count > 0:
- trimmed = messages[-history_count:]
- if trimmed:
- return trimmed
- return [{"role": "user", "content": user_content}]
- async def save_assistant_message(session_id: int, user_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
- messages.append({"role": "assistant", "content": content})
- await append_message(session_id, user_id, "assistant", content)
|