fastchat.py 10 KB


  1. # -*- coding: utf-8 -*-
  2. import asyncio
  3. import base64
  4. import json
  5. import threading
  6. import uuid
  7. from typing import Any, Dict, List, Optional
  8. from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
  11. from fastapi.staticfiles import StaticFiles
  12. from pydantic import BaseModel
  13. from openai import OpenAI
  14. from chatfast.api import admin_router, auth_router, export_router
  15. from chatfast.config import API_URL, DOWNLOAD_BASE, MODEL_KEYS, STATIC_DIR, UPLOAD_DIR
  16. from chatfast.db import FILE_LOCK, MessageContent, ensure_database_initialized, ensure_directories
  17. from chatfast.services.auth import (
  18. UserInfo,
  19. cleanup_expired_tokens,
  20. ensure_default_admin,
  21. get_current_user,
  22. )
  23. from chatfast.services.chat import (
  24. append_message,
  25. build_download_url,
  26. create_chat_session,
  27. delete_history_file,
  28. ensure_active_session,
  29. ensure_session_numbering,
  30. export_message_to_blog,
  31. get_export_record,
  32. get_latest_session,
  33. get_session_payload,
  34. list_exports_admin,
  35. list_exports_for_user,
  36. list_history,
  37. move_history_file,
  38. prepare_messages_for_completion,
  39. record_export_entry,
  40. save_assistant_message,
  41. )
  42. client = OpenAI(api_key=next(iter(MODEL_KEYS.values()), ""), base_url=API_URL)
  43. class MessageModel(BaseModel):
  44. role: str
  45. content: MessageContent
  46. class ChatRequest(BaseModel):
  47. session_id: int
  48. model: str
  49. content: MessageContent
  50. history_count: int = 0
  51. stream: bool = True
  52. class HistoryActionRequest(BaseModel):
  53. session_id: int
  54. class UploadResponseItem(BaseModel):
  55. type: str
  56. filename: str
  57. data: Optional[str] = None
  58. url: Optional[str] = None
  59. # 确保静态与数据目录在应用初始化前存在
  60. ensure_directories()
  61. ensure_database_initialized()
  62. ensure_default_admin()
  63. ensure_session_numbering()
  64. app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
  65. app.add_middleware(
  66. CORSMiddleware,
  67. allow_origins=["*"],
  68. allow_credentials=True,
  69. allow_methods=["*"],
  70. allow_headers=["*"],
  71. )
  72. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  73. app.include_router(auth_router)
  74. app.include_router(admin_router)
  75. app.include_router(export_router)
  76. @app.on_event("startup")
  77. async def on_startup() -> None:
  78. ensure_directories()
  79. ensure_database_initialized()
  80. ensure_default_admin()
  81. ensure_session_numbering()
  82. await cleanup_expired_tokens()
  83. INDEX_HTML = STATIC_DIR / "index.html"
  84. @app.get("/", response_class=HTMLResponse)
  85. async def serve_index() -> str:
  86. if not INDEX_HTML.exists():
  87. raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
  88. return INDEX_HTML.read_text(encoding="utf-8")
  89. @app.get("/download/{filename}")
  90. async def download_file(filename: str) -> FileResponse:
  91. target = UPLOAD_DIR / filename
  92. if not target.exists():
  93. raise HTTPException(status_code=404, detail="File not found")
  94. return FileResponse(target, filename=filename)
  95. @app.get("/api/config")
  96. async def get_config() -> Dict[str, Any]:
  97. models = list(MODEL_KEYS.keys())
  98. return {
  99. "title": "ChatGPT-like Clone",
  100. "models": models,
  101. "default_model": models[0] if models else "",
  102. "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
  103. "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
  104. }
  105. @app.get("/api/session/latest")
  106. async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  107. return await get_latest_session(current_user.id)
  108. @app.get("/api/session/{session_id}")
  109. async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  110. return await get_session_payload(session_id, current_user.id, allow_archived=True)
  111. @app.post("/api/session/new")
  112. async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  113. return await create_chat_session(current_user.id)
  114. @app.get("/api/history")
  115. async def api_history(
  116. page: int = 0,
  117. page_size: int = 10,
  118. current_user: UserInfo = Depends(get_current_user),
  119. ) -> Dict[str, Any]:
  120. return await list_history(current_user.id, page, page_size)
  121. @app.post("/api/history/move")
  122. async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  123. await move_history_file(current_user.id, payload.session_id)
  124. return {"status": "ok"}
  125. @app.delete("/api/history/{session_id}")
  126. async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  127. await delete_history_file(current_user.id, session_id)
  128. return {"status": "ok"}
  129. @app.post("/api/upload")
  130. async def api_upload(
  131. files: List[UploadFile] = File(...),
  132. current_user: UserInfo = Depends(get_current_user),
  133. ) -> List[UploadResponseItem]:
  134. if not files:
  135. return []
  136. responses: List[UploadResponseItem] = []
  137. for upload in files:
  138. filename = upload.filename or "file"
  139. safe_filename = Path(filename).name or "file"
  140. content_type = (upload.content_type or "").lower()
  141. data = await upload.read()
  142. unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
  143. target_path = UPLOAD_DIR / unique_name
  144. def _write() -> None:
  145. with target_path.open("wb") as fp:
  146. fp.write(data)
  147. await asyncio.to_thread(_write)
  148. if content_type.startswith("image/"):
  149. encoded = base64.b64encode(data).decode("utf-8")
  150. data_url = f"data:{content_type};base64,{encoded}"
  151. responses.append(
  152. UploadResponseItem(
  153. type="image",
  154. filename=safe_filename,
  155. data=data_url,
  156. url=build_download_url(unique_name),
  157. )
  158. )
  159. else:
  160. responses.append(
  161. UploadResponseItem(
  162. type="file",
  163. filename=safe_filename,
  164. url=build_download_url(unique_name),
  165. )
  166. )
  167. return responses
  168. @app.post("/api/chat")
  169. async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)):
  170. if payload.model not in MODEL_KEYS:
  171. raise HTTPException(status_code=400, detail="未知的模型")
  172. session_payload = await ensure_active_session(payload.session_id, current_user.id)
  173. active_session_id = session_payload["session_id"]
  174. session_number = session_payload.get("session_number", active_session_id)
  175. messages = list(session_payload.get("messages") or [])
  176. user_message = {"role": "user", "content": payload.content}
  177. messages.append(user_message)
  178. await append_message(active_session_id, current_user.id, "user", payload.content)
  179. client.api_key = MODEL_KEYS[payload.model]
  180. to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
  181. if payload.stream:
  182. queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
  183. aggregated: List[str] = []
  184. loop = asyncio.get_running_loop()
  185. def worker() -> None:
  186. try:
  187. response = client.chat.completions.create(
  188. model=payload.model,
  189. messages=to_send,
  190. stream=True,
  191. )
  192. for chunk in response:
  193. try:
  194. delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
  195. except (IndexError, AttributeError):
  196. delta = None
  197. if delta:
  198. aggregated.append(delta)
  199. asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
  200. asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
  201. except Exception as exc: # pragma: no cover - 网络调用
  202. asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
  203. threading.Thread(target=worker, daemon=True).start()
  204. async def streamer():
  205. meta = {"type": "meta", "session_id": active_session_id, "session_number": session_number}
  206. yield json.dumps(meta, ensure_ascii=False) + "\n"
  207. try:
  208. while True:
  209. item = await queue.get()
  210. if item["type"] == "delta":
  211. yield json.dumps(item, ensure_ascii=False) + "\n"
  212. elif item["type"] == "complete":
  213. assistant_text = "".join(aggregated)
  214. await save_assistant_message(active_session_id, current_user.id, messages, assistant_text)
  215. yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
  216. break
  217. elif item["type"] == "error":
  218. yield json.dumps(item, ensure_ascii=False) + "\n"
  219. break
  220. except asyncio.CancelledError: # pragma: no cover - 流被取消
  221. raise
  222. return StreamingResponse(streamer(), media_type="application/x-ndjson")
  223. try:
  224. completion = await asyncio.to_thread(
  225. client.chat.completions.create,
  226. model=payload.model,
  227. messages=to_send,
  228. stream=False,
  229. )
  230. except Exception as exc: # pragma: no cover - 网络调用
  231. raise HTTPException(status_code=500, detail=str(exc)) from exc
  232. choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
  233. if not choice:
  234. raise HTTPException(status_code=500, detail="响应格式不正确")
  235. assistant_content = getattr(choice.message, "content", "")
  236. if not assistant_content:
  237. assistant_content = ""
  238. await save_assistant_message(active_session_id, current_user.id, messages, assistant_content)
  239. return {"session_id": active_session_id, "session_number": session_number, "message": assistant_content}
  240. if __name__ == "__main__":
  241. import uvicorn
  242. uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)