from __future__ import annotations import os import asyncio from typing import Any, Dict, Optional import httpx from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel, Field CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000") BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000") YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000") QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000") VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20")) # API key (set via env var `API_KEY`). If not set, gateway will reject requests. API_KEY = os.getenv("API_KEY") class APIKeyMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # allow health and docs endpoints without API key if request.url.path in ("/health", "/openapi.json", "/docs", "/redoc"): return await call_next(request) key = request.headers.get("x-api-key") or request.headers.get("X-API-Key") if not API_KEY or key != API_KEY: return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) return await call_next(request) app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0") app.add_middleware(APIKeyMiddleware) class ClipRequest(BaseModel): url: Optional[str] = None limit: int = Field(default=5, ge=1, le=50) threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) class BlipRequest(BaseModel): url: Optional[str] = None variants: int = Field(default=3, ge=0, le=10) max_length: int = Field(default=60, ge=10, le=200) class YoloRequest(BaseModel): url: Optional[str] = None conf: float = Field(default=0.25, ge=0.0, le=1.0) async def _get_health(client: httpx.AsyncClient, base: str) -> Dict[str, Any]: try: r = await client.get(f"{base}/health") return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code} except Exception: return {"status": "unreachable"} async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]: try: r = await client.post(url, json=payload) except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}") if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}") try: return r.json() except Exception: # upstream returned non-JSON (HTML error page or empty body) raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}") async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]: files = {"file": ("image", data, "application/octet-stream")} try: r = await client.post(url, data={k: str(v) for k, v in fields.items()}, files=files) except httpx.RequestError as e: raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}") if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}") try: return r.json() except Exception: raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}") @app.get("/health") async def health(): async with httpx.AsyncClient(timeout=5) as client: clip_h, blip_h, yolo_h, qdrant_h = await asyncio.gather( _get_health(client, CLIP_URL), _get_health(client, BLIP_URL), _get_health(client, YOLO_URL), _get_health(client, QDRANT_SVC_URL), ) return {"status": "ok", "services": {"clip": clip_h, "blip": blip_h, "yolo": yolo_h, "qdrant": qdrant_h}} # ---- Individual analyze endpoints (URL) ---- @app.post("/analyze/clip") async def analyze_clip(req: ClipRequest): if not req.url: raise HTTPException(400, "url is required") async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{CLIP_URL}/analyze", req.model_dump()) @app.post("/analyze/blip") async def analyze_blip(req: BlipRequest): if not req.url: raise HTTPException(400, "url is required") async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{BLIP_URL}/caption", req.model_dump()) @app.post("/analyze/yolo") async def analyze_yolo(req: YoloRequest): if not req.url: raise HTTPException(400, "url is required") async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{YOLO_URL}/detect", req.model_dump()) # ---- Individual analyze endpoints (file upload) ---- @app.post("/analyze/clip/file") async def analyze_clip_file( file: UploadFile = File(...), limit: int = Form(5), threshold: Optional[float] = Form(None), ): data = await file.read() fields: Dict[str, Any] = {"limit": int(limit)} if threshold is not None: fields["threshold"] = float(threshold) async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_file(client, f"{CLIP_URL}/analyze/file", data, fields) @app.post("/analyze/blip/file") async def analyze_blip_file( file: UploadFile = File(...), variants: int = Form(3), max_length: int = Form(60), ): data = await file.read() fields = {"variants": int(variants), "max_length": int(max_length)} async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_file(client, f"{BLIP_URL}/caption/file", data, fields) @app.post("/analyze/yolo/file") async def analyze_yolo_file( file: UploadFile = File(...), conf: float = Form(0.25), ): data = await file.read() fields = {"conf": float(conf)} async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_file(client, f"{YOLO_URL}/detect/file", data, fields) @app.post("/analyze/all") async def analyze_all(payload: Dict[str, Any]): url = payload.get("url") if not url: raise HTTPException(400, "url is required") clip_req = {"url": url, "limit": int(payload.get("limit", 5)), "threshold": payload.get("threshold")} blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))} yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))} async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: clip_task = _post_json(client, f"{CLIP_URL}/analyze", clip_req) blip_task = _post_json(client, f"{BLIP_URL}/caption", blip_req) yolo_task = _post_json(client, f"{YOLO_URL}/detect", yolo_req) clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task) return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res} # ---- Vector / Qdrant endpoints ---- @app.post("/vectors/upsert") async def vectors_upsert(payload: Dict[str, Any]): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{QDRANT_SVC_URL}/upsert", payload) @app.post("/vectors/upsert/file") async def vectors_upsert_file( file: UploadFile = File(...), id: Optional[str] = Form(None), collection: Optional[str] = Form(None), metadata_json: Optional[str] = Form(None), ): data = await file.read() fields: Dict[str, Any] = {} if id is not None: fields["id"] = id if collection is not None: fields["collection"] = collection if metadata_json is not None: fields["metadata_json"] = metadata_json async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_file(client, f"{QDRANT_SVC_URL}/upsert/file", data, fields) @app.post("/vectors/upsert/vector") async def vectors_upsert_vector(payload: Dict[str, Any]): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{QDRANT_SVC_URL}/upsert/vector", payload) @app.post("/vectors/search") async def vectors_search(payload: Dict[str, Any]): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{QDRANT_SVC_URL}/search", payload) @app.post("/vectors/search/file") async def vectors_search_file( file: UploadFile = File(...), limit: int = Form(5), score_threshold: Optional[float] = Form(None), collection: Optional[str] = Form(None), ): data = await file.read() fields: Dict[str, Any] = {"limit": int(limit)} if score_threshold is not None: fields["score_threshold"] = float(score_threshold) if collection is not None: fields["collection"] = collection async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_file(client, f"{QDRANT_SVC_URL}/search/file", data, fields) @app.post("/vectors/search/vector") async def vectors_search_vector(payload: Dict[str, Any]): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{QDRANT_SVC_URL}/search/vector", payload) @app.post("/vectors/delete") async def vectors_delete(payload: Dict[str, Any]): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{QDRANT_SVC_URL}/delete", payload) @app.get("/vectors/collections") async def vectors_collections(): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: r = await client.get(f"{QDRANT_SVC_URL}/collections") if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}") return r.json() @app.post("/vectors/collections") async def vectors_create_collection(payload: Dict[str, Any]): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: return await _post_json(client, f"{QDRANT_SVC_URL}/collections", payload) @app.get("/vectors/collections/{name}") async def vectors_collection_info(name: str): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: r = await client.get(f"{QDRANT_SVC_URL}/collections/{name}") if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}") return r.json() @app.delete("/vectors/collections/{name}") async def vectors_delete_collection(name: str): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: r = await client.delete(f"{QDRANT_SVC_URL}/collections/{name}") if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}") return r.json() @app.get("/vectors/points/{point_id}") async def vectors_get_point(point_id: str, collection: Optional[str] = None): async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: params = {} if collection: params["collection"] = collection r = await client.get(f"{QDRANT_SVC_URL}/points/{point_id}", params=params) if r.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}") return r.json() # ---- File-based universal analyze ---- @app.post("/analyze/all/file") async def analyze_all_file( file: UploadFile = File(...), limit: int = Form(5), variants: int = Form(3), conf: float = Form(0.25), max_length: int = Form(60), ): data = await file.read() async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client: clip_task = _post_file(client, f"{CLIP_URL}/analyze/file", data, {"limit": limit}) blip_task = _post_file(client, f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length}) yolo_task = _post_file(client, f"{YOLO_URL}/detect/file", data, {"conf": conf}) clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task) return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}