Files
vision/qdrant/main.py

411 lines
13 KiB
Python

from __future__ import annotations
import os
import uuid
from typing import Any, Dict, List, Optional
import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
PointStruct,
VectorParams,
Filter,
FieldCondition,
MatchValue,
)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
client: QdrantClient = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# Startup / shutdown
# ---------------------------------------------------------------------------
@app.on_event("startup")
def startup():
global client
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
_ensure_collection()
def _ensure_collection():
"""Create the default collection if it does not exist yet."""
collections = [c.name for c in client.get_collections().collections]
if COLLECTION_NAME not in collections:
client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
)
# ---------------------------------------------------------------------------
# Request / Response models
# ---------------------------------------------------------------------------
class UpsertUrlRequest(BaseModel):
url: str
id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
collection: Optional[str] = None
class UpsertVectorRequest(BaseModel):
vector: List[float]
id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
collection: Optional[str] = None
class SearchUrlRequest(BaseModel):
url: str
limit: int = Field(default=5, ge=1, le=100)
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
collection: Optional[str] = None
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
class SearchVectorRequest(BaseModel):
vector: List[float]
limit: int = Field(default=5, ge=1, le=100)
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
collection: Optional[str] = None
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
class DeleteRequest(BaseModel):
ids: List[str]
collection: Optional[str] = None
class CollectionRequest(BaseModel):
name: str
vector_dim: int = Field(default=512, ge=1)
distance: str = Field(default="cosine")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _col(name: Optional[str]) -> str:
return name or COLLECTION_NAME
async def _embed_url(url: str) -> List[float]:
"""Call the CLIP service to get an image embedding."""
async with httpx.AsyncClient(timeout=30) as http:
try:
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
except httpx.RequestError as e:
raise HTTPException(502, f"CLIP request failed: {str(e)}")
if r.status_code >= 400:
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
try:
return r.json()["vector"]
except Exception:
raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}")
async def _embed_bytes(data: bytes) -> List[float]:
"""Call the CLIP service to embed uploaded file bytes."""
async with httpx.AsyncClient(timeout=30) as http:
files = {"file": ("image", data, "application/octet-stream")}
try:
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
except httpx.RequestError as e:
raise HTTPException(502, f"CLIP request failed: {str(e)}")
if r.status_code >= 400:
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
try:
return r.json()["vector"]
except Exception:
raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}")
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
if not metadata:
return None
conditions = [
FieldCondition(key=k, match=MatchValue(value=v))
for k, v in metadata.items()
]
return Filter(must=conditions)
def _point_id(raw: Optional[str]) -> str:
"""Return a Qdrant-compatible point id.
Qdrant accepts either an unsigned integer or a UUID string (with hyphens).
If the provided `raw` value is an int or valid UUID we return it (as int or str).
Otherwise we generate a new UUID string and the caller should store the
original `raw` value in the point payload under `_original_id`.
"""
if not raw:
return str(uuid.uuid4())
# allow integer ids
try:
return int(raw)
except Exception:
pass
# allow UUID strings
try:
u = uuid.UUID(raw)
return str(u)
except Exception:
# fallback: generate a UUID
return str(uuid.uuid4())
# ---------------------------------------------------------------------------
# Health
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
try:
info = client.get_collections()
names = [c.name for c in info.collections]
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
except Exception as e:
return {"status": "error", "detail": str(e)}
# ---------------------------------------------------------------------------
# Collection management
# ---------------------------------------------------------------------------
@app.post("/collections")
def create_collection(req: CollectionRequest):
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
dist = dist_map.get(req.distance.lower())
if dist is None:
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
collections = [c.name for c in client.get_collections().collections]
if req.name in collections:
raise HTTPException(409, f"Collection '{req.name}' already exists")
client.create_collection(
collection_name=req.name,
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
)
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
@app.get("/collections")
def list_collections():
info = client.get_collections()
return {"collections": [c.name for c in info.collections]}
@app.get("/collections/{name}")
def collection_info(name: str):
try:
info = client.get_collection(name)
return {
"name": name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else None,
}
except Exception as e:
raise HTTPException(404, str(e))
@app.delete("/collections/{name}")
def delete_collection(name: str):
client.delete_collection(name)
return {"deleted": name}
# ---------------------------------------------------------------------------
# Upsert endpoints
# ---------------------------------------------------------------------------
@app.post("/upsert")
async def upsert_url(req: UpsertUrlRequest):
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
vector = await _embed_url(req.url)
pid = _point_id(req.id)
payload = {**req.metadata, "source_url": req.url}
# preserve original user-provided id if it wasn't usable as a point id
if req.id is not None and str(pid) != str(req.id):
payload["_original_id"] = req.id
col = _col(req.collection)
try:
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=vector, payload=payload)],
)
except Exception as e:
raise HTTPException(500, str(e))
return {"id": pid, "collection": col, "dim": len(vector)}
@app.post("/upsert/file")
async def upsert_file(
file: UploadFile = File(...),
id: Optional[str] = Form(None),
collection: Optional[str] = Form(None),
metadata_json: Optional[str] = Form(None),
):
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
import json
data = await file.read()
vector = await _embed_bytes(data)
pid = _point_id(id)
payload: Dict[str, Any] = {}
if metadata_json:
try:
payload = json.loads(metadata_json)
except json.JSONDecodeError:
raise HTTPException(400, "metadata_json must be valid JSON")
# preserve original user-provided id if it wasn't usable as a point id
if id is not None and str(pid) != str(id):
payload["_original_id"] = id
col = _col(collection)
try:
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=vector, payload=payload)],
)
except Exception as e:
raise HTTPException(500, str(e))
return {"id": pid, "collection": col, "dim": len(vector)}
@app.post("/upsert/vector")
def upsert_vector(req: UpsertVectorRequest):
"""Store a pre-computed vector directly (skip CLIP embedding)."""
pid = _point_id(req.id)
col = _col(req.collection)
payload = dict(req.metadata or {})
if req.id is not None and str(pid) != str(req.id):
payload["_original_id"] = req.id
try:
client.upsert(
collection_name=col,
points=[PointStruct(id=pid, vector=req.vector, payload=payload)],
)
except Exception as e:
raise HTTPException(500, str(e))
return {"id": pid, "collection": col, "dim": len(req.vector)}
# ---------------------------------------------------------------------------
# Search endpoints
# ---------------------------------------------------------------------------
@app.post("/search")
async def search_url(req: SearchUrlRequest):
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
vector = await _embed_url(req.url)
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
@app.post("/search/file")
async def search_file(
file: UploadFile = File(...),
limit: int = Form(5),
score_threshold: Optional[float] = Form(None),
collection: Optional[str] = Form(None),
):
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
data = await file.read()
vector = await _embed_bytes(data)
return _do_search(vector, int(limit), score_threshold, collection, {})
@app.post("/search/vector")
def search_vector(req: SearchVectorRequest):
"""Search Qdrant using a pre-computed vector."""
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata)
def _do_search(
vector: List[float],
limit: int,
score_threshold: Optional[float],
collection: Optional[str],
filter_metadata: Dict[str, Any],
):
col = _col(collection)
qfilter = _build_filter(filter_metadata)
results = client.query_points(
collection_name=col,
query=vector,
limit=limit,
score_threshold=score_threshold,
query_filter=qfilter,
)
hits = []
for point in results.points:
hits.append({
"id": point.id,
"score": point.score,
"metadata": point.payload,
})
return {"results": hits, "collection": col, "count": len(hits)}
# ---------------------------------------------------------------------------
# Delete points
# ---------------------------------------------------------------------------
@app.post("/delete")
def delete_points(req: DeleteRequest):
col = _col(req.collection)
client.delete(
collection_name=col,
points_selector=req.ids,
)
return {"deleted": req.ids, "collection": col}
# ---------------------------------------------------------------------------
# Get point by ID
# ---------------------------------------------------------------------------
@app.get("/points/{point_id}")
def get_point(point_id: str, collection: Optional[str] = None):
col = _col(collection)
try:
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
if not points:
raise HTTPException(404, f"Point '{point_id}' not found")
p = points[0]
return {
"id": p.id,
"vector": p.vector,
"metadata": p.payload,
"collection": col,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(404, str(e))