894 lines
31 KiB
Python
894 lines
31 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
|
|
from fastapi.responses import JSONResponse, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|
|
|
logger = logging.getLogger("gateway")
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
|
|
|
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
|
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
|
|
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
|
|
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
|
|
CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000")
|
|
MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000")
|
|
MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no")
|
|
LLM_URL = os.getenv("LLM_URL", "http://llm:8080")
|
|
LLM_ENABLED = os.getenv("LLM_ENABLED", "false").lower() not in ("0", "false", "no")
|
|
LLM_DEFAULT_MODEL = os.getenv("LLM_DEFAULT_MODEL", "qwen3-1.7b-instruct-q4_k_m")
|
|
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120"))
|
|
LLM_MAX_TOKENS_HARD_LIMIT = max(1, int(os.getenv("LLM_MAX_TOKENS_HARD_LIMIT", "1024")))
|
|
LLM_MAX_TOKENS_DEFAULT = min(
|
|
LLM_MAX_TOKENS_HARD_LIMIT,
|
|
max(1, int(os.getenv("LLM_MAX_TOKENS_DEFAULT", "256"))),
|
|
)
|
|
LLM_MAX_REQUEST_BYTES = max(1024, int(os.getenv("LLM_MAX_REQUEST_BYTES", "65536")))
|
|
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
|
|
|
|
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
|
|
API_KEY = os.getenv("API_KEY")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared persistent HTTP client — created once at startup, reused across all
|
|
# requests. This eliminates per-request TCP connect + DNS latency (the main
|
|
# cause of 20 s first-request latency observed on /vectors/inspect).
|
|
# ---------------------------------------------------------------------------
|
|
_http_client: httpx.AsyncClient | None = None
|
|
|
|
|
|
class LLMGatewayError(Exception):
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
code: str,
|
|
message: str,
|
|
details: Optional[Any] = None,
|
|
):
|
|
self.status_code = status_code
|
|
self.code = code
|
|
self.message = message
|
|
self.details = details
|
|
super().__init__(message)
|
|
|
|
|
|
def get_http_client() -> httpx.AsyncClient:
|
|
"""Return the shared httpx client. Raises if called before lifespan starts."""
|
|
if _http_client is None:
|
|
raise RuntimeError("HTTP client not initialised — lifespan not running")
|
|
return _http_client
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan: create shared HTTP client and warm upstream connections."""
|
|
global _http_client
|
|
|
|
t0 = time.perf_counter()
|
|
logger.info("gateway startup: creating shared HTTP client")
|
|
|
|
limits = httpx.Limits(
|
|
max_connections=100,
|
|
max_keepalive_connections=20,
|
|
keepalive_expiry=30,
|
|
)
|
|
_http_client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(VISION_TIMEOUT, connect=10),
|
|
limits=limits,
|
|
)
|
|
|
|
# Warm the qdrant-svc connection so the first real request does not pay
|
|
# the TCP handshake + DNS cost. Failure is non-fatal — the service may
|
|
# still be starting when the gateway starts.
|
|
try:
|
|
t_warm = time.perf_counter()
|
|
r = await _http_client.get(f"{QDRANT_SVC_URL}/health", timeout=8)
|
|
logger.info(
|
|
"gateway startup: qdrant-svc warm ping done status=%s elapsed_ms=%.1f",
|
|
r.status_code, (time.perf_counter() - t_warm) * 1000,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc)
|
|
|
|
if LLM_ENABLED:
|
|
try:
|
|
t_warm = time.perf_counter()
|
|
r = await _http_client.get(f"{LLM_URL}/health", timeout=min(LLM_TIMEOUT, 10))
|
|
logger.info(
|
|
"gateway startup: llm warm ping done status=%s elapsed_ms=%.1f",
|
|
r.status_code, (time.perf_counter() - t_warm) * 1000,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("gateway startup: llm warm ping failed (non-fatal): %s", exc)
|
|
|
|
logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
|
|
|
yield # application runs
|
|
|
|
logger.info("gateway shutdown: closing shared HTTP client")
|
|
await _http_client.aclose()
|
|
_http_client = None
|
|
|
|
|
|
class APIKeyMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
# allow health and docs endpoints without API key
|
|
if request.url.path in ("/health", "/openapi.json", "/docs", "/redoc"):
|
|
return await call_next(request)
|
|
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
|
|
if not API_KEY or key != API_KEY:
|
|
if _is_llm_path(request.url.path):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"error": {"code": "unauthorized", "message": "Unauthorized"}},
|
|
)
|
|
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
|
return await call_next(request)
|
|
|
|
|
|
def _is_llm_path(path: str) -> bool:
|
|
return path.startswith("/v1/") or path.startswith("/ai/")
|
|
|
|
|
|
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan)
|
|
app.add_middleware(APIKeyMiddleware)
|
|
|
|
|
|
@app.exception_handler(LLMGatewayError)
|
|
async def handle_llm_gateway_error(_: Request, exc: LLMGatewayError):
|
|
error: Dict[str, Any] = {"code": exc.code, "message": exc.message}
|
|
if exc.details is not None:
|
|
error["details"] = exc.details
|
|
return JSONResponse(status_code=exc.status_code, content={"error": error})
|
|
|
|
|
|
class ClipRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
limit: int = Field(default=5, ge=1, le=50)
|
|
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
|
|
|
|
class BlipRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
variants: int = Field(default=3, ge=0, le=10)
|
|
max_length: int = Field(default=60, ge=10, le=200)
|
|
|
|
|
|
class YoloRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
|
|
|
|
|
class MaturityRequest(BaseModel):
|
|
url: Optional[str] = None
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: Literal["system", "user", "assistant"]
|
|
content: str
|
|
|
|
@field_validator("content")
|
|
@classmethod
|
|
def validate_content(cls, value: str) -> str:
|
|
if not value or not value.strip():
|
|
raise ValueError("message content must not be empty")
|
|
return value
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: Optional[str] = None
|
|
messages: List[ChatMessage] = Field(min_length=1, max_length=100)
|
|
temperature: Optional[float] = None
|
|
max_tokens: Optional[int] = Field(default=None, ge=1)
|
|
stream: bool = False
|
|
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
stop: Optional[str | List[str]] = None
|
|
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
|
frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
|
|
|
@field_validator("model")
|
|
@classmethod
|
|
def validate_model(cls, value: Optional[str]) -> Optional[str]:
|
|
if value is None:
|
|
return value
|
|
model = value.strip()
|
|
if not model:
|
|
raise ValueError("model must not be empty")
|
|
return model
|
|
|
|
@field_validator("temperature")
|
|
@classmethod
|
|
def validate_temperature(cls, value: Optional[float]) -> Optional[float]:
|
|
if value is None:
|
|
return value
|
|
if value < 0.0 or value > 2.0:
|
|
raise ValueError("temperature must be between 0 and 2")
|
|
return value
|
|
|
|
|
|
def _llm_timeout() -> httpx.Timeout:
|
|
return httpx.Timeout(LLM_TIMEOUT, connect=min(LLM_TIMEOUT, 10))
|
|
|
|
|
|
def _assert_llm_enabled() -> None:
|
|
if not LLM_ENABLED:
|
|
raise LLMGatewayError(503, "llm_disabled", "LLM service is disabled")
|
|
|
|
|
|
def _extract_upstream_error_message(response: httpx.Response) -> str:
|
|
try:
|
|
payload = response.json()
|
|
except Exception:
|
|
payload = None
|
|
|
|
if isinstance(payload, dict):
|
|
error = payload.get("error")
|
|
if isinstance(error, dict) and error.get("message"):
|
|
return str(error["message"])
|
|
if payload.get("message"):
|
|
return str(payload["message"])
|
|
if payload.get("detail"):
|
|
return str(payload["detail"])
|
|
|
|
text = response.text.strip()
|
|
return text[:500] if text else f"Upstream returned HTTP {response.status_code}"
|
|
|
|
|
|
def _map_upstream_llm_status(status_code: int) -> int:
|
|
if status_code in (400, 413, 422):
|
|
return status_code
|
|
if 400 <= status_code < 500:
|
|
return 422
|
|
return 503
|
|
|
|
|
|
def _normalize_chat_payload(payload: ChatCompletionRequest) -> Dict[str, Any]:
|
|
normalized = payload.model_dump(exclude_none=True)
|
|
normalized["model"] = normalized.get("model") or LLM_DEFAULT_MODEL
|
|
normalized["max_tokens"] = min(
|
|
int(normalized.get("max_tokens") or LLM_MAX_TOKENS_DEFAULT),
|
|
LLM_MAX_TOKENS_HARD_LIMIT,
|
|
)
|
|
|
|
if "temperature" in normalized:
|
|
normalized["temperature"] = max(0.0, min(2.0, float(normalized["temperature"])))
|
|
|
|
if normalized.get("stream"):
|
|
raise LLMGatewayError(
|
|
422,
|
|
"streaming_not_supported",
|
|
"Streaming responses are not enabled for this gateway",
|
|
)
|
|
|
|
return normalized
|
|
|
|
|
|
async def _parse_llm_request(request: Request) -> ChatCompletionRequest:
|
|
content_length = request.headers.get("content-length")
|
|
if content_length:
|
|
try:
|
|
if int(content_length) > LLM_MAX_REQUEST_BYTES:
|
|
raise LLMGatewayError(
|
|
413,
|
|
"payload_too_large",
|
|
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
|
|
)
|
|
except ValueError:
|
|
raise LLMGatewayError(400, "invalid_request", "Invalid Content-Length header")
|
|
|
|
body = await request.body()
|
|
if not body:
|
|
raise LLMGatewayError(400, "invalid_request", "Request body is required")
|
|
if len(body) > LLM_MAX_REQUEST_BYTES:
|
|
raise LLMGatewayError(
|
|
413,
|
|
"payload_too_large",
|
|
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
|
|
)
|
|
|
|
try:
|
|
payload = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
raise LLMGatewayError(400, "invalid_json", "Request body must be valid JSON")
|
|
|
|
if not isinstance(payload, dict):
|
|
raise LLMGatewayError(400, "invalid_request", "JSON body must be an object")
|
|
|
|
try:
|
|
return ChatCompletionRequest.model_validate(payload)
|
|
except ValidationError as exc:
|
|
raise LLMGatewayError(422, "validation_error", "Invalid chat request", exc.errors())
|
|
|
|
|
|
async def _llm_request(
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
json_payload: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
_assert_llm_enabled()
|
|
|
|
url = f"{LLM_URL}{path}"
|
|
try:
|
|
response = await get_http_client().request(
|
|
method,
|
|
url,
|
|
json=json_payload,
|
|
timeout=_llm_timeout(),
|
|
)
|
|
except httpx.TimeoutException:
|
|
raise LLMGatewayError(504, "llm_timeout", "LLM request timed out")
|
|
except httpx.RequestError as exc:
|
|
raise LLMGatewayError(503, "llm_unavailable", f"LLM service is unavailable: {exc}")
|
|
|
|
if response.status_code >= 500:
|
|
raise LLMGatewayError(503, "llm_unavailable", _extract_upstream_error_message(response))
|
|
if response.status_code >= 400:
|
|
raise LLMGatewayError(
|
|
_map_upstream_llm_status(response.status_code),
|
|
"llm_rejected_request",
|
|
_extract_upstream_error_message(response),
|
|
)
|
|
|
|
try:
|
|
return response.json()
|
|
except Exception:
|
|
raise LLMGatewayError(503, "llm_invalid_response", "LLM service returned invalid JSON")
|
|
|
|
|
|
def _normalize_ai_chat_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
|
choices = response.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain choices")
|
|
|
|
first_choice = choices[0] if isinstance(choices[0], dict) else {}
|
|
message = first_choice.get("message") if isinstance(first_choice.get("message"), dict) else {}
|
|
content = message.get("content")
|
|
if not isinstance(content, str):
|
|
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain message content")
|
|
|
|
usage = response.get("usage") if isinstance(response.get("usage"), dict) else {}
|
|
return {
|
|
"model": response.get("model") or LLM_DEFAULT_MODEL,
|
|
"content": content,
|
|
"finish_reason": first_choice.get("finish_reason") or "stop",
|
|
"usage": {
|
|
"prompt_tokens": int(usage.get("prompt_tokens") or 0),
|
|
"completion_tokens": int(usage.get("completion_tokens") or 0),
|
|
"total_tokens": int(usage.get("total_tokens") or 0),
|
|
},
|
|
}
|
|
|
|
|
|
async def _get_llm_models_payload() -> Dict[str, Any]:
|
|
models = await _llm_request("GET", "/v1/models")
|
|
if isinstance(models.get("data"), list) and models["data"]:
|
|
return models
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": LLM_DEFAULT_MODEL,
|
|
"object": "model",
|
|
"owned_by": "self-hosted",
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
async def _get_health(base: str) -> Dict[str, Any]:
|
|
try:
|
|
r = await get_http_client().get(f"{base}/health", timeout=5)
|
|
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
|
|
except Exception:
|
|
return {"status": "unreachable"}
|
|
|
|
|
|
async def _post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
t0 = time.perf_counter()
|
|
try:
|
|
r = await get_http_client().post(url, json=payload)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
|
elapsed = (time.perf_counter() - t0) * 1000
|
|
logger.debug("POST %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
|
try:
|
|
return r.json()
|
|
except Exception:
|
|
# upstream returned non-JSON (HTML error page or empty body)
|
|
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
|
|
|
|
|
async def _post_file(url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
|
files = {"file": ("image", data, "application/octet-stream")}
|
|
t0 = time.perf_counter()
|
|
try:
|
|
r = await get_http_client().post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
|
elapsed = (time.perf_counter() - t0) * 1000
|
|
logger.debug("POST(file) %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
|
try:
|
|
return r.json()
|
|
except Exception:
|
|
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
|
|
|
|
|
async def _get_json(url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
t0 = time.perf_counter()
|
|
try:
|
|
r = await get_http_client().get(url, params=params)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
|
elapsed = (time.perf_counter() - t0) * 1000
|
|
logger.debug("GET %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
|
try:
|
|
return r.json()
|
|
except Exception:
|
|
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
health_checks = [
|
|
_get_health(CLIP_URL),
|
|
_get_health(BLIP_URL),
|
|
_get_health(YOLO_URL),
|
|
_get_health(QDRANT_SVC_URL),
|
|
]
|
|
llm_index: Optional[int] = None
|
|
if MATURITY_ENABLED:
|
|
health_checks.append(_get_health(MATURITY_URL))
|
|
if LLM_ENABLED:
|
|
llm_index = len(health_checks)
|
|
health_checks.append(_get_health(LLM_URL))
|
|
|
|
results = await asyncio.gather(*health_checks)
|
|
services: Dict[str, Any] = {
|
|
"clip": results[0],
|
|
"blip": results[1],
|
|
"yolo": results[2],
|
|
"qdrant": results[3],
|
|
}
|
|
if MATURITY_ENABLED:
|
|
services["maturity"] = results[4]
|
|
if LLM_ENABLED and llm_index is not None:
|
|
services["llm"] = {
|
|
"enabled": True,
|
|
"default_model": LLM_DEFAULT_MODEL,
|
|
"upstream": results[llm_index],
|
|
}
|
|
else:
|
|
services["llm"] = {
|
|
"enabled": False,
|
|
"default_model": LLM_DEFAULT_MODEL,
|
|
"upstream": {"status": "disabled"},
|
|
}
|
|
return {"status": "ok", "services": services}
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def llm_chat_completions(request: Request):
|
|
payload = _normalize_chat_payload(await _parse_llm_request(request))
|
|
return await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def llm_models():
|
|
return await _get_llm_models_payload()
|
|
|
|
|
|
@app.post("/ai/chat")
|
|
async def ai_chat(request: Request):
|
|
payload = _normalize_chat_payload(await _parse_llm_request(request))
|
|
response = await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
|
|
return _normalize_ai_chat_response(response)
|
|
|
|
|
|
@app.get("/ai/models")
|
|
async def ai_models():
|
|
models = await _get_llm_models_payload()
|
|
return {
|
|
"enabled": LLM_ENABLED,
|
|
"default_model": LLM_DEFAULT_MODEL,
|
|
"models": models.get("data", []),
|
|
}
|
|
|
|
|
|
@app.get("/ai/health")
|
|
async def ai_health():
|
|
if not LLM_ENABLED:
|
|
return {
|
|
"status": "ok",
|
|
"enabled": False,
|
|
"reachable": False,
|
|
"default_model": LLM_DEFAULT_MODEL,
|
|
"upstream": {"status": "disabled"},
|
|
}
|
|
|
|
upstream = await _get_health(LLM_URL)
|
|
reachable = upstream.get("status") == "ok"
|
|
return {
|
|
"status": "ok" if reachable else "degraded",
|
|
"enabled": True,
|
|
"reachable": reachable,
|
|
"default_model": LLM_DEFAULT_MODEL,
|
|
"upstream": upstream,
|
|
}
|
|
|
|
|
|
# ---- Individual analyze endpoints (URL) ----
|
|
|
|
@app.post("/analyze/clip")
|
|
async def analyze_clip(req: ClipRequest):
|
|
if not req.url:
|
|
raise HTTPException(400, "url is required")
|
|
return await _post_json(f"{CLIP_URL}/analyze", req.model_dump())
|
|
|
|
|
|
@app.post("/analyze/blip")
|
|
async def analyze_blip(req: BlipRequest):
|
|
if not req.url:
|
|
raise HTTPException(400, "url is required")
|
|
return await _post_json(f"{BLIP_URL}/caption", req.model_dump())
|
|
|
|
|
|
@app.post("/analyze/yolo")
|
|
async def analyze_yolo(req: YoloRequest):
|
|
if not req.url:
|
|
raise HTTPException(400, "url is required")
|
|
return await _post_json(f"{YOLO_URL}/detect", req.model_dump())
|
|
|
|
|
|
# ---- Individual analyze endpoints (file upload) ----
|
|
|
|
|
|
@app.post("/analyze/clip/file")
|
|
async def analyze_clip_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
threshold: Optional[float] = Form(None),
|
|
):
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {"limit": int(limit)}
|
|
if threshold is not None:
|
|
fields["threshold"] = float(threshold)
|
|
return await _post_file(f"{CLIP_URL}/analyze/file", data, fields)
|
|
|
|
|
|
@app.post("/analyze/blip/file")
|
|
async def analyze_blip_file(
|
|
file: UploadFile = File(...),
|
|
variants: int = Form(3),
|
|
max_length: int = Form(60),
|
|
):
|
|
data = await file.read()
|
|
fields = {"variants": int(variants), "max_length": int(max_length)}
|
|
return await _post_file(f"{BLIP_URL}/caption/file", data, fields)
|
|
|
|
|
|
@app.post("/analyze/yolo/file")
|
|
async def analyze_yolo_file(
|
|
file: UploadFile = File(...),
|
|
conf: float = Form(0.25),
|
|
):
|
|
data = await file.read()
|
|
fields = {"conf": float(conf)}
|
|
return await _post_file(f"{YOLO_URL}/detect/file", data, fields)
|
|
|
|
|
|
@app.post("/analyze/all")
|
|
async def analyze_all(payload: Dict[str, Any]):
|
|
url = payload.get("url")
|
|
if not url:
|
|
raise HTTPException(400, "url is required")
|
|
|
|
clip_req = {"url": url, "limit": int(payload.get("limit", 5)), "threshold": payload.get("threshold")}
|
|
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
|
|
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
|
|
|
|
clip_res, blip_res, yolo_res = await asyncio.gather(
|
|
_post_json(f"{CLIP_URL}/analyze", clip_req),
|
|
_post_json(f"{BLIP_URL}/caption", blip_req),
|
|
_post_json(f"{YOLO_URL}/detect", yolo_req),
|
|
)
|
|
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
|
|
|
|
|
# ---- Vector / Qdrant endpoints ----
|
|
|
|
@app.post("/vectors/upsert")
|
|
async def vectors_upsert(payload: Dict[str, Any]):
|
|
return await _post_json(f"{QDRANT_SVC_URL}/upsert", payload)
|
|
|
|
|
|
@app.post("/vectors/upsert/file")
|
|
async def vectors_upsert_file(
|
|
file: UploadFile = File(...),
|
|
id: Optional[str] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
metadata_json: Optional[str] = Form(None),
|
|
):
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {}
|
|
if id is not None:
|
|
fields["id"] = id
|
|
if collection is not None:
|
|
fields["collection"] = collection
|
|
if metadata_json is not None:
|
|
fields["metadata_json"] = metadata_json
|
|
return await _post_file(f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
|
|
|
|
|
@app.post("/vectors/upsert/vector")
|
|
async def vectors_upsert_vector(payload: Dict[str, Any]):
|
|
return await _post_json(f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
|
|
|
|
|
@app.post("/vectors/search")
|
|
async def vectors_search(payload: Dict[str, Any]):
|
|
return await _post_json(f"{QDRANT_SVC_URL}/search", payload)
|
|
|
|
|
|
@app.post("/vectors/search/file")
|
|
async def vectors_search_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
score_threshold: Optional[float] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
hnsw_ef: Optional[int] = Form(None),
|
|
exact: bool = Form(False),
|
|
indexed_only: bool = Form(False),
|
|
filter_metadata_json: Optional[str] = Form(None),
|
|
):
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {"limit": int(limit), "exact": exact, "indexed_only": indexed_only}
|
|
if score_threshold is not None:
|
|
fields["score_threshold"] = float(score_threshold)
|
|
if collection is not None:
|
|
fields["collection"] = collection
|
|
if hnsw_ef is not None:
|
|
fields["hnsw_ef"] = int(hnsw_ef)
|
|
if filter_metadata_json is not None:
|
|
fields["filter_metadata_json"] = filter_metadata_json
|
|
return await _post_file(f"{QDRANT_SVC_URL}/search/file", data, fields)
|
|
|
|
|
|
@app.post("/vectors/search/vector")
|
|
async def vectors_search_vector(payload: Dict[str, Any]):
|
|
return await _post_json(f"{QDRANT_SVC_URL}/search/vector", payload)
|
|
|
|
|
|
@app.post("/vectors/delete")
|
|
async def vectors_delete(payload: Dict[str, Any]):
|
|
return await _post_json(f"{QDRANT_SVC_URL}/delete", payload)
|
|
|
|
|
|
@app.get("/vectors/collections")
|
|
async def vectors_collections():
|
|
return await _get_json(f"{QDRANT_SVC_URL}/collections")
|
|
|
|
|
|
@app.post("/vectors/collections")
|
|
async def vectors_create_collection(payload: Dict[str, Any]):
|
|
return await _post_json(f"{QDRANT_SVC_URL}/collections", payload)
|
|
|
|
|
|
@app.get("/vectors/collections/{name}")
|
|
async def vectors_collection_info(name: str):
|
|
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}")
|
|
|
|
|
|
@app.get("/vectors/inspect")
|
|
async def vectors_inspect():
|
|
"""Full diagnostic summary for all Qdrant collections (HNSW, optimizer, payload indexes, RAM estimate)."""
|
|
t0 = time.perf_counter()
|
|
logger.info("vectors_inspect: start")
|
|
result = await _get_json(f"{QDRANT_SVC_URL}/inspect")
|
|
logger.info("vectors_inspect: done elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
|
return result
|
|
|
|
|
|
@app.delete("/vectors/collections/{name}")
|
|
async def vectors_delete_collection(name: str):
|
|
try:
|
|
r = await get_http_client().delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
|
except httpx.RequestError as exc:
|
|
raise HTTPException(status_code=502, detail=f"Upstream request failed: {exc}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
|
return r.json()
|
|
|
|
|
|
@app.get("/vectors/points/{point_id}")
|
|
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
|
|
params = {}
|
|
if collection:
|
|
params["collection"] = collection
|
|
return await _get_json(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
|
|
|
|
|
@app.get("/vectors/points/by-original-id/{original_id}")
|
|
async def vectors_get_point_by_original_id(original_id: str, collection: Optional[str] = None):
|
|
params = {}
|
|
if collection:
|
|
params["collection"] = collection
|
|
return await _get_json(f"{QDRANT_SVC_URL}/points/by-original-id/{original_id}", params=params)
|
|
|
|
|
|
# ---- File-based universal analyze ----
|
|
|
|
@app.post("/analyze/all/file")
|
|
async def analyze_all_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
variants: int = Form(3),
|
|
conf: float = Form(0.25),
|
|
max_length: int = Form(60),
|
|
):
|
|
data = await file.read()
|
|
clip_res, blip_res, yolo_res = await asyncio.gather(
|
|
_post_file(f"{CLIP_URL}/analyze/file", data, {"limit": limit}),
|
|
_post_file(f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length}),
|
|
_post_file(f"{YOLO_URL}/detect/file", data, {"conf": conf}),
|
|
)
|
|
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
|
|
|
|
|
# ---- Maturity / NSFW analysis endpoints ----
|
|
|
|
|
|
def _assert_maturity_enabled() -> None:
|
|
if not MATURITY_ENABLED:
|
|
raise HTTPException(status_code=503, detail="Maturity service is disabled")
|
|
|
|
|
|
@app.post("/analyze/maturity")
|
|
async def analyze_maturity(req: MaturityRequest):
|
|
"""Analyze an image URL for maturity / NSFW content.
|
|
|
|
Returns a normalized maturity signal including maturity_label (safe/mature),
|
|
confidence, score, optional sublabels, and an action_hint for Nova moderation.
|
|
"""
|
|
_assert_maturity_enabled()
|
|
if not req.url:
|
|
raise HTTPException(status_code=400, detail="url is required")
|
|
logger.info("analyze_maturity: url=%s", req.url)
|
|
return await _post_json(f"{MATURITY_URL}/analyze", req.model_dump())
|
|
|
|
|
|
@app.post("/analyze/maturity/file")
|
|
async def analyze_maturity_file(file: UploadFile = File(...)):
|
|
"""Analyze an uploaded image file for maturity / NSFW content.
|
|
|
|
Returns the same normalized maturity signal as /analyze/maturity.
|
|
"""
|
|
_assert_maturity_enabled()
|
|
data = await file.read()
|
|
logger.info("analyze_maturity_file: filename=%s size=%d", file.filename, len(data))
|
|
return await _post_file(f"{MATURITY_URL}/analyze/file", data, {})
|
|
|
|
|
|
# ---- Card renderer endpoints ----
|
|
|
|
@app.get("/cards/templates")
|
|
async def cards_templates():
|
|
"""List available card templates."""
|
|
return await _get_json(f"{CARD_RENDERER_URL}/templates")
|
|
|
|
|
|
@app.post("/cards/render")
|
|
async def cards_render(payload: Dict[str, Any]):
|
|
"""Render a Nova card from a remote image URL. Returns binary image bytes."""
|
|
try:
|
|
resp = await get_http_client().post(f"{CARD_RENDERER_URL}/render", json=payload)
|
|
except httpx.RequestError as exc:
|
|
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
|
if resp.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
|
|
return Response(
|
|
content=resp.content,
|
|
media_type=resp.headers.get("content-type", "image/webp"),
|
|
)
|
|
|
|
|
|
@app.post("/cards/render/file")
|
|
async def cards_render_file(
|
|
file: UploadFile = File(...),
|
|
template: str = Form("nova-artwork-v1"),
|
|
width: int = Form(1200),
|
|
height: int = Form(630),
|
|
output: str = Form("webp"),
|
|
quality: int = Form(90),
|
|
title: Optional[str] = Form(None),
|
|
subtitle: Optional[str] = Form(None),
|
|
username: Optional[str] = Form(None),
|
|
category: Optional[str] = Form(None),
|
|
tags_json: Optional[str] = Form(None),
|
|
show_logo: bool = Form(True),
|
|
):
|
|
"""Render a Nova card from an uploaded image file. Returns binary image bytes."""
|
|
data = await file.read()
|
|
fields: Dict[str, Any] = {
|
|
"template": template,
|
|
"width": width,
|
|
"height": height,
|
|
"output": output,
|
|
"quality": quality,
|
|
"show_logo": show_logo,
|
|
}
|
|
if title is not None:
|
|
fields["title"] = title
|
|
if subtitle is not None:
|
|
fields["subtitle"] = subtitle
|
|
if username is not None:
|
|
fields["username"] = username
|
|
if category is not None:
|
|
fields["category"] = category
|
|
if tags_json is not None:
|
|
fields["tags_json"] = tags_json
|
|
|
|
upload_files = {"file": (file.filename or "image", data, file.content_type or "application/octet-stream")}
|
|
try:
|
|
resp = await get_http_client().post(
|
|
f"{CARD_RENDERER_URL}/render/file",
|
|
data={k: str(v) for k, v in fields.items()},
|
|
files=upload_files,
|
|
)
|
|
except httpx.RequestError as exc:
|
|
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
|
if resp.status_code >= 400:
|
|
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
|
|
return Response(
|
|
content=resp.content,
|
|
media_type=resp.headers.get("content-type", "image/webp"),
|
|
)
|
|
|
|
|
|
@app.post("/cards/render/meta")
|
|
async def cards_render_meta(payload: Dict[str, Any]):
|
|
"""Return crop and layout metadata for a card render (no image produced)."""
|
|
return await _post_json(f"{CARD_RENDERER_URL}/render/meta", payload)
|
|
|
|
|
|
# ---- Qdrant administration endpoints (index management + collection config) ----
|
|
|
|
@app.get("/vectors/collections/{name}/indexes")
|
|
async def vectors_collection_indexes(name: str):
|
|
"""List payload indexes for a collection."""
|
|
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes")
|
|
|
|
|
|
@app.post("/vectors/collections/{name}/indexes")
|
|
async def vectors_create_payload_index(name: str, payload: Dict[str, Any]):
|
|
"""Create a payload index on a field in a collection."""
|
|
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload)
|
|
|
|
|
|
@app.post("/vectors/collections/{name}/ensure-indexes")
|
|
async def vectors_ensure_indexes(name: str, payload: Dict[str, Any]):
|
|
"""Idempotently ensure payload indexes exist for a list of fields."""
|
|
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload)
|
|
|
|
|
|
@app.post("/vectors/collections/{name}/configure")
|
|
async def vectors_configure_collection(name: str, payload: Dict[str, Any]):
|
|
"""Update HNSW and optimizer configuration for a collection."""
|
|
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/configure", payload)
|