|
|
@@ -2,9 +2,11 @@
|
|
|
import asyncio
|
|
|
import base64
|
|
|
import json
|
|
|
+import mimetypes
|
|
|
import threading
|
|
|
import uuid
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
+from urllib.parse import urlparse
|
|
|
|
|
|
from fastapi import Body, Depends, FastAPI, HTTPException, UploadFile, File
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
@@ -13,6 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from openai import OpenAI
|
|
|
+from pathlib import Path
|
|
|
|
|
|
from chatfast.api import admin_router, auth_router, export_router
|
|
|
from chatfast.config import API_URL, DOWNLOAD_BASE, MODEL_KEYS, STATIC_DIR, UPLOAD_DIR
|
|
|
@@ -66,8 +69,69 @@ class HistoryActionRequest(BaseModel):
|
|
|
class UploadResponseItem(BaseModel):
|
|
|
type: str
|
|
|
filename: str
|
|
|
- data: Optional[str] = None
|
|
|
url: Optional[str] = None
|
|
|
+ path: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+def _is_data_url(value: str) -> bool:
|
|
|
+ return value.startswith("data:")
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_upload_path(reference: str) -> Optional[Path]:
|
|
|
+ if not reference or _is_data_url(reference):
|
|
|
+ return None
|
|
|
+ parsed = urlparse(reference)
|
|
|
+ candidate = Path(parsed.path).name if parsed.scheme else Path(reference).name
|
|
|
+ if not candidate:
|
|
|
+ return None
|
|
|
+ return UPLOAD_DIR / candidate
|
|
|
+
|
|
|
+
|
|
|
+async def _inline_local_image(reference: str) -> str:
|
|
|
+ if not reference or _is_data_url(reference):
|
|
|
+ return reference
|
|
|
+ file_path = _resolve_upload_path(reference)
|
|
|
+ if not file_path or not file_path.exists():
|
|
|
+ return reference
|
|
|
+ try:
|
|
|
+ data = await asyncio.to_thread(file_path.read_bytes)
|
|
|
+ except OSError:
|
|
|
+ return reference
|
|
|
+ mime = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
|
|
|
+ encoded = base64.b64encode(data).decode("utf-8")
|
|
|
+ return f"data:{mime};base64,{encoded}"
|
|
|
+
|
|
|
+
|
|
|
+async def _prepare_content_for_model(content: MessageContent) -> MessageContent:
|
|
|
+ if not isinstance(content, list):
|
|
|
+ return content
|
|
|
+ prepared: List[Dict[str, Any]] = []
|
|
|
+ for part in content:
|
|
|
+ if not isinstance(part, dict):
|
|
|
+ prepared.append(part)
|
|
|
+ continue
|
|
|
+ if part.get("type") != "image_url":
|
|
|
+ prepared.append({**part})
|
|
|
+ continue
|
|
|
+ image_data = dict(part.get("image_url") or {})
|
|
|
+ url_value = str(image_data.get("url") or "")
|
|
|
+ new_url = await _inline_local_image(url_value)
|
|
|
+ if new_url:
|
|
|
+ image_data["url"] = new_url
|
|
|
+ prepared.append({"type": "image_url", "image_url": image_data})
|
|
|
+ return prepared
|
|
|
+
|
|
|
+
|
|
|
+async def _prepare_messages_for_model(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
+ prepared: List[Dict[str, Any]] = []
|
|
|
+ for message in messages:
|
|
|
+ prepared.append(
|
|
|
+ {
|
|
|
+ "role": message.get("role", ""),
|
|
|
+ "content": await _prepare_content_for_model(message.get("content")),
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return prepared
|
|
|
|
|
|
|
|
|
# 确保静态与数据目录在应用初始化前存在
|
|
|
@@ -191,15 +255,14 @@ async def api_upload(
|
|
|
|
|
|
await asyncio.to_thread(_write)
|
|
|
|
|
|
+ download_url = build_download_url(unique_name)
|
|
|
if content_type.startswith("image/"):
|
|
|
- encoded = base64.b64encode(data).decode("utf-8")
|
|
|
- data_url = f"data:{content_type};base64,{encoded}"
|
|
|
responses.append(
|
|
|
UploadResponseItem(
|
|
|
type="image",
|
|
|
filename=safe_filename,
|
|
|
- data=data_url,
|
|
|
- url=build_download_url(unique_name),
|
|
|
+ url=download_url,
|
|
|
+ path=unique_name,
|
|
|
)
|
|
|
)
|
|
|
else:
|
|
|
@@ -207,7 +270,8 @@ async def api_upload(
|
|
|
UploadResponseItem(
|
|
|
type="file",
|
|
|
filename=safe_filename,
|
|
|
- url=build_download_url(unique_name),
|
|
|
+ url=download_url,
|
|
|
+ path=unique_name,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
@@ -230,6 +294,7 @@ async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = De
|
|
|
client.api_key = MODEL_KEYS[payload.model]
|
|
|
|
|
|
to_send = await prepare_messages_for_completion(messages, payload.content, max(payload.history_count, 0))
|
|
|
+ model_messages = await _prepare_messages_for_model(to_send)
|
|
|
|
|
|
if payload.stream:
|
|
|
queue: "asyncio.Queue[Dict[str, Any]]" = asyncio.Queue()
|
|
|
@@ -240,7 +305,7 @@ async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = De
|
|
|
try:
|
|
|
response = client.chat.completions.create(
|
|
|
model=payload.model,
|
|
|
- messages=to_send,
|
|
|
+ messages=model_messages,
|
|
|
stream=True,
|
|
|
)
|
|
|
for chunk in response:
|
|
|
@@ -282,7 +347,7 @@ async def api_chat(payload: ChatRequest = Body(...), current_user: UserInfo = De
|
|
|
completion = await asyncio.to_thread(
|
|
|
client.chat.completions.create,
|
|
|
model=payload.model,
|
|
|
- messages=to_send,
|
|
|
+ messages=model_messages,
|
|
|
stream=False,
|
|
|
)
|
|
|
except Exception as exc: # pragma: no cover - 网络调用
|