| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- # -*- coding: utf-8 -*-
- import asyncio
- import base64
- import json
- import threading
- import uuid
- from typing import Any, Dict, List, Optional
- from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from pydantic import BaseModel
- from openai import OpenAI
- from chatfast.api import admin_router, auth_router, export_router
- from chatfast.config import API_URL, DOWNLOAD_BASE, MODEL_KEYS, STATIC_DIR, UPLOAD_DIR
- from chatfast.db import FILE_LOCK, MessageContent, ensure_database_initialized, ensure_directories
- from chatfast.services.auth import (
- UserInfo,
- cleanup_expired_tokens,
- ensure_default_admin,
- get_current_user,
- )
- from chatfast.services.chat import (
- append_message,
- build_download_url,
- create_chat_session,
- delete_history_file,
- ensure_active_session,
- ensure_session_numbering,
- export_message_to_blog,
- get_export_record,
- get_latest_session,
- get_session_payload,
- list_exports_admin,
- list_exports_for_user,
- list_history,
- move_history_file,
- prepare_messages_for_completion,
- record_export_entry,
- save_assistant_message,
- )
- client = OpenAI(api_key=next(iter(MODEL_KEYS.values()), ""), base_url=API_URL)
- class MessageModel(BaseModel):
- role: str
- content: MessageContent
- class ChatRequest(BaseModel):
- session_id: int
- model: str
- content: MessageContent
- history_count: int = 0
- stream: bool = True
- class HistoryActionRequest(BaseModel):
- session_id: int
- class UploadResponseItem(BaseModel):
- type: str
- filename: str
- data: Optional[str] = None
- url: Optional[str] = None
- # 确保静态与数据目录在应用初始化前存在
- ensure_directories()
- ensure_database_initialized()
- ensure_default_admin()
- ensure_session_numbering()
- app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
- app.include_router(auth_router)
- app.include_router(admin_router)
- app.include_router(export_router)
- @app.on_event("startup")
- async def on_startup() -> None:
- ensure_directories()
- ensure_database_initialized()
- ensure_default_admin()
- ensure_session_numbering()
- await cleanup_expired_tokens()
- INDEX_HTML = STATIC_DIR / "index.html"
- @app.get("/", response_class=HTMLResponse)
- async def serve_index() -> str:
- if not INDEX_HTML.exists():
- raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
- return INDEX_HTML.read_text(encoding="utf-8")
- @app.get("/download/{filename}")
- async def download_file(filename: str) -> FileResponse:
- target = UPLOAD_DIR / filename
- if not target.exists():
- raise HTTPException(status_code=404, detail="File not found")
- return FileResponse(target, filename=filename)
- @app.get("/api/config")
- async def get_config() -> Dict[str, Any]:
- models = list(MODEL_KEYS.keys())
- return {
- "title": "ChatGPT-like Clone",
- "models": models,
- "default_model": models[0] if models else "",
- "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
- "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
- }
- @app.get("/api/session/latest")
- async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- return await get_latest_session(current_user.id)
- @app.get("/api/session/{session_id}")
- async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- return await get_session_payload(session_id, current_user.id, allow_archived=True)
- @app.post("/api/session/new")
- async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- return await create_chat_session(current_user.id)
- @app.get("/api/history")
- async def api_history(
- page: int = 0,
- page_size: int = 10,
- current_user: UserInfo = Depends(get_current_user),
- ) -> Dict[str, Any]:
- return await list_history(current_user.id, page, page_size)
- @app.post("/api/history/move")
- async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- await move_history_file(current_user.id, payload.session_id)
- return {"status": "ok"}
- @app.delete("/api/history/{session_id}")
- async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
- await delete_history_file(current_user.id, session_id)
- return {"status": "ok"}
- @app.post("/api/upload")
- async def api_upload(
- files: List[UploadFile] = File(...),
- current_user: UserInfo = Depends(get_current_user),
- ) -> List[UploadResponseItem]:
- if not files:
- return []
- responses: List[UploadResponseItem] = []
- for upload in files:
- filename = upload.filename or "file"
- safe_filename = Path(filename).name or "file"
- content_type = (upload.content_type or "").lower()
- data = await upload.read()
- unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
- target_path = UPLOAD_DIR / unique_name
- def _write() -> None:
- with target_path.open("wb") as fp:
- fp.write(data)
- await asyncio.to_thread(_write)
- if content_type.startswith("image/"):
- encoded = base64.b64encode(data).decode("utf-8")
- data_url = f"data:{content_type};base64,{encoded}"
- responses.append(
- UploadResponseItem(
- type="image",
- filename=safe_filename,
- data=data_url,
- url=build_download_url(unique_name),
- )
- )
- else:
- responses.append(
- UploadResponseItem(
- type="file",
- filename=safe_filename,
- url=build_download_url(unique_name),
- )
- )
- return responses
- @app.post("/api/chat")
- async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)):
- if payload.model not in MODEL_KEYS:
- raise HTTPException(status_code=400, detail="未知的模型")
- session_payload = await ensure_active_session(payload.session_id, current_user.id)
- active_session_id = session_payload["session_id"]
- session_number = session_payload.get("session_number", active_session_id)
- messages = list(session_payload.get("messages") or [])
- user_message = {"role": "user", "content": payload.content}
- messages.append(user_message)
- await append_message(active_session_id, current_user.id, "user", payload.content)
- client.api_key = MODEL_KEYS[payload.model]
- to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
- if payload.stream:
- queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
- aggregated: List[str] = []
- loop = asyncio.get_running_loop()
- def worker() -> None:
- try:
- response = client.chat.completions.create(
- model=payload.model,
- messages=to_send,
- stream=True,
- )
- for chunk in response:
- try:
- delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
- except (IndexError, AttributeError):
- delta = None
- if delta:
- aggregated.append(delta)
- asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
- asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
- except Exception as exc: # pragma: no cover - 网络调用
- asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
- threading.Thread(target=worker, daemon=True).start()
- async def streamer():
- meta = {"type": "meta", "session_id": active_session_id, "session_number": session_number}
- yield json.dumps(meta, ensure_ascii=False) + "\n"
- try:
- while True:
- item = await queue.get()
- if item["type"] == "delta":
- yield json.dumps(item, ensure_ascii=False) + "\n"
- elif item["type"] == "complete":
- assistant_text = "".join(aggregated)
- await save_assistant_message(active_session_id, current_user.id, messages, assistant_text)
- yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
- break
- elif item["type"] == "error":
- yield json.dumps(item, ensure_ascii=False) + "\n"
- break
- except asyncio.CancelledError: # pragma: no cover - 流被取消
- raise
- return StreamingResponse(streamer(), media_type="application/x-ndjson")
- try:
- completion = await asyncio.to_thread(
- client.chat.completions.create,
- model=payload.model,
- messages=to_send,
- stream=False,
- )
- except Exception as exc: # pragma: no cover - 网络调用
- raise HTTPException(status_code=500, detail=str(exc)) from exc
- choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
- if not choice:
- raise HTTPException(status_code=500, detail="响应格式不正确")
- assistant_content = getattr(choice.message, "content", "")
- if not assistant_content:
- assistant_content = ""
- await save_assistant_message(active_session_id, current_user.id, messages, assistant_content)
- return {"session_id": active_session_id, "session_number": session_number, "message": assistant_content}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)
|