| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552 |
- # -*- coding: utf-8 -*-
- import asyncio
- import base64
- import datetime
- import json
- import os
- import re
- import shutil
- import threading
- import uuid
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Union
- from fastapi import Body, FastAPI, HTTPException, UploadFile, File
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse, HTMLResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from pydantic import BaseModel
- from openai import OpenAI
- # =============================
- # 基础配置
- # =============================
- BASE_DIR = Path(__file__).resolve().parent
- DATA_DIR = BASE_DIR / "data"
- BACKUP_DIR = BASE_DIR / "data_bak"
- BLOG_DIR = BASE_DIR / "blog"
- UPLOAD_DIR = BASE_DIR / "uploads"
- STATIC_DIR = BASE_DIR / "static"
- SESSION_ID_FILE = DATA_DIR / "session_id.txt"
- # 默认上传文件下载地址,可通过环境变量覆盖
- DEFAULT_UPLOAD_BASE = os.getenv("UPLOAD_BASE_URL", "/download/")
- DOWNLOAD_BASE = DEFAULT_UPLOAD_BASE.rstrip("/")
- # 与 appchat.py 相同的模型与密钥配置(仅示例)
- default_key = "sk-re2NlaKIQn11ZNWzAbB6339cEbF94c6aAfC8B7Ab82879bEa"
- MODEL_KEYS: Dict[str, str] = {
- "grok-3": default_key,
- "grok-4": default_key,
- "gpt-5.1-2025-11-13": default_key,
- "gpt-5-2025-08-07": default_key,
- "gpt-4o-mini": default_key,
- # "gpt-4.1-mini-2025-04-14": default_key,
- "o1-mini": default_key,
- "o4-mini": default_key,
- "deepseek-v3": default_key,
- "deepseek-r1": default_key,
- "gpt-4o-all": default_key,
- # "gpt-5-mini-2025-08-07": default_key,
- "o3-mini-all": default_key,
- }
- API_URL = "https://yunwu.ai/v1"
- client = OpenAI(api_key=default_key, base_url=API_URL)
- # 锁用于避免并发文件写入导致的数据损坏
- FILE_LOCK = asyncio.Lock()
- SESSION_LOCK = asyncio.Lock()
- MessageContent = Union[str, List[Dict[str, Any]]]
- SESSION_FILE_PATTERN = re.compile(r"chat_history_(\d+)\.json")
- def ensure_directories() -> None:
- for path in [DATA_DIR, BACKUP_DIR, BLOG_DIR, UPLOAD_DIR, STATIC_DIR]:
- path.mkdir(parents=True, exist_ok=True)
- def extract_session_id(path: Path) -> Optional[int]:
- match = SESSION_FILE_PATTERN.search(path.name)
- if match:
- try:
- return int(match.group(1))
- except ValueError:
- return None
- return None
- def load_session_counter() -> int:
- if SESSION_ID_FILE.exists():
- try:
- value = SESSION_ID_FILE.read_text(encoding="utf-8").strip()
- return int(value) if value.isdigit() else 0
- except Exception:
- return 0
- return 0
- def save_session_counter(value: int) -> None:
- SESSION_ID_FILE.write_text(str(value), encoding="utf-8")
- def sync_session_counter_with_history() -> None:
- max_session = load_session_counter()
- for path in DATA_DIR.glob("chat_history_*.json"):
- session_id = extract_session_id(path)
- if session_id is not None and session_id > max_session:
- max_session = session_id
- save_session_counter(max_session)
- def text_from_content(content: MessageContent) -> str:
- if isinstance(content, str):
- return content
- if isinstance(content, list):
- pieces: List[str] = []
- for part in content:
- if part.get("type") == "text":
- pieces.append(part.get("text", ""))
- return " ".join(pieces)
- return str(content)
- def extract_history_title(messages: List[Dict[str, Any]]) -> str:
- """Return the first meaningful title extracted from user messages."""
- for message in messages:
- if message.get("role") != "user":
- continue
- title = text_from_content(message.get("content", "")).strip()
- if title:
- return title[:10]
- if messages:
- fallback = text_from_content(messages[0].get("content", "")).strip()
- if fallback:
- return fallback[:10]
- return "空的聊天"[:10]
- def history_path(session_id: int) -> Path:
- return DATA_DIR / f"chat_history_{session_id}.json"
- def build_download_url(filename: str) -> str:
- base = DOWNLOAD_BASE or ""
- return f"{base}/{filename}" if base else filename
- async def read_json_file(path: Path) -> List[Dict[str, Any]]:
- def _read() -> List[Dict[str, Any]]:
- with path.open("r", encoding="utf-8") as fp:
- return json.load(fp)
- return await asyncio.to_thread(_read)
- async def write_json_file(path: Path, payload: List[Dict[str, Any]]) -> None:
- serialized = json.dumps(payload, ensure_ascii=False)
- def _write() -> None:
- with path.open("w", encoding="utf-8") as fp:
- fp.write(serialized)
- async with FILE_LOCK:
- await asyncio.to_thread(_write)
- async def load_messages(session_id: int) -> List[Dict[str, Any]]:
- path = history_path(session_id)
- if not path.exists():
- return []
- try:
- return await read_json_file(path)
- except Exception:
- return []
- async def save_messages(session_id: int, messages: List[Dict[str, Any]]) -> None:
- path = history_path(session_id)
- path.parent.mkdir(parents=True, exist_ok=True)
- await write_json_file(path, messages)
- async def get_latest_session() -> Dict[str, Any]:
- history_files = sorted(DATA_DIR.glob("chat_history_*.json"), key=lambda p: p.stat().st_mtime)
- if history_files:
- latest = history_files[-1]
- session_id = extract_session_id(latest)
- if session_id is None:
- session_id = await asyncio.to_thread(load_session_counter)
- try:
- messages = await read_json_file(latest)
- except Exception:
- messages = []
- return {"session_id": session_id, "messages": messages}
- session_id = await asyncio.to_thread(load_session_counter)
- return {"session_id": session_id, "messages": []}
- async def increment_session_id() -> int:
- async with SESSION_LOCK:
- current = await asyncio.to_thread(load_session_counter)
- next_session = current + 1
- await asyncio.to_thread(save_session_counter, next_session)
- return next_session
- async def list_history(page: int, page_size: int) -> Dict[str, Any]:
- files = sorted(DATA_DIR.glob("chat_history_*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
- total = len(files)
- start = max(page, 0) * page_size
- end = start + page_size
- items: List[Dict[str, Any]] = []
- for path in files[start:end]:
- session_id = extract_session_id(path)
- if session_id is None:
- continue
- try:
- messages = await read_json_file(path)
- except Exception:
- messages = []
- title = extract_history_title(messages)
- items.append({
- "session_id": session_id,
- "title": title,
- "updated_at": datetime.datetime.fromtimestamp(path.stat().st_mtime).isoformat(),
- "filename": path.name,
- })
- return {
- "page": page,
- "page_size": page_size,
- "total": total,
- "items": items,
- }
- async def move_history_file(session_id: int) -> None:
- src = history_path(session_id)
- if not src.exists():
- raise HTTPException(status_code=404, detail="历史记录不存在")
- dst = BACKUP_DIR / src.name
- dst.parent.mkdir(parents=True, exist_ok=True)
- def _move() -> None:
- shutil.move(str(src), str(dst))
- async with FILE_LOCK:
- await asyncio.to_thread(_move)
- async def delete_history_file(session_id: int) -> None:
- target = history_path(session_id)
- if not target.exists():
- raise HTTPException(status_code=404, detail="历史记录不存在")
- def _delete() -> None:
- target.unlink(missing_ok=True)
- async with FILE_LOCK:
- await asyncio.to_thread(_delete)
- async def export_message_to_blog(content: MessageContent) -> str:
- processed = text_from_content(content)
- processed = processed.replace("\r\n", "\n")
- timestamp = datetime.datetime.now().strftime("%m%d%H%M")
- first_10 = (
- processed[:10]
- .replace(" ", "")
- .replace("/", "")
- .replace("\\", "")
- .replace(":", "")
- .replace("`", "")
- )
- filename = f"{timestamp}_{first_10 or 'export'}.txt"
- path = BLOG_DIR / filename
- def _write() -> None:
- with path.open("w", encoding="utf-8") as fp:
- fp.write(processed)
- await asyncio.to_thread(_write)
- return str(path)
- class MessageModel(BaseModel):
- role: str
- content: MessageContent
- class ChatRequest(BaseModel):
- session_id: int
- model: str
- content: MessageContent
- history_count: int = 0
- stream: bool = True
- class HistoryActionRequest(BaseModel):
- session_id: int
- class ExportRequest(BaseModel):
- content: MessageContent
- class UploadResponseItem(BaseModel):
- type: str
- filename: str
- data: Optional[str] = None
- url: Optional[str] = None
- # 确保静态与数据目录在应用初始化前存在
- ensure_directories()
- app = FastAPI(title="ChatGPT-like Clone", version="1.0.0")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
- @app.on_event("startup")
- async def on_startup() -> None:
- ensure_directories()
- await asyncio.to_thread(sync_session_counter_with_history)
- INDEX_HTML = STATIC_DIR / "index.html"
- @app.get("/", response_class=HTMLResponse)
- async def serve_index() -> str:
- if not INDEX_HTML.exists():
- raise HTTPException(status_code=404, detail="UI 未找到,请确认 static/index.html 是否存在")
- return INDEX_HTML.read_text(encoding="utf-8")
- @app.get("/download/{filename}")
- async def download_file(filename: str) -> FileResponse:
- target = UPLOAD_DIR / filename
- if not target.exists():
- raise HTTPException(status_code=404, detail="File not found")
- return FileResponse(target, filename=filename)
- @app.get("/api/config")
- async def get_config() -> Dict[str, Any]:
- models = list(MODEL_KEYS.keys())
- return {
- "title": "ChatGPT-like Clone",
- "models": models,
- "default_model": models[0] if models else "",
- "output_modes": ["流式输出 (Stream)", "非流式输出 (Non-stream)"],
- "upload_base_url": DOWNLOAD_BASE + "/" if DOWNLOAD_BASE else "",
- }
- @app.get("/api/session/latest")
- async def api_latest_session() -> Dict[str, Any]:
- return await get_latest_session()
- @app.get("/api/session/{session_id}")
- async def api_get_session(session_id: int) -> Dict[str, Any]:
- messages = await load_messages(session_id)
- path = history_path(session_id)
- if not messages and not path.exists():
- raise HTTPException(status_code=404, detail="会话不存在")
- return {"session_id": session_id, "messages": messages}
- @app.post("/api/session/new")
- async def api_new_session() -> Dict[str, Any]:
- session_id = await increment_session_id()
- await save_messages(session_id, [])
- return {"session_id": session_id, "messages": []}
- @app.get("/api/history")
- async def api_history(page: int = 0, page_size: int = 10) -> Dict[str, Any]:
- return await list_history(page, page_size)
- @app.post("/api/history/move")
- async def api_move_history(payload: HistoryActionRequest) -> Dict[str, Any]:
- await move_history_file(payload.session_id)
- return {"status": "ok"}
- @app.delete("/api/history/{session_id}")
- async def api_delete_history(session_id: int) -> Dict[str, Any]:
- await delete_history_file(session_id)
- return {"status": "ok"}
- @app.post("/api/export")
- async def api_export_message(payload: ExportRequest) -> Dict[str, Any]:
- path = await export_message_to_blog(payload.content)
- return {"status": "ok", "path": path}
- @app.post("/api/upload")
- async def api_upload(files: List[UploadFile] = File(...)) -> List[UploadResponseItem]:
- if not files:
- return []
- responses: List[UploadResponseItem] = []
- for upload in files:
- filename = upload.filename or "file"
- safe_filename = Path(filename).name or "file"
- content_type = (upload.content_type or "").lower()
- data = await upload.read()
- unique_name = f"{uuid.uuid4().hex}_{safe_filename}"
- target_path = UPLOAD_DIR / unique_name
- def _write() -> None:
- with target_path.open("wb") as fp:
- fp.write(data)
- await asyncio.to_thread(_write)
- if content_type.startswith("image/"):
- encoded = base64.b64encode(data).decode("utf-8")
- data_url = f"data:{content_type};base64,{encoded}"
- responses.append(
- UploadResponseItem(
- type="image",
- filename=safe_filename,
- data=data_url,
- url=build_download_url(unique_name),
- )
- )
- else:
- responses.append(
- UploadResponseItem(
- type="file",
- filename=safe_filename,
- url=build_download_url(unique_name),
- )
- )
- return responses
- async def prepare_messages_for_completion(
- messages: List[Dict[str, Any]],
- user_content: MessageContent,
- history_count: int,
- ) -> List[Dict[str, Any]]:
- if history_count > 0:
- trimmed = messages[-history_count:]
- if trimmed:
- return trimmed
- return [{"role": "user", "content": user_content}]
- async def save_assistant_message(session_id: int, messages: List[Dict[str, Any]], content: MessageContent) -> None:
- messages.append({"role": "assistant", "content": content})
- await save_messages(session_id, messages)
- @app.post("/api/chat")
- async def api_chat(payload: ChatRequest = Body(...)):
- if payload.model not in MODEL_KEYS:
- raise HTTPException(status_code=400, detail="未知的模型")
- messages = await load_messages(payload.session_id)
- user_message = {"role": "user", "content": payload.content}
- messages.append(user_message)
- await save_messages(payload.session_id, messages)
- client.api_key = MODEL_KEYS[payload.model]
- to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
- if payload.stream:
- queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
- aggregated: List[str] = []
- loop = asyncio.get_running_loop()
- def worker() -> None:
- try:
- response = client.chat.completions.create(
- model=payload.model,
- messages=to_send,
- stream=True,
- )
- for chunk in response:
- try:
- delta = chunk.choices[0].delta.content # type: ignore[attr-defined]
- except (IndexError, AttributeError):
- delta = None
- if delta:
- aggregated.append(delta)
- asyncio.run_coroutine_threadsafe(queue.put({"type": "delta", "text": delta}), loop)
- asyncio.run_coroutine_threadsafe(queue.put({"type": "complete"}), loop)
- except Exception as exc: # pragma: no cover - 网络调用
- asyncio.run_coroutine_threadsafe(queue.put({"type": "error", "message": str(exc)}), loop)
- threading.Thread(target=worker, daemon=True).start()
- async def streamer():
- try:
- while True:
- item = await queue.get()
- if item["type"] == "delta":
- yield json.dumps(item, ensure_ascii=False) + "\n"
- elif item["type"] == "complete":
- assistant_text = "".join(aggregated)
- await save_assistant_message(payload.session_id, messages, assistant_text)
- yield json.dumps({"type": "end"}, ensure_ascii=False) + "\n"
- break
- elif item["type"] == "error":
- yield json.dumps(item, ensure_ascii=False) + "\n"
- break
- except asyncio.CancelledError: # pragma: no cover - 流被取消
- raise
- return StreamingResponse(streamer(), media_type="application/x-ndjson")
- try:
- completion = await asyncio.to_thread(
- client.chat.completions.create,
- model=payload.model,
- messages=to_send,
- stream=False,
- )
- except Exception as exc: # pragma: no cover - 网络调用
- raise HTTPException(status_code=500, detail=str(exc)) from exc
- choice = completion.choices[0] if getattr(completion, "choices", None) else None # type: ignore[attr-defined]
- if not choice:
- raise HTTPException(status_code=500, detail="响应格式不正确")
- assistant_content = getattr(choice.message, "content", "")
- if not assistant_content:
- assistant_content = ""
- await save_assistant_message(payload.session_id, messages, assistant_content)
- return {"message": assistant_content}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("fastchat:app", host="0.0.0.0", port=16016, reload=True)
|