from __future__ import annotations import os import uuid from typing import Any, Dict, List, Optional import httpx import numpy as np from fastapi import FastAPI, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, PointStruct, VectorParams, Filter, FieldCondition, MatchValue, HnswConfigDiff, OptimizersConfigDiff, SearchParams, PayloadSchemaType, ScalarQuantizationConfig, ScalarType, ) # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000") COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images") VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512")) # hnsw_ef at query time: higher = better recall, slightly more latency (Qdrant default ~100) SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128")) app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0") client: QdrantClient = None # type: ignore[assignment] # --------------------------------------------------------------------------- # Startup / shutdown # --------------------------------------------------------------------------- @app.on_event("startup") def startup(): global client client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) _ensure_collection() def _ensure_collection(): """Create the default collection with production-friendly defaults if it does not exist yet.""" collections = [c.name for c in client.get_collections().collections] if COLLECTION_NAME not in collections: client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE), hnsw_config=HnswConfigDiff( m=16, ef_construct=200, # higher than default 100 = better index quality on_disk=False, # keep HNSW graph in RAM for fast traversal ), optimizers_config=OptimizersConfigDiff( indexing_threshold=20000, # start indexing after 20k accumulated vectors default_segment_number=4, # parallelism-friendly segment count ), ) # --------------------------------------------------------------------------- # Request / Response models # --------------------------------------------------------------------------- class UpsertUrlRequest(BaseModel): url: str id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) collection: Optional[str] = None class UpsertVectorRequest(BaseModel): vector: List[float] id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) collection: Optional[str] = None class SearchUrlRequest(BaseModel): url: str limit: int = Field(default=5, ge=1, le=100) score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512, description="Override ef at query time. Higher = better recall, slightly higher latency.") exact: bool = Field(default=False, description="Brute-force exact search. Avoid on large collections.") indexed_only: bool = Field(default=False, description="Search only fully indexed segments. Useful during bulk ingest.") class SearchVectorRequest(BaseModel): vector: List[float] limit: int = Field(default=5, ge=1, le=100) score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0) collection: Optional[str] = None filter_metadata: Dict[str, Any] = Field(default_factory=dict) hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512) exact: bool = False indexed_only: bool = False class DeleteRequest(BaseModel): ids: List[str] collection: Optional[str] = None class CollectionRequest(BaseModel): name: str vector_dim: int = Field(default=512, ge=1) distance: str = Field(default="cosine") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _col(name: Optional[str]) -> str: return name or COLLECTION_NAME async def _embed_url(url: str) -> List[float]: """Call the CLIP service to get an image embedding.""" async with httpx.AsyncClient(timeout=30) as http: try: r = await http.post(f"{CLIP_URL}/embed", json={"url": url}) except httpx.RequestError as e: raise HTTPException(502, f"CLIP request failed: {str(e)}") if r.status_code >= 400: raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}") try: return r.json()["vector"] except Exception: raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}") async def _embed_bytes(data: bytes) -> List[float]: """Call the CLIP service to embed uploaded file bytes.""" async with httpx.AsyncClient(timeout=30) as http: files = {"file": ("image", data, "application/octet-stream")} try: r = await http.post(f"{CLIP_URL}/embed/file", files=files) except httpx.RequestError as e: raise HTTPException(502, f"CLIP request failed: {str(e)}") if r.status_code >= 400: raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}") try: return r.json()["vector"] except Exception: raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}") def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]: if not metadata: return None conditions = [ FieldCondition(key=k, match=MatchValue(value=v)) for k, v in metadata.items() ] return Filter(must=conditions) def _id_filter(original_id: str) -> Filter: return Filter(must=[FieldCondition(key="_original_id", match=MatchValue(value=original_id))]) def _point_id(raw: Optional[str]) -> str: """Return a Qdrant-compatible point id. Qdrant accepts either an unsigned integer or a UUID string (with hyphens). If the provided `raw` value is an int or valid UUID we return it (as int or str). Otherwise we generate a new UUID string and the caller should store the original `raw` value in the point payload under `_original_id`. """ if not raw: return str(uuid.uuid4()) # allow integer ids try: return int(raw) except Exception: pass # allow UUID strings try: u = uuid.UUID(raw) return str(u) except Exception: # fallback: generate a UUID return str(uuid.uuid4()) # --------------------------------------------------------------------------- # Health # --------------------------------------------------------------------------- @app.get("/health") def health(): try: info = client.get_collections() names = [c.name for c in info.collections] return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names} except Exception as e: return {"status": "error", "detail": str(e)} @app.get("/inspect") def inspect(): """Return a full diagnostic summary for every collection. Covers: vector counts, segment counts, HNSW config, optimizer config, quantization, payload indexes and their coverage. Designed for production health checks and the Qdrant optimization workflow. """ try: all_collections = client.get_collections().collections except Exception as exc: return {"status": "error", "detail": str(exc)} result = {} for col_desc in all_collections: name = col_desc.name try: info = client.get_collection(name) cfg = info.config hnsw = cfg.hnsw_config opt = cfg.optimizer_config quant = cfg.quantization_config params = cfg.params # Estimate raw RAM footprint: vectors * dim * 4 bytes * 1.5 safety factor vec_count = info.vectors_count or 0 vec_dim = ( params.vectors.size if hasattr(params.vectors, "size") else VECTOR_DIM ) ram_estimate_mb = round(vec_count * vec_dim * 4 * 1.5 / 1_048_576, 1) result[name] = { "status": info.status.value if info.status else None, "optimizer_status": str(info.optimizer_status) if info.optimizer_status else None, "vectors_count": vec_count, "indexed_vectors_count": info.indexed_vectors_count, "points_count": info.points_count, "segments_count": info.segments_count, "ram_estimate_mb": ram_estimate_mb, "hnsw": { "m": hnsw.m, "ef_construct": hnsw.ef_construct, "on_disk": hnsw.on_disk, "full_scan_threshold": hnsw.full_scan_threshold, "max_indexing_threads": hnsw.max_indexing_threads, } if hnsw else None, "optimizer": { "indexing_threshold": opt.indexing_threshold, "default_segment_number": opt.default_segment_number, "max_segment_size": opt.max_segment_size, "memmap_threshold": opt.memmap_threshold, "flush_interval_sec": opt.flush_interval_sec, } if opt else None, "quantization": str(quant) if quant else None, "payload_indexes": { k: { "type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type), "points": v.points, "coverage_pct": round(v.points / max(vec_count, 1) * 100, 1), } for k, v in (info.payload_schema or {}).items() }, "payload_index_count": len(info.payload_schema or {}), "search_hnsw_ef": SEARCH_HNSW_EF, } except Exception as exc: result[name] = {"error": str(exc)} return {"collections": result, "total": len(result)} # --------------------------------------------------------------------------- # Collection management # --------------------------------------------------------------------------- @app.post("/collections") def create_collection(req: CollectionRequest): dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT} dist = dist_map.get(req.distance.lower()) if dist is None: raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.") collections = [c.name for c in client.get_collections().collections] if req.name in collections: raise HTTPException(409, f"Collection '{req.name}' already exists") # Apply the same production defaults as _ensure_collection so all # collections start with tuned HNSW and optimizer settings. client.create_collection( collection_name=req.name, vectors_config=VectorParams(size=req.vector_dim, distance=dist), hnsw_config=HnswConfigDiff(m=16, ef_construct=200, on_disk=False), optimizers_config=OptimizersConfigDiff(indexing_threshold=20000, default_segment_number=4), ) return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance} @app.get("/collections") def list_collections(): info = client.get_collections() return {"collections": [c.name for c in info.collections]} @app.get("/collections/{name}") def collection_info(name: str): try: info = client.get_collection(name) cfg = info.config hnsw = cfg.hnsw_config opt = cfg.optimizer_config quant = cfg.quantization_config return { "name": name, "vectors_count": info.vectors_count, "indexed_vectors_count": info.indexed_vectors_count, "points_count": info.points_count, "segments_count": info.segments_count, "status": info.status.value if info.status else None, "optimizer_status": str(info.optimizer_status) if info.optimizer_status else None, "hnsw": { "m": hnsw.m, "ef_construct": hnsw.ef_construct, "on_disk": hnsw.on_disk, "full_scan_threshold": hnsw.full_scan_threshold, "max_indexing_threads": hnsw.max_indexing_threads, } if hnsw else None, "optimizer": { "indexing_threshold": opt.indexing_threshold, "default_segment_number": opt.default_segment_number, "max_segment_size": opt.max_segment_size, "memmap_threshold": opt.memmap_threshold, "flush_interval_sec": opt.flush_interval_sec, } if opt else None, "quantization": str(quant) if quant else None, "payload_schema": { k: { "type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type), "points": v.points, } for k, v in (info.payload_schema or {}).items() }, } except Exception as e: raise HTTPException(404, str(e)) @app.delete("/collections/{name}") def delete_collection(name: str): client.delete_collection(name) return {"deleted": name} # --------------------------------------------------------------------------- # Upsert endpoints # --------------------------------------------------------------------------- @app.post("/upsert") async def upsert_url(req: UpsertUrlRequest): """Embed an image by URL via CLIP, then store the vector in Qdrant.""" vector = await _embed_url(req.url) pid = _point_id(req.id) payload = {**req.metadata, "source_url": req.url} # preserve original user-provided id if it wasn't usable as a point id if req.id is not None and str(pid) != str(req.id): payload["_original_id"] = req.id col = _col(req.collection) try: client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=vector, payload=payload)], ) except Exception as e: raise HTTPException(500, str(e)) return {"id": pid, "collection": col, "dim": len(vector)} @app.post("/upsert/file") async def upsert_file( file: UploadFile = File(...), id: Optional[str] = Form(None), collection: Optional[str] = Form(None), metadata_json: Optional[str] = Form(None), ): """Embed an uploaded image via CLIP, then store the vector in Qdrant.""" import json data = await file.read() vector = await _embed_bytes(data) pid = _point_id(id) payload: Dict[str, Any] = {} if metadata_json: try: payload = json.loads(metadata_json) except json.JSONDecodeError: raise HTTPException(400, "metadata_json must be valid JSON") # preserve original user-provided id if it wasn't usable as a point id if id is not None and str(pid) != str(id): payload["_original_id"] = id col = _col(collection) try: client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=vector, payload=payload)], ) except Exception as e: raise HTTPException(500, str(e)) return {"id": pid, "collection": col, "dim": len(vector)} @app.post("/upsert/vector") def upsert_vector(req: UpsertVectorRequest): """Store a pre-computed vector directly (skip CLIP embedding).""" pid = _point_id(req.id) col = _col(req.collection) payload = dict(req.metadata or {}) if req.id is not None and str(pid) != str(req.id): payload["_original_id"] = req.id try: client.upsert( collection_name=col, points=[PointStruct(id=pid, vector=req.vector, payload=payload)], ) except Exception as e: raise HTTPException(500, str(e)) return {"id": pid, "collection": col, "dim": len(req.vector)} # --------------------------------------------------------------------------- # Search endpoints # --------------------------------------------------------------------------- @app.post("/search") async def search_url(req: SearchUrlRequest): """Embed an image by URL via CLIP, then search Qdrant for similar vectors.""" vector = await _embed_url(req.url) return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only) @app.post("/search/file") async def search_file( file: UploadFile = File(...), limit: int = Form(5), score_threshold: Optional[float] = Form(None), collection: Optional[str] = Form(None), hnsw_ef: Optional[int] = Form(None), exact: bool = Form(False), indexed_only: bool = Form(False), filter_metadata_json: Optional[str] = Form(None), ): """Embed an uploaded image via CLIP, then search Qdrant for similar vectors.""" import json filter_metadata: Dict[str, Any] = {} if filter_metadata_json: try: filter_metadata = json.loads(filter_metadata_json) except json.JSONDecodeError: raise HTTPException(400, "filter_metadata_json must be valid JSON") data = await file.read() vector = await _embed_bytes(data) return _do_search(vector, int(limit), score_threshold, collection, filter_metadata, hnsw_ef, exact, indexed_only) @app.post("/search/vector") def search_vector(req: SearchVectorRequest): """Search Qdrant using a pre-computed vector.""" return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only) def _do_search( vector: List[float], limit: int, score_threshold: Optional[float], collection: Optional[str], filter_metadata: Dict[str, Any], hnsw_ef: Optional[int] = None, exact: bool = False, indexed_only: bool = False, ): col = _col(collection) qfilter = _build_filter(filter_metadata) ef = hnsw_ef if hnsw_ef is not None else SEARCH_HNSW_EF results = client.query_points( collection_name=col, query=vector, limit=limit, score_threshold=score_threshold, query_filter=qfilter, search_params=SearchParams(hnsw_ef=ef, exact=exact, indexed_only=indexed_only), ) hits = [] for point in results.points: hits.append({ "id": point.id, "score": point.score, "metadata": point.payload, }) return {"results": hits, "collection": col, "count": len(hits)} # --------------------------------------------------------------------------- # Delete points # --------------------------------------------------------------------------- @app.post("/delete") def delete_points(req: DeleteRequest): col = _col(req.collection) client.delete( collection_name=col, points_selector=req.ids, ) return {"deleted": req.ids, "collection": col} # --------------------------------------------------------------------------- # Get point by ID # --------------------------------------------------------------------------- @app.get("/points/{point_id}") def get_point(point_id: str, collection: Optional[str] = None): col = _col(collection) try: points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True) if not points: raise HTTPException(404, f"Point '{point_id}' not found") p = points[0] return { "id": p.id, "vector": p.vector, "metadata": p.payload, "collection": col, } except HTTPException: raise except Exception as e: raise HTTPException(404, str(e)) @app.get("/points/by-original-id/{original_id}") def get_point_by_original_id(original_id: str, collection: Optional[str] = None): col = _col(collection) try: points, _ = client.scroll( collection_name=col, scroll_filter=_id_filter(original_id), limit=1, with_vectors=True, with_payload=True, ) if not points: raise HTTPException(404, f"Point with _original_id '{original_id}' not found") point = points[0] return { "id": point.id, "vector": point.vector, "metadata": point.payload, "collection": col, } except HTTPException: raise except Exception as e: raise HTTPException(404, str(e)) # --------------------------------------------------------------------------- # Payload index management # --------------------------------------------------------------------------- _SCHEMA_TYPE_MAP: Dict[str, PayloadSchemaType] = { t.value: t for t in PayloadSchemaType } def _resolve_schema_type(type_str: str) -> PayloadSchemaType: schema = _SCHEMA_TYPE_MAP.get(type_str.lower()) if schema is None: raise HTTPException(400, f"Unknown index type '{type_str}'. Valid: {', '.join(_SCHEMA_TYPE_MAP)}") return schema class PayloadIndexRequest(BaseModel): field: str type: str = Field(default="keyword", description="keyword | integer | float | bool | geo | datetime | text | uuid") collection: Optional[str] = None class EnsureIndexesRequest(BaseModel): """List of field specs, each with 'field' and optional 'type' keys.""" fields: List[Dict[str, str]] collection: Optional[str] = None @app.get("/collections/{name}/indexes") def collection_indexes(name: str): """List all payload indexes for a collection.""" try: info = client.get_collection(name) schema = info.payload_schema or {} return { "collection": name, "indexes": { k: { "type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type), "points": v.points, } for k, v in schema.items() }, "count": len(schema), } except Exception as e: raise HTTPException(404, str(e)) @app.post("/collections/{name}/indexes") def create_index(name: str, req: PayloadIndexRequest): """Create a payload index on a single field.""" col = req.collection or name schema = _resolve_schema_type(req.type) try: client.create_payload_index( collection_name=col, field_name=req.field, field_schema=schema, ) return {"collection": col, "field": req.field, "type": req.type, "status": "created"} except Exception as e: raise HTTPException(500, str(e)) @app.post("/collections/{name}/ensure-indexes") def ensure_indexes(name: str, req: EnsureIndexesRequest): """Idempotently ensure payload indexes exist for a list of fields. Skips fields that are already indexed; only creates the missing ones. Example body: {"fields": [{"field": "is_public", "type": "bool"}, {"field": "category_id", "type": "integer"}]} """ col = req.collection or name try: info = client.get_collection(col) except Exception as e: raise HTTPException(404, str(e)) existing = set(info.payload_schema.keys()) if info.payload_schema else set() created: List[str] = [] skipped: List[str] = [] for field_spec in req.fields: field = field_spec.get("field") type_str = field_spec.get("type", "keyword") if not field: raise HTTPException(400, "Each field spec must include a 'field' key") if field in existing: skipped.append(field) continue schema = _resolve_schema_type(type_str) try: client.create_payload_index( collection_name=col, field_name=field, field_schema=schema, ) created.append(field) except Exception as exc: raise HTTPException(500, f"Failed to index '{field}': {exc}") return {"collection": col, "created": created, "skipped": skipped} # --------------------------------------------------------------------------- # Collection HNSW + optimizer configuration # --------------------------------------------------------------------------- class CollectionConfigRequest(BaseModel): hnsw_m: Optional[int] = Field(default=None, ge=4, le=64, description="Edges per node in the HNSW graph.") hnsw_ef_construct: Optional[int] = Field(default=None, ge=10, le=1000, description="ef during index construction. Changes apply to new segments only.") hnsw_on_disk: Optional[bool] = Field(default=None, description="Store HNSW graph on disk (saves RAM, slightly slower queries).") indexing_threshold: Optional[int] = Field(default=None, ge=0, description="Min payload changes before a segment is indexed.") default_segment_number: Optional[int] = Field(default=None, ge=1, le=32, description="Target number of segments for parallelism.") # Scalar quantization — reduces RAM ~4x, often speeds up search on large collections. # Set quantization_type='int8' to enable. Use always_ram=True to keep quantized # vectors in RAM (recommended on VPS with limited memory but fast disk). quantization_type: Optional[str] = Field(default=None, description="Enable scalar quantization: 'int8'. Set to null to keep current setting.") quantization_quantile: float = Field(default=0.99, ge=0.5, le=1.0, description="Fraction of vectors used to calibrate quantization range (0.99 recommended).") quantization_always_ram: bool = Field(default=True, description="Keep quantized vectors in RAM even when raw vectors are on disk.") @app.post("/collections/{name}/configure") def configure_collection(name: str, req: CollectionConfigRequest): """Apply HNSW and optimizer configuration updates to an existing collection. Changes are applied in-place without data loss or re-ingestion. Note: hnsw_m and hnsw_ef_construct only affect newly created segments. """ hnsw_kwargs = {k: v for k, v in { "m": req.hnsw_m, "ef_construct": req.hnsw_ef_construct, "on_disk": req.hnsw_on_disk, }.items() if v is not None} opt_kwargs = {k: v for k, v in { "indexing_threshold": req.indexing_threshold, "default_segment_number": req.default_segment_number, }.items() if v is not None} # Build optional scalar quantization config quant_config = None if req.quantization_type is not None: if req.quantization_type.lower() != "int8": raise HTTPException(400, f"Unsupported quantization_type '{req.quantization_type}'. Only 'int8' is supported.") quant_config = ScalarQuantizationConfig( type=ScalarType.INT8, quantile=req.quantization_quantile, always_ram=req.quantization_always_ram, ) if not hnsw_kwargs and not opt_kwargs and quant_config is None: raise HTTPException(400, "No configuration fields provided") try: client.update_collection( collection_name=name, hnsw_config=HnswConfigDiff(**hnsw_kwargs) if hnsw_kwargs else None, optimizers_config=OptimizersConfigDiff(**opt_kwargs) if opt_kwargs else None, quantization_config=quant_config, ) return { "collection": name, "status": "updated", "hnsw_changes": hnsw_kwargs, "optimizer_changes": opt_kwargs, "quantization": {"type": req.quantization_type, "quantile": req.quantization_quantile, "always_ram": req.quantization_always_ram} if quant_config else None, } except Exception as exc: raise HTTPException(500, str(exc))