# -*- coding: utf-8 -*- import asyncio import base64 import json import threading import uuid from typing import Any, Dict, List, Optional from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from openai import OpenAI from chatfast.api import admin_router, auth_router, export_router from chatfast.config import API_URL, DOWNLOAD_BASE, MODEL_KEYS, STATIC_DIR, UPLOAD_DIR from chatfast.db import FILE_LOCK, MessageContent, ensure_database_initialized, ensure_directories from chatfast.services.auth import ( UserInfo, cleanup_expired_tokens, ensure_default_admin, get_current_user, ) from chatfast.services.chat import ( append_message, build_download_url, create_chat_session, delete_history_file, ensure_active_session, ensure_session_numbering, export_message_to_blog, get_export_record, get_latest_session, get_session_payload, list_exports_admin, list_exports_for_user, list_history, move_history_file, prepare_messages_for_completion, record_export_entry, save_assistant_message, ) client = OpenAI(api_key=next(iter(MODEL_KEYS.values()), ""), base_url=API_URL) class MessageModel(BaseModel): role: str content: MessageContent class ChatRequest(BaseModel): session_id: int model: str content: MessageContent history_count: int = 0 stream: bool = True class HistoryActionRequest(BaseModel): session_id: int class UploadResponseItem(BaseModel): type: str filename: str data: Optional[str] = None url: Optional[str] = None # 确保静态与数据目录在应用初始化前存在 ensure_directories() ensure_database_initialized() ensure_default_admin() ensure_session_numbering() app = FastAPI(title="ChatGPT-like Clone", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.include_router(auth_router) app.include_router(admin_router) app.include_router(export_router) @app.on_event("startup") async def on_startup() -> None: ensure_directories() ensure_database_initialized() ensure_default_admin() ensure_session_numbering() await cleanup_expired_tokens() INDEX_HTML = STATIC_DIR / "index.html" @app.get("/", response_class=HTMLResponse) async def serve_index() -> str: if not INDEX_HTML.exists(): raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在") return INDEX_HTML.read_text(encoding="utf-8") @app.get("/download/{filename}") async def download_file(filename: str) -> FileResponse: target = UPLOAD_DIR / filename if not target.exists(): raise HTTPException(status_code=404, detail="File not found") return FileResponse(target, filename=filename) @app.get("/api/config") async def get_config() -> Dict[str, Any]: models = list(MODEL_KEYS.keys()) return { "title": "ChatGPT-like Clone", "models": models, "default_model": models[0] if models else "", "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"], "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "", } @app.get("/api/session/latest") async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: return await get_latest_session(current_user.id) @app.get("/api/session/{session_id}") async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: return await get_session_payload(session_id, current_user.id, allow_archived=True) @app.post("/api/session/new") async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: return await create_chat_session(current_user.id) @app.get("/api/history") async def api_history( page: int = 0, page_size: int = 10, current_user: UserInfo = Depends(get_current_user), ) -> Dict[str, Any]: return await list_history(current_user.id, page, page_size) @app.post("/api/history/move") async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: await move_history_file(current_user.id, payload.session_id) return {"status": "ok"} @app.delete("/api/history/{session_id}") async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]: await delete_history_file(current_user.id, session_id) return {"status": "ok"} @app.post("/api/upload") async def api_upload( files: List[UploadFile] = File(...), current_user: UserInfo = Depends(get_current_user), ) -> List[UploadResponseItem]: if not files: return [] responses: List[UploadResponseItem] = [] for upload in files: filename = upload.filename or "file" safe_filename = Path(filename).name or "file" content_type = (upload.content_type or "").lower() data = await upload.read() unique_name = f"{uuid.uuid4().hex}_{safe_filename}" target_path = UPLOAD_DIR / unique_name def _write() -> None: with target_path.open("wb") as fp: fp.write(data) await asyncio.to_thread(_write) if content_type.startswith("image/"): encoded = base64.b64encode(data).decode("utf-8") data_url = f"data:{content_type};base64,{encoded}" responses.append( UploadResponseItem( type="image", filename=safe_filename, data=data_url, url=build_download_url(unique_name), ) ) else: responses.append( UploadResponseItem( type="file", filename=safe_filename, url=build_download_url(unique_name), ) ) return responses @app.post("/api/chat") async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)): if payload.model not in MODEL_KEYS: raise HTTPException(status_code=400, detail="未知的模型") session_payload = await ensure_active_session(payload.session_id, current_user.id) active_session_id = session_payload["session_id"] session_number = session_payload.get("session_number", active_session_id) messages = list(session_payload.get("messages") or []) user_message = {"role": "user", "content": payload.content} messages.append(user_message) await append_message(active_session_id, current_user.id, "user", payload.content) client.api_key = MODEL_KEYS[payload.model] to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0)) if payload.stream: queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() aggregated: List[str] = [] loop = asyncio.get_running_loop() def worker() -> None: try: response = client.chat.completions.create( model=payload.model, messages=to_send, stream=True, ) for chunk in response: try: delta = chunk.choices[0].delta.content # type: ignore[attr-defined] except (IndexError, AttributeError): delta = None if delta: aggregated.append(delta) asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop) asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop) except Exception as exc: # pragma: no cover - 网络调用 asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop) threading.Thread(target=worker, daemon=True).start() async def streamer(): meta = {"type": "meta", "session_id": active_session_id, "session_number": session_number} yield json.dumps(meta, ensure_ascii=False) + "\n" try: while True: item = await queue.get() if item["type"] == "delta": yield json.dumps(item, ensure_ascii=False) + "\n" elif item["type"] == "complete": assistant_text = "".join(aggregated) await save_assistant_message(active_session_id, current_user.id, messages, assistant_text) yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n" break elif item["type"] == "error": yield json.dumps(item, ensure_ascii=False) + "\n" break except asyncio.CancelledError: # pragma: no cover - 流被取消 raise return StreamingResponse(streamer(), media_type="application/x-ndjson") try: completion = await asyncio.to_thread( client.chat.completions.create, model=payload.model, messages=to_send, stream=False, ) except Exception as exc: # pragma: no cover - 网络调用 raise HTTPException(status_code=500, detail=str(exc)) from exc choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined] if not choice: raise HTTPException(status_code=500, detail="响应格式不正确") assistant_content = getattr(choice.message, "content", "") if not assistant_content: assistant_content = "" await save_assistant_message(active_session_id, current_user.id, messages, assistant_content) return {"session_id": active_session_id, "session_number": session_number, "message": assistant_content} if __name__ == "__main__": import uvicorn uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)