fastchat.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # -*- coding: utf-8 -*-
  2. import asyncio
  3. import base64
  4. import json
  5. import mimetypes
  6. import threading
  7. import uuid
  8. from typing import Any, Dict, List, Optional
  9. from urllib.parse import urlparse
  10. from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File
  11. from fastapi.middleware.cors import CORSMiddleware
  12. from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
  13. from fastapi.staticfiles import StaticFiles
  14. from pydantic import BaseModel
  15. from openai import OpenAI
  16. from pathlib import Path
  17. from chatfast.api import admin_router, auth_router, export_router
  18. from chatfast.config import API_URL, DOWNLOAD_BASE, MODEL_KEYS, STATIC_DIR, UPLOAD_DIR
  19. from chatfast.db import FILE_LOCK, MessageContent, ensure_database_initialized, ensure_directories
  20. from chatfast.services.auth import (
  21. UserInfo,
  22. cleanup_expired_tokens,
  23. ensure_default_admin,
  24. get_current_user,
  25. )
  26. from chatfast.services.chat import (
  27. append_message,
  28. build_download_url,
  29. create_chat_session,
  30. delete_history_file,
  31. ensure_active_session,
  32. ensure_session_numbering,
  33. export_message_to_blog,
  34. get_export_record,
  35. get_latest_session,
  36. get_session_payload,
  37. list_exports_admin,
  38. list_exports_for_user,
  39. list_history,
  40. move_history_file,
  41. prepare_messages_for_completion,
  42. record_export_entry,
  43. save_assistant_message,
  44. )
  45. client = OpenAI(api_key=next(iter(MODEL_KEYS.values()), ""), base_url=API_URL)
  46. class MessageModel(BaseModel):
  47. role: str
  48. content: MessageContent
  49. class ChatRequest(BaseModel):
  50. session_id: int
  51. model: str
  52. content: MessageContent
  53. history_count: int = 0
  54. stream: bool = True
  55. class HistoryActionRequest(BaseModel):
  56. session_id: int
  57. class UploadResponseItem(BaseModel):
  58. type: str
  59. filename: str
  60. url: Optional[str] = None
  61. path: Optional[str] = None
  62. def _is_data_url(value: str) -> bool:
  63. return value.startswith("data:")
  64. def _resolve_upload_path(reference: str) -> Optional[Path]:
  65. if not reference or _is_data_url(reference):
  66. return None
  67. parsed = urlparse(reference)
  68. candidate = Path(parsed.path).name if parsed.scheme else Path(reference).name
  69. if not candidate:
  70. return None
  71. return UPLOAD_DIR / candidate
  72. async def _inline_local_image(reference: str) -> str:
  73. if not reference or _is_data_url(reference):
  74. return reference
  75. file_path = _resolve_upload_path(reference)
  76. if not file_path or not file_path.exists():
  77. return reference
  78. try:
  79. data = await asyncio.to_thread(file_path.read_bytes)
  80. except OSError:
  81. return reference
  82. mime = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
  83. encoded = base64.b64encode(data).decode("utf-8")
  84. return f"data:{mime};base64,{encoded}"
  85. async def _prepare_content_for_model(content: MessageContent) -> MessageContent:
  86. if not isinstance(content, list):
  87. return content
  88. prepared: List[Dict[str, Any]] = []
  89. for part in content:
  90. if not isinstance(part, dict):
  91. prepared.append(part)
  92. continue
  93. if part.get("type") != "image_url":
  94. prepared.append({**part})
  95. continue
  96. image_data = dict(part.get("image_url") or {})
  97. url_value = str(image_data.get("url") or "")
  98. new_url = await _inline_local_image(url_value)
  99. if new_url:
  100. image_data["url"] = new_url
  101. prepared.append({"type": "image_url", "image_url": image_data})
  102. return prepared
  103. async def _prepare_messages_for_model(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  104. prepared: List[Dict[str, Any]] = []
  105. for message in messages:
  106. prepared.append(
  107. {
  108. "role": message.get("role", ""),
  109. "content": await _prepare_content_for_model(message.get("content")),
  110. }
  111. )
  112. return prepared
  113. # 确保静态与数据目录在应用初始化前存在
  114. ensure_directories()
  115. ensure_database_initialized()
  116. ensure_default_admin()
  117. ensure_session_numbering()
  118. app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
  119. app.add_middleware(
  120. CORSMiddleware,
  121. allow_origins=["*"],
  122. allow_credentials=True,
  123. allow_methods=["*"],
  124. allow_headers=["*"],
  125. )
  126. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  127. app.include_router(auth_router)
  128. app.include_router(admin_router)
  129. app.include_router(export_router)
  130. @app.on_event("startup")
  131. async def on_startup() -> None:
  132. ensure_directories()
  133. ensure_database_initialized()
  134. ensure_default_admin()
  135. ensure_session_numbering()
  136. await cleanup_expired_tokens()
  137. INDEX_HTML = STATIC_DIR / "index.html"
  138. @app.get("/", response_class=HTMLResponse)
  139. async def serve_index() -> str:
  140. if not INDEX_HTML.exists():
  141. raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
  142. return INDEX_HTML.read_text(encoding="utf-8")
  143. @app.get("/download/{filename}")
  144. async def download_file(filename: str) -> FileResponse:
  145. target = UPLOAD_DIR / filename
  146. if not target.exists():
  147. raise HTTPException(status_code=404, detail="File not found")
  148. return FileResponse(target, filename=filename)
  149. @app.get("/api/config")
  150. async def get_config() -> Dict[str, Any]:
  151. models = list(MODEL_KEYS.keys())
  152. return {
  153. "title": "ChatGPT-like Clone",
  154. "models": models,
  155. "default_model": models[0] if models else "",
  156. "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
  157. "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
  158. }
  159. @app.get("/api/session/latest")
  160. async def api_latest_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  161. return await get_latest_session(current_user.id)
  162. @app.get("/api/session/{session_id}")
  163. async def api_get_session(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  164. return await get_session_payload(session_id, current_user.id, allow_archived=True)
  165. @app.post("/api/session/new")
  166. async def api_new_session(current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  167. return await create_chat_session(current_user.id)
  168. @app.get("/api/history")
  169. async def api_history(
  170. page: int = 0,
  171. page_size: int = 10,
  172. current_user: UserInfo = Depends(get_current_user),
  173. ) -> Dict[str, Any]:
  174. return await list_history(current_user.id, page, page_size)
  175. @app.post("/api/history/move")
  176. async def api_move_history(payload: HistoryActionRequest, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  177. await move_history_file(current_user.id, payload.session_id)
  178. return {"status": "ok"}
  179. @app.delete("/api/history/{session_id}")
  180. async def api_delete_history(session_id: int, current_user: UserInfo = Depends(get_current_user)) -> Dict[str, Any]:
  181. await delete_history_file(current_user.id, session_id)
  182. return {"status": "ok"}
  183. @app.post("/api/upload")
  184. async def api_upload(
  185. files: List[UploadFile] = File(...),
  186. current_user: UserInfo = Depends(get_current_user),
  187. ) -> List[UploadResponseItem]:
  188. if not files:
  189. return []
  190. responses: List[UploadResponseItem] = []
  191. for upload in files:
  192. filename = upload.filename or "file"
  193. safe_filename = Path(filename).name or "file"
  194. content_type = (upload.content_type or "").lower()
  195. data = await upload.read()
  196. unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
  197. target_path = UPLOAD_DIR / unique_name
  198. def _write() -> None:
  199. with target_path.open("wb") as fp:
  200. fp.write(data)
  201. await asyncio.to_thread(_write)
  202. download_url = build_download_url(unique_name)
  203. if content_type.startswith("image/"):
  204. responses.append(
  205. UploadResponseItem(
  206. type="image",
  207. filename=safe_filename,
  208. url=download_url,
  209. path=unique_name,
  210. )
  211. )
  212. else:
  213. responses.append(
  214. UploadResponseItem(
  215. type="file",
  216. filename=safe_filename,
  217. url=download_url,
  218. path=unique_name,
  219. )
  220. )
  221. return responses
  222. @app.post("/api/chat")
  223. async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = Depends(get_current_user)):
  224. if payload.model not in MODEL_KEYS:
  225. raise HTTPException(status_code=400, detail="未知的模型")
  226. session_payload = await ensure_active_session(payload.session_id, current_user.id)
  227. active_session_id = session_payload["session_id"]
  228. session_number = session_payload.get("session_number", active_session_id)
  229. messages = list(session_payload.get("messages") or [])
  230. user_message = {"role": "user", "content": payload.content}
  231. messages.append(user_message)
  232. await append_message(active_session_id, current_user.id, "user", payload.content)
  233. client.api_key = MODEL_KEYS[payload.model]
  234. to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
  235. model_messages = await _prepare_messages_for_model(to_send)
  236. if payload.stream:
  237. queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
  238. aggregated: List[str] = []
  239. loop = asyncio.get_running_loop()
  240. def worker() -> None:
  241. try:
  242. response = client.chat.completions.create(
  243. model=payload.model,
  244. messages=model_messages,
  245. stream=True,
  246. )
  247. for chunk in response:
  248. try:
  249. delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
  250. except (IndexError, AttributeError):
  251. delta = None
  252. if delta:
  253. aggregated.append(delta)
  254. asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
  255. asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
  256. except Exception as exc: # pragma: no cover - 网络调用
  257. asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
  258. threading.Thread(target=worker, daemon=True).start()
  259. async def streamer():
  260. meta = {"type": "meta", "session_id": active_session_id, "session_number": session_number}
  261. yield json.dumps(meta, ensure_ascii=False) + "\n"
  262. try:
  263. while True:
  264. item = await queue.get()
  265. if item["type"] == "delta":
  266. yield json.dumps(item, ensure_ascii=False) + "\n"
  267. elif item["type"] == "complete":
  268. assistant_text = "".join(aggregated)
  269. await save_assistant_message(active_session_id, current_user.id, messages, assistant_text)
  270. yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
  271. break
  272. elif item["type"] == "error":
  273. yield json.dumps(item, ensure_ascii=False) + "\n"
  274. break
  275. except asyncio.CancelledError: # pragma: no cover - 流被取消
  276. raise
  277. return StreamingResponse(streamer(), media_type="application/x-ndjson")
  278. try:
  279. completion = await asyncio.to_thread(
  280. client.chat.completions.create,
  281. model=payload.model,
  282. messages=model_messages,
  283. stream=False,
  284. )
  285. except Exception as exc: # pragma: no cover - 网络调用
  286. raise HTTPException(status_code=500, detail=str(exc)) from exc
  287. choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
  288. if not choice:
  289. raise HTTPException(status_code=500, detail="响应格式不正确")
  290. assistant_content = getattr(choice.message, "content", "")
  291. if not assistant_content:
  292. assistant_content = ""
  293. await save_assistant_message(active_session_id, current_user.id, messages, assistant_content)
  294. return {"session_id": active_session_id, "session_number": session_number, "message": assistant_content}
  295. if __name__ == "__main__":
  296. import uvicorn
  297. uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)