fastchat.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. # -*- coding: utf-8 -*-
  2. import asyncio
  3. import base64
  4. import datetime
  5. import json
  6. import os
  7. import re
  8. import shutil
  9. import threading
  10. import uuid
  11. from pathlib import Path
  12. from typing import Any, Dict, List, Optional, Union
  13. from fastapi import Body, FastAPI, HTTPException, UploadFile, File
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
  16. from fastapi.staticfiles import StaticFiles
  17. from pydantic import BaseModel
  18. from openai import OpenAI
  19. # =============================
  20. # 基础配置
  21. # =============================
  22. BASE_DIR = Path(__file__).resolve().parent
  23. DATA_DIR = BASE_DIR / "data"
  24. BACKUP_DIR = BASE_DIR / "data_bak"
  25. BLOG_DIR = BASE_DIR / "blog"
  26. UPLOAD_DIR = BASE_DIR / "uploads"
  27. STATIC_DIR = BASE_DIR / "static"
  28. SESSION_ID_FILE = DATA_DIR / "session_id.txt"
  29. # 默认上传文件下载地址,可通过环境变量覆盖
  30. DEFAULT_UPLOAD_BASE = os.getenv("UPLOAD_BASE_URL", "/download/")
  31. DOWNLOAD_BASE = DEFAULT_UPLOAD_BASE.rstrip("/")
  32. # 与 appchat.py 相同的模型与密钥配置(仅示例)
  33. default_key = "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa"
  34. MODEL_KEYS: Dict[str, str] = {
  35. "grok-3": default_key,
  36. "grok-4": default_key,
  37. "gpt-5.1-2025-11-13": default_key,
  38. "gpt-5-2025-08-07": default_key,
  39. "gpt-4o-mini": default_key,
  40. # "gpt-4.1-mini-2025-04-14": default_key,
  41. "o1-mini": default_key,
  42. "o4-mini": default_key,
  43. "deepseek-v3": default_key,
  44. "deepseek-r1": default_key,
  45. "gpt-4o-all": default_key,
  46. # "gpt-5-mini-2025-08-07": default_key,
  47. "o3-mini-all": default_key,
  48. }
  49. API_URL = "https://yunwu.ai/v1"
  50. client = OpenAI(api_key=default_key, base_url=API_URL)
  51. # 锁用于避免并发文件写入导致的数据损坏
  52. FILE_LOCK = asyncio.Lock()
  53. SESSION_LOCK = asyncio.Lock()
  54. MessageContent = Union[str, List[Dict[str, Any]]]
  55. SESSION_FILE_PATTERN = re.compile(r"chat_history_(\d+)\.json")
  56. def ensure_directories() -> None:
  57. for path in [DATA_DIR, BACKUP_DIR, BLOG_DIR, UPLOAD_DIR, STATIC_DIR]:
  58. path.mkdir(parents=True, exist_ok=True)
  59. def extract_session_id(path: Path) -> Optional[int]:
  60. match = SESSION_FILE_PATTERN.search(path.name)
  61. if match:
  62. try:
  63. return int(match.group(1))
  64. except ValueError:
  65. return None
  66. return None
  67. def load_session_counter() -> int:
  68. if SESSION_ID_FILE.exists():
  69. try:
  70. value = SESSION_ID_FILE.read_text(encoding="utf-8").strip()
  71. return int(value) if value.isdigit() else 0
  72. except Exception:
  73. return 0
  74. return 0
  75. def save_session_counter(value: int) -> None:
  76. SESSION_ID_FILE.write_text(str(value), encoding="utf-8")
  77. def sync_session_counter_with_history() -> None:
  78. max_session = load_session_counter()
  79. for path in DATA_DIR.glob("chat_history_*.json"):
  80. session_id = extract_session_id(path)
  81. if session_id is not None and session_id > max_session:
  82. max_session = session_id
  83. save_session_counter(max_session)
  84. def text_from_content(content: MessageContent) -> str:
  85. if isinstance(content, str):
  86. return content
  87. if isinstance(content, list):
  88. pieces: List[str] = []
  89. for part in content:
  90. if part.get("type") == "text":
  91. pieces.append(part.get("text", ""))
  92. return " ".join(pieces)
  93. return str(content)
  94. def extract_history_title(messages: List[Dict[str, Any]]) -> str:
  95. """Return the first meaningful title extracted from user messages."""
  96. for message in messages:
  97. if message.get("role") != "user":
  98. continue
  99. title = text_from_content(message.get("content", "")).strip()
  100. if title:
  101. return title[:10]
  102. if messages:
  103. fallback = text_from_content(messages[0].get("content", "")).strip()
  104. if fallback:
  105. return fallback[:10]
  106. return "空的聊天"[:10]
  107. def history_path(session_id: int) -> Path:
  108. return DATA_DIR / f"chat_history_{session_id}.json"
  109. def build_download_url(filename: str) -> str:
  110. base = DOWNLOAD_BASE or ""
  111. return f"{base}/{filename}" if base else filename
  112. async def read_json_file(path: Path) -> List[Dict[str, Any]]:
  113. def _read() -> List[Dict[str, Any]]:
  114. with path.open("r", encoding="utf-8") as fp:
  115. return json.load(fp)
  116. return await asyncio.to_thread(_read)
  117. async def write_json_file(path: Path, payload: List[Dict[str, Any]]) -> None:
  118. serialized = json.dumps(payload, ensure_ascii=False)
  119. def _write() -> None:
  120. with path.open("w", encoding="utf-8") as fp:
  121. fp.write(serialized)
  122. async with FILE_LOCK:
  123. await asyncio.to_thread(_write)
  124. async def load_messages(session_id: int) -> List[Dict[str, Any]]:
  125. path = history_path(session_id)
  126. if not path.exists():
  127. return []
  128. try:
  129. return await read_json_file(path)
  130. except Exception:
  131. return []
  132. async def save_messages(session_id: int, messages: List[Dict[str, Any]]) -> None:
  133. path = history_path(session_id)
  134. path.parent.mkdir(parents=True, exist_ok=True)
  135. await write_json_file(path, messages)
  136. async def get_latest_session() -> Dict[str, Any]:
  137. history_files = sorted(DATA_DIR.glob("chat_history_*.json"), key=lambda p: p.stat().st_mtime)
  138. if history_files:
  139. latest = history_files[-1]
  140. session_id = extract_session_id(latest)
  141. if session_id is None:
  142. session_id = await asyncio.to_thread(load_session_counter)
  143. try:
  144. messages = await read_json_file(latest)
  145. except Exception:
  146. messages = []
  147. return {"session_id": session_id, "messages": messages}
  148. session_id = await asyncio.to_thread(load_session_counter)
  149. return {"session_id": session_id, "messages": []}
  150. async def increment_session_id() -> int:
  151. async with SESSION_LOCK:
  152. current = await asyncio.to_thread(load_session_counter)
  153. next_session = current + 1
  154. await asyncio.to_thread(save_session_counter, next_session)
  155. return next_session
  156. async def list_history(page: int, page_size: int) -> Dict[str, Any]:
  157. files = sorted(DATA_DIR.glob("chat_history_*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
  158. total = len(files)
  159. start = max(page, 0) * page_size
  160. end = start + page_size
  161. items: List[Dict[str, Any]] = []
  162. for path in files[start:end]:
  163. session_id = extract_session_id(path)
  164. if session_id is None:
  165. continue
  166. try:
  167. messages = await read_json_file(path)
  168. except Exception:
  169. messages = []
  170. title = extract_history_title(messages)
  171. items.append({
  172. "session_id": session_id,
  173. "title": title,
  174. "updated_at": datetime.datetime.fromtimestamp(path.stat().st_mtime).isoformat(),
  175. "filename": path.name,
  176. })
  177. return {
  178. "page": page,
  179. "page_size": page_size,
  180. "total": total,
  181. "items": items,
  182. }
  183. async def move_history_file(session_id: int) -> None:
  184. src = history_path(session_id)
  185. if not src.exists():
  186. raise HTTPException(status_code=404, detail="历史记录不存在")
  187. dst = BACKUP_DIR / src.name
  188. dst.parent.mkdir(parents=True, exist_ok=True)
  189. def _move() -> None:
  190. shutil.move(str(src), str(dst))
  191. async with FILE_LOCK:
  192. await asyncio.to_thread(_move)
  193. async def delete_history_file(session_id: int) -> None:
  194. target = history_path(session_id)
  195. if not target.exists():
  196. raise HTTPException(status_code=404, detail="历史记录不存在")
  197. def _delete() -> None:
  198. target.unlink(missing_ok=True)
  199. async with FILE_LOCK:
  200. await asyncio.to_thread(_delete)
  201. async def export_message_to_blog(content: MessageContent) -> str:
  202. processed = text_from_content(content)
  203. processed = processed.replace("\r\n", "\n")
  204. timestamp = datetime.datetime.now().strftime("%m%d%H%M")
  205. first_10 = (
  206. processed[:10]
  207. .replace(" ", "")
  208. .replace("/", "")
  209. .replace("\\", "")
  210. .replace(":", "")
  211. .replace("`", "")
  212. )
  213. filename = f"{timestamp}_{first_10 or 'export'}.txt"
  214. path = BLOG_DIR / filename
  215. def _write() -> None:
  216. with path.open("w", encoding="utf-8") as fp:
  217. fp.write(processed)
  218. await asyncio.to_thread(_write)
  219. return str(path)
  220. class MessageModel(BaseModel):
  221. role: str
  222. content: MessageContent
  223. class ChatRequest(BaseModel):
  224. session_id: int
  225. model: str
  226. content: MessageContent
  227. history_count: int = 0
  228. stream: bool = True
  229. class HistoryActionRequest(BaseModel):
  230. session_id: int
  231. class ExportRequest(BaseModel):
  232. content: MessageContent
  233. class UploadResponseItem(BaseModel):
  234. type: str
  235. filename: str
  236. data: Optional[str] = None
  237. url: Optional[str] = None
  238. # 确保静态与数据目录在应用初始化前存在
  239. ensure_directories()
  240. app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
  241. app.add_middleware(
  242. CORSMiddleware,
  243. allow_origins=["*"],
  244. allow_credentials=True,
  245. allow_methods=["*"],
  246. allow_headers=["*"],
  247. )
  248. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  249. @app.on_event("startup")
  250. async def on_startup() -> None:
  251. ensure_directories()
  252. await asyncio.to_thread(sync_session_counter_with_history)
  253. INDEX_HTML = STATIC_DIR / "index.html"
  254. @app.get("/", response_class=HTMLResponse)
  255. async def serve_index() -> str:
  256. if not INDEX_HTML.exists():
  257. raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
  258. return INDEX_HTML.read_text(encoding="utf-8")
  259. @app.get("/download/{filename}")
  260. async def download_file(filename: str) -> FileResponse:
  261. target = UPLOAD_DIR / filename
  262. if not target.exists():
  263. raise HTTPException(status_code=404, detail="File not found")
  264. return FileResponse(target, filename=filename)
  265. @app.get("/api/config")
  266. async def get_config() -> Dict[str, Any]:
  267. models = list(MODEL_KEYS.keys())
  268. return {
  269. "title": "ChatGPT-like Clone",
  270. "models": models,
  271. "default_model": models[0] if models else "",
  272. "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
  273. "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
  274. }
  275. @app.get("/api/session/latest")
  276. async def api_latest_session() -> Dict[str, Any]:
  277. return await get_latest_session()
  278. @app.get("/api/session/{session_id}")
  279. async def api_get_session(session_id: int) -> Dict[str, Any]:
  280. messages = await load_messages(session_id)
  281. path = history_path(session_id)
  282. if not messages and not path.exists():
  283. raise HTTPException(status_code=404, detail="会话不存在")
  284. return {"session_id": session_id, "messages": messages}
  285. @app.post("/api/session/new")
  286. async def api_new_session() -> Dict[str, Any]:
  287. session_id = await increment_session_id()
  288. await save_messages(session_id, [])
  289. return {"session_id": session_id, "messages": []}
  290. @app.get("/api/history")
  291. async def api_history(page: int = 0, page_size: int = 10) -> Dict[str, Any]:
  292. return await list_history(page, page_size)
  293. @app.post("/api/history/move")
  294. async def api_move_history(payload: HistoryActionRequest) -> Dict[str, Any]:
  295. await move_history_file(payload.session_id)
  296. return {"status": "ok"}
  297. @app.delete("/api/history/{session_id}")
  298. async def api_delete_history(session_id: int) -> Dict[str, Any]:
  299. await delete_history_file(session_id)
  300. return {"status": "ok"}
  301. @app.post("/api/export")
  302. async def api_export_message(payload: ExportRequest) -> Dict[str, Any]:
  303. path = await export_message_to_blog(payload.content)
  304. return {"status": "ok", "path": path}
  305. @app.post("/api/upload")
  306. async def api_upload(files: List[UploadFile] = File(...)) -> List[UploadResponseItem]:
  307. if not files:
  308. return []
  309. responses: List[UploadResponseItem] = []
  310. for upload in files:
  311. filename = upload.filename or "file"
  312. safe_filename = Path(filename).name or "file"
  313. content_type = (upload.content_type or "").lower()
  314. data = await upload.read()
  315. unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
  316. target_path = UPLOAD_DIR / unique_name
  317. def _write() -> None:
  318. with target_path.open("wb") as fp:
  319. fp.write(data)
  320. await asyncio.to_thread(_write)
  321. if content_type.startswith("image/"):
  322. encoded = base64.b64encode(data).decode("utf-8")
  323. data_url = f"data:{content_type};base64,{encoded}"
  324. responses.append(
  325. UploadResponseItem(
  326. type="image",
  327. filename=safe_filename,
  328. data=data_url,
  329. url=build_download_url(unique_name),
  330. )
  331. )
  332. else:
  333. responses.append(
  334. UploadResponseItem(
  335. type="file",
  336. filename=safe_filename,
  337. url=build_download_url(unique_name),
  338. )
  339. )
  340. return responses
  341. async def prepare_messages_for_completion(
  342. messages: List[Dict[str, Any]],
  343. user_content: MessageContent,
  344. history_count: int,
  345. ) -> List[Dict[str, Any]]:
  346. if history_count > 0:
  347. trimmed = messages[-history_count:]
  348. if trimmed:
  349. return trimmed
  350. return [{"role": "user", "content": user_content}]
  351. async def save_assistant_message(session_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
  352. messages.append({"role": "assistant", "content": content})
  353. await save_messages(session_id, messages)
  354. @app.post("/api/chat")
  355. async def api_chat(payload: ChatRequest = Body(...)):
  356. if payload.model not in MODEL_KEYS:
  357. raise HTTPException(status_code=400, detail="未知的模型")
  358. messages = await load_messages(payload.session_id)
  359. user_message = {"role": "user", "content": payload.content}
  360. messages.append(user_message)
  361. await save_messages(payload.session_id, messages)
  362. client.api_key = MODEL_KEYS[payload.model]
  363. to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
  364. if payload.stream:
  365. queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
  366. aggregated: List[str] = []
  367. loop = asyncio.get_running_loop()
  368. def worker() -> None:
  369. try:
  370. response = client.chat.completions.create(
  371. model=payload.model,
  372. messages=to_send,
  373. stream=True,
  374. )
  375. for chunk in response:
  376. try:
  377. delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
  378. except (IndexError, AttributeError):
  379. delta = None
  380. if delta:
  381. aggregated.append(delta)
  382. asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
  383. asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
  384. except Exception as exc: # pragma: no cover - 网络调用
  385. asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
  386. threading.Thread(target=worker, daemon=True).start()
  387. async def streamer():
  388. try:
  389. while True:
  390. item = await queue.get()
  391. if item["type"] == "delta":
  392. yield json.dumps(item, ensure_ascii=False) + "\n"
  393. elif item["type"] == "complete":
  394. assistant_text = "".join(aggregated)
  395. await save_assistant_message(payload.session_id, messages, assistant_text)
  396. yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
  397. break
  398. elif item["type"] == "error":
  399. yield json.dumps(item, ensure_ascii=False) + "\n"
  400. break
  401. except asyncio.CancelledError: # pragma: no cover - 流被取消
  402. raise
  403. return StreamingResponse(streamer(), media_type="application/x-ndjson")
  404. try:
  405. completion = await asyncio.to_thread(
  406. client.chat.completions.create,
  407. model=payload.model,
  408. messages=to_send,
  409. stream=False,
  410. )
  411. except Exception as exc: # pragma: no cover - 网络调用
  412. raise HTTPException(status_code=500, detail=str(exc)) from exc
  413. choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
  414. if not choice:
  415. raise HTTPException(status_code=500, detail="响应格式不正确")
  416. assistant_content = getattr(choice.message, "content", "")
  417. if not assistant_content:
  418. assistant_content = ""
  419. await save_assistant_message(payload.session_id, messages, assistant_content)
  420. return {"message": assistant_content}
  421. if __name__ == "__main__":
  422. import uvicorn
  423. uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)