# -*- coding: utf-8 -*- import asyncio import base64 import datetime import json import os import re import shutil import threading import uuid from pathlib import Path from typing import Any, Dict, List, Optional, Union from fastapi import Body, FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from openai import OpenAI # ============================= # 基础配置 # ============================= BASE_DIR = Path(__file__).resolve().parent DATA_DIR = BASE_DIR / "data" BACKUP_DIR = BASE_DIR / "data_bak" BLOG_DIR = BASE_DIR / "blog" UPLOAD_DIR = BASE_DIR / "uploads" STATIC_DIR = BASE_DIR / "static" SESSION_ID_FILE = DATA_DIR / "session_id.txt" # 默认上传文件下载地址,可通过环境变量覆盖 DEFAULT_UPLOAD_BASE = os.getenv("UPLOAD_BASE_URL", "/download/") DOWNLOAD_BASE = DEFAULT_UPLOAD_BASE.rstrip("/") # 与 appchat.py 相同的模型与密钥配置(仅示例) default_key = "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa" MODEL_KEYS: Dict[str, str] = { "grok-3": default_key, "grok-4": default_key, "gpt-5.1-2025-11-13": default_key, "gpt-5-2025-08-07": default_key, "gpt-4o-mini": default_key, # "gpt-4.1-mini-2025-04-14": default_key, "o1-mini": default_key, "o4-mini": default_key, "deepseek-v3": default_key, "deepseek-r1": default_key, "gpt-4o-all": default_key, # "gpt-5-mini-2025-08-07": default_key, "o3-mini-all": default_key, } API_URL = "https://yunwu.ai/v1" client = OpenAI(api_key=default_key, base_url=API_URL) # 锁用于避免并发文件写入导致的数据损坏 FILE_LOCK = asyncio.Lock() SESSION_LOCK = asyncio.Lock() MessageContent = Union[str, List[Dict[str, Any]]] SESSION_FILE_PATTERN = re.compile(r"chat_history_(\d+)\.json") def ensure_directories() -> None: for path in [DATA_DIR, BACKUP_DIR, BLOG_DIR, UPLOAD_DIR, STATIC_DIR]: path.mkdir(parents=True, exist_ok=True) def extract_session_id(path: Path) -> Optional[int]: match = SESSION_FILE_PATTERN.search(path.name) if match: try: return int(match.group(1)) except ValueError: return None return None def load_session_counter() -> int: if SESSION_ID_FILE.exists(): try: value = SESSION_ID_FILE.read_text(encoding="utf-8").strip() return int(value) if value.isdigit() else 0 except Exception: return 0 return 0 def save_session_counter(value: int) -> None: SESSION_ID_FILE.write_text(str(value), encoding="utf-8") def sync_session_counter_with_history() -> None: max_session = load_session_counter() for path in DATA_DIR.glob("chat_history_*.json"): session_id = extract_session_id(path) if session_id is not None and session_id > max_session: max_session = session_id save_session_counter(max_session) def text_from_content(content: MessageContent) -> str: if isinstance(content, str): return content if isinstance(content, list): pieces: List[str] = [] for part in content: if part.get("type") == "text": pieces.append(part.get("text", "")) return " ".join(pieces) return str(content) def extract_history_title(messages: List[Dict[str, Any]]) -> str: """Return the first meaningful title extracted from user messages.""" for message in messages: if message.get("role") != "user": continue title = text_from_content(message.get("content", "")).strip() if title: return title[:10] if messages: fallback = text_from_content(messages[0].get("content", "")).strip() if fallback: return fallback[:10] return "空的聊天"[:10] def history_path(session_id: int) -> Path: return DATA_DIR / f"chat_history_{session_id}.json" def build_download_url(filename: str) -> str: base = DOWNLOAD_BASE or "" return f"{base}/{filename}" if base else filename async def read_json_file(path: Path) -> List[Dict[str, Any]]: def _read() -> List[Dict[str, Any]]: with path.open("r", encoding="utf-8") as fp: return json.load(fp) return await asyncio.to_thread(_read) async def write_json_file(path: Path, payload: List[Dict[str, Any]]) -> None: serialized = json.dumps(payload, ensure_ascii=False) def _write() -> None: with path.open("w", encoding="utf-8") as fp: fp.write(serialized) async with FILE_LOCK: await asyncio.to_thread(_write) async def load_messages(session_id: int) -> List[Dict[str, Any]]: path = history_path(session_id) if not path.exists(): return [] try: return await read_json_file(path) except Exception: return [] async def save_messages(session_id: int, messages: List[Dict[str, Any]]) -> None: path = history_path(session_id) path.parent.mkdir(parents=True, exist_ok=True) await write_json_file(path, messages) async def get_latest_session() -> Dict[str, Any]: history_files = sorted(DATA_DIR.glob("chat_history_*.json"), key=lambda p: p.stat().st_mtime) if history_files: latest = history_files[-1] session_id = extract_session_id(latest) if session_id is None: session_id = await asyncio.to_thread(load_session_counter) try: messages = await read_json_file(latest) except Exception: messages = [] return {"session_id": session_id, "messages": messages} session_id = await asyncio.to_thread(load_session_counter) return {"session_id": session_id, "messages": []} async def increment_session_id() -> int: async with SESSION_LOCK: current = await asyncio.to_thread(load_session_counter) next_session = current + 1 await asyncio.to_thread(save_session_counter, next_session) return next_session async def list_history(page: int, page_size: int) -> Dict[str, Any]: files = sorted(DATA_DIR.glob("chat_history_*.json"), key=lambda p: p.stat().st_mtime, reverse=True) total = len(files) start = max(page, 0) * page_size end = start + page_size items: List[Dict[str, Any]] = [] for path in files[start:end]: session_id = extract_session_id(path) if session_id is None: continue try: messages = await read_json_file(path) except Exception: messages = [] title = extract_history_title(messages) items.append({ "session_id": session_id, "title": title, "updated_at": datetime.datetime.fromtimestamp(path.stat().st_mtime).isoformat(), "filename": path.name, }) return { "page": page, "page_size": page_size, "total": total, "items": items, } async def move_history_file(session_id: int) -> None: src = history_path(session_id) if not src.exists(): raise HTTPException(status_code=404, detail="历史记录不存在") dst = BACKUP_DIR / src.name dst.parent.mkdir(parents=True, exist_ok=True) def _move() -> None: shutil.move(str(src), str(dst)) async with FILE_LOCK: await asyncio.to_thread(_move) async def delete_history_file(session_id: int) -> None: target = history_path(session_id) if not target.exists(): raise HTTPException(status_code=404, detail="历史记录不存在") def _delete() -> None: target.unlink(missing_ok=True) async with FILE_LOCK: await asyncio.to_thread(_delete) async def export_message_to_blog(content: MessageContent) -> str: processed = text_from_content(content) processed = processed.replace("\r\n", "\n") timestamp = datetime.datetime.now().strftime("%m%d%H%M") first_10 = ( processed[:10] .replace(" ", "") .replace("/", "") .replace("\\", "") .replace(":", "") .replace("`", "") ) filename = f"{timestamp}_{first_10 or 'export'}.txt" path = BLOG_DIR / filename def _write() -> None: with path.open("w", encoding="utf-8") as fp: fp.write(processed) await asyncio.to_thread(_write) return str(path) class MessageModel(BaseModel): role: str content: MessageContent class ChatRequest(BaseModel): session_id: int model: str content: MessageContent history_count: int = 0 stream: bool = True class HistoryActionRequest(BaseModel): session_id: int class ExportRequest(BaseModel): content: MessageContent class UploadResponseItem(BaseModel): type: str filename: str data: Optional[str] = None url: Optional[str] = None # 确保静态与数据目录在应用初始化前存在 ensure_directories() app = FastAPI(title="ChatGPT-like Clone", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.on_event("startup") async def on_startup() -> None: ensure_directories() await asyncio.to_thread(sync_session_counter_with_history) INDEX_HTML = STATIC_DIR / "index.html" @app.get("/", response_class=HTMLResponse) async def serve_index() -> str: if not INDEX_HTML.exists(): raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在") return INDEX_HTML.read_text(encoding="utf-8") @app.get("/download/{filename}") async def download_file(filename: str) -> FileResponse: target = UPLOAD_DIR / filename if not target.exists(): raise HTTPException(status_code=404, detail="File not found") return FileResponse(target, filename=filename) @app.get("/api/config") async def get_config() -> Dict[str, Any]: models = list(MODEL_KEYS.keys()) return { "title": "ChatGPT-like Clone", "models": models, "default_model": models[0] if models else "", "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"], "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "", } @app.get("/api/session/latest") async def api_latest_session() -> Dict[str, Any]: return await get_latest_session() @app.get("/api/session/{session_id}") async def api_get_session(session_id: int) -> Dict[str, Any]: messages = await load_messages(session_id) path = history_path(session_id) if not messages and not path.exists(): raise HTTPException(status_code=404, detail="会话不存在") return {"session_id": session_id, "messages": messages} @app.post("/api/session/new") async def api_new_session() -> Dict[str, Any]: session_id = await increment_session_id() await save_messages(session_id, []) return {"session_id": session_id, "messages": []} @app.get("/api/history") async def api_history(page: int = 0, page_size: int = 10) -> Dict[str, Any]: return await list_history(page, page_size) @app.post("/api/history/move") async def api_move_history(payload: HistoryActionRequest) -> Dict[str, Any]: await move_history_file(payload.session_id) return {"status": "ok"} @app.delete("/api/history/{session_id}") async def api_delete_history(session_id: int) -> Dict[str, Any]: await delete_history_file(session_id) return {"status": "ok"} @app.post("/api/export") async def api_export_message(payload: ExportRequest) -> Dict[str, Any]: path = await export_message_to_blog(payload.content) return {"status": "ok", "path": path} @app.post("/api/upload") async def api_upload(files: List[UploadFile] = File(...)) -> List[UploadResponseItem]: if not files: return [] responses: List[UploadResponseItem] = [] for upload in files: filename = upload.filename or "file" safe_filename = Path(filename).name or "file" content_type = (upload.content_type or "").lower() data = await upload.read() unique_name = f"{uuid.uuid4().hex}_{safe_filename}" target_path = UPLOAD_DIR / unique_name def _write() -> None: with target_path.open("wb") as fp: fp.write(data) await asyncio.to_thread(_write) if content_type.startswith("image/"): encoded = base64.b64encode(data).decode("utf-8") data_url = f"data:{content_type};base64,{encoded}" responses.append( UploadResponseItem( type="image", filename=safe_filename, data=data_url, url=build_download_url(unique_name), ) ) else: responses.append( UploadResponseItem( type="file", filename=safe_filename, url=build_download_url(unique_name), ) ) return responses async def prepare_messages_for_completion( messages: List[Dict[str, Any]], user_content: MessageContent, history_count: int, ) -> List[Dict[str, Any]]: if history_count > 0: trimmed = messages[-history_count:] if trimmed: return trimmed return [{"role": "user", "content": user_content}] async def save_assistant_message(session_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None: messages.append({"role": "assistant", "content": content}) await save_messages(session_id, messages) @app.post("/api/chat") async def api_chat(payload: ChatRequest = Body(...)): if payload.model not in MODEL_KEYS: raise HTTPException(status_code=400, detail="未知的模型") messages = await load_messages(payload.session_id) user_message = {"role": "user", "content": payload.content} messages.append(user_message) await save_messages(payload.session_id, messages) client.api_key = MODEL_KEYS[payload.model] to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0)) if payload.stream: queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue() aggregated: List[str] = [] loop = asyncio.get_running_loop() def worker() -> None: try: response = client.chat.completions.create( model=payload.model, messages=to_send, stream=True, ) for chunk in response: try: delta = chunk.choices[0].delta.content # type: ignore[attr-defined] except (IndexError, AttributeError): delta = None if delta: aggregated.append(delta) asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop) asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop) except Exception as exc: # pragma: no cover - 网络调用 asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop) threading.Thread(target=worker, daemon=True).start() async def streamer(): try: while True: item = await queue.get() if item["type"] == "delta": yield json.dumps(item, ensure_ascii=False) + "\n" elif item["type"] == "complete": assistant_text = "".join(aggregated) await save_assistant_message(payload.session_id, messages, assistant_text) yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n" break elif item["type"] == "error": yield json.dumps(item, ensure_ascii=False) + "\n" break except asyncio.CancelledError: # pragma: no cover - 流被取消 raise return StreamingResponse(streamer(), media_type="application/x-ndjson") try: completion = await asyncio.to_thread( client.chat.completions.create, model=payload.model, messages=to_send, stream=False, ) except Exception as exc: # pragma: no cover - 网络调用 raise HTTPException(status_code=500, detail=str(exc)) from exc choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined] if not choice: raise HTTPException(status_code=500, detail="响应格式不正确") assistant_content = getattr(choice.message, "content", "") if not assistant_content: assistant_content = "" await save_assistant_message(payload.session_id, messages, assistant_content) return {"message": assistant_content} if __name__ == "__main__": import uvicorn uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)