760 lines
28 KiB
Python
760 lines
28 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
|
from pydantic import BaseModel, Field
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
Distance,
|
|
PointStruct,
|
|
VectorParams,
|
|
Filter,
|
|
FieldCondition,
|
|
MatchValue,
|
|
HnswConfigDiff,
|
|
OptimizersConfigDiff,
|
|
SearchParams,
|
|
PayloadSchemaType,
|
|
ScalarQuantizationConfig,
|
|
ScalarType,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
|
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
|
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
|
|
VECTOR_DIM = int(os.getenv("VECTOR_DIM", "512"))
|
|
# hnsw_ef at query time: higher = better recall, slightly more latency (Qdrant default ~100)
|
|
SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128"))
|
|
|
|
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
|
|
client: QdrantClient = None # type: ignore[assignment]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Startup / shutdown
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.on_event("startup")
|
|
def startup():
|
|
global client
|
|
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
|
_ensure_collection()
|
|
|
|
|
|
def _ensure_collection():
|
|
"""Create the default collection with production-friendly defaults if it does not exist yet."""
|
|
collections = [c.name for c in client.get_collections().collections]
|
|
if COLLECTION_NAME not in collections:
|
|
client.create_collection(
|
|
collection_name=COLLECTION_NAME,
|
|
vectors_config=VectorParams(size=VECTOR_DIM, distance=Distance.COSINE),
|
|
hnsw_config=HnswConfigDiff(
|
|
m=16,
|
|
ef_construct=200, # higher than default 100 = better index quality
|
|
on_disk=False, # keep HNSW graph in RAM for fast traversal
|
|
),
|
|
optimizers_config=OptimizersConfigDiff(
|
|
indexing_threshold=20000, # start indexing after 20k accumulated vectors
|
|
default_segment_number=4, # parallelism-friendly segment count
|
|
),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / Response models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class UpsertUrlRequest(BaseModel):
|
|
url: str
|
|
id: Optional[str] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class UpsertVectorRequest(BaseModel):
|
|
vector: List[float]
|
|
id: Optional[str] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class SearchUrlRequest(BaseModel):
|
|
url: str
|
|
limit: int = Field(default=5, ge=1, le=100)
|
|
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
collection: Optional[str] = None
|
|
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512, description="Override ef at query time. Higher = better recall, slightly higher latency.")
|
|
exact: bool = Field(default=False, description="Brute-force exact search. Avoid on large collections.")
|
|
indexed_only: bool = Field(default=False, description="Search only fully indexed segments. Useful during bulk ingest.")
|
|
|
|
|
|
class SearchVectorRequest(BaseModel):
|
|
vector: List[float]
|
|
limit: int = Field(default=5, ge=1, le=100)
|
|
score_threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
|
collection: Optional[str] = None
|
|
filter_metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
hnsw_ef: Optional[int] = Field(default=None, ge=1, le=512)
|
|
exact: bool = False
|
|
indexed_only: bool = False
|
|
|
|
|
|
class DeleteRequest(BaseModel):
|
|
ids: List[str]
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class CollectionRequest(BaseModel):
|
|
name: str
|
|
vector_dim: int = Field(default=512, ge=1)
|
|
distance: str = Field(default="cosine")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _col(name: Optional[str]) -> str:
|
|
return name or COLLECTION_NAME
|
|
|
|
|
|
async def _embed_url(url: str) -> List[float]:
|
|
"""Call the CLIP service to get an image embedding."""
|
|
async with httpx.AsyncClient(timeout=30) as http:
|
|
try:
|
|
r = await http.post(f"{CLIP_URL}/embed", json={"url": url})
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(502, f"CLIP request failed: {str(e)}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(502, f"CLIP /embed error: {r.status_code} {r.text[:200]}")
|
|
try:
|
|
return r.json()["vector"]
|
|
except Exception:
|
|
raise HTTPException(502, f"CLIP /embed returned non-JSON: {r.status_code} {r.text[:200]}")
|
|
|
|
|
|
async def _embed_bytes(data: bytes) -> List[float]:
|
|
"""Call the CLIP service to embed uploaded file bytes."""
|
|
async with httpx.AsyncClient(timeout=30) as http:
|
|
files = {"file": ("image", data, "application/octet-stream")}
|
|
try:
|
|
r = await http.post(f"{CLIP_URL}/embed/file", files=files)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(502, f"CLIP request failed: {str(e)}")
|
|
if r.status_code >= 400:
|
|
raise HTTPException(502, f"CLIP /embed/file error: {r.status_code} {r.text[:200]}")
|
|
try:
|
|
return r.json()["vector"]
|
|
except Exception:
|
|
raise HTTPException(502, f"CLIP /embed/file returned non-JSON: {r.status_code} {r.text[:200]}")
|
|
|
|
|
|
def _build_filter(metadata: Dict[str, Any]) -> Optional[Filter]:
|
|
if not metadata:
|
|
return None
|
|
conditions = [
|
|
FieldCondition(key=k, match=MatchValue(value=v))
|
|
for k, v in metadata.items()
|
|
]
|
|
return Filter(must=conditions)
|
|
|
|
|
|
def _id_filter(original_id: str) -> Filter:
|
|
return Filter(must=[FieldCondition(key="_original_id", match=MatchValue(value=original_id))])
|
|
|
|
|
|
def _point_id(raw: Optional[str]) -> str:
|
|
"""Return a Qdrant-compatible point id.
|
|
|
|
Qdrant accepts either an unsigned integer or a UUID string (with hyphens).
|
|
If the provided `raw` value is an int or valid UUID we return it (as int or str).
|
|
Otherwise we generate a new UUID string and the caller should store the
|
|
original `raw` value in the point payload under `_original_id`.
|
|
"""
|
|
if not raw:
|
|
return str(uuid.uuid4())
|
|
# allow integer ids
|
|
try:
|
|
return int(raw)
|
|
except Exception:
|
|
pass
|
|
# allow UUID strings
|
|
try:
|
|
u = uuid.UUID(raw)
|
|
return str(u)
|
|
except Exception:
|
|
# fallback: generate a UUID
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Health
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
try:
|
|
info = client.get_collections()
|
|
names = [c.name for c in info.collections]
|
|
return {"status": "ok", "qdrant": QDRANT_HOST, "collections": names}
|
|
except Exception as e:
|
|
return {"status": "error", "detail": str(e)}
|
|
|
|
|
|
@app.get("/inspect")
|
|
def inspect():
|
|
"""Return a full diagnostic summary for every collection.
|
|
|
|
Covers: vector counts, segment counts, HNSW config, optimizer config,
|
|
quantization, payload indexes and their coverage. Designed for production
|
|
health checks and the Qdrant optimization workflow.
|
|
"""
|
|
try:
|
|
all_collections = client.get_collections().collections
|
|
except Exception as exc:
|
|
return {"status": "error", "detail": str(exc)}
|
|
|
|
result = {}
|
|
for col_desc in all_collections:
|
|
name = col_desc.name
|
|
try:
|
|
info = client.get_collection(name)
|
|
cfg = info.config
|
|
hnsw = cfg.hnsw_config
|
|
opt = cfg.optimizer_config
|
|
quant = cfg.quantization_config
|
|
params = cfg.params
|
|
|
|
# `vectors_count` is deprecated and returns 0 in newer Qdrant versions;
|
|
# use `points_count` as the canonical count for coverage and RAM estimates.
|
|
points_count = info.points_count or 0
|
|
vec_count = info.vectors_count or points_count # kept for backwards compat display
|
|
vec_dim = (
|
|
params.vectors.size
|
|
if hasattr(params.vectors, "size")
|
|
else VECTOR_DIM
|
|
)
|
|
ram_estimate_mb = round(points_count * vec_dim * 4 * 1.5 / 1_048_576, 1)
|
|
|
|
result[name] = {
|
|
"status": info.status.value if info.status else None,
|
|
"optimizer_status": str(info.optimizer_status) if info.optimizer_status else None,
|
|
"vectors_count": vec_count,
|
|
"indexed_vectors_count": info.indexed_vectors_count,
|
|
"points_count": info.points_count,
|
|
"segments_count": info.segments_count,
|
|
"ram_estimate_mb": ram_estimate_mb,
|
|
"hnsw": {
|
|
"m": hnsw.m,
|
|
"ef_construct": hnsw.ef_construct,
|
|
"on_disk": hnsw.on_disk,
|
|
"full_scan_threshold": hnsw.full_scan_threshold,
|
|
"max_indexing_threads": hnsw.max_indexing_threads,
|
|
} if hnsw else None,
|
|
"optimizer": {
|
|
"indexing_threshold": opt.indexing_threshold,
|
|
"default_segment_number": opt.default_segment_number,
|
|
"max_segment_size": opt.max_segment_size,
|
|
"memmap_threshold": opt.memmap_threshold,
|
|
"flush_interval_sec": opt.flush_interval_sec,
|
|
} if opt else None,
|
|
"quantization": str(quant) if quant else None,
|
|
"payload_indexes": {
|
|
k: {
|
|
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
|
|
"points": v.points,
|
|
"coverage_pct": round(v.points / max(points_count, 1) * 100, 1),
|
|
}
|
|
for k, v in (info.payload_schema or {}).items()
|
|
},
|
|
"payload_index_count": len(info.payload_schema or {}),
|
|
"search_hnsw_ef": SEARCH_HNSW_EF,
|
|
}
|
|
except Exception as exc:
|
|
result[name] = {"error": str(exc)}
|
|
|
|
return {"collections": result, "total": len(result)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Collection management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/collections")
|
|
def create_collection(req: CollectionRequest):
|
|
dist_map = {"cosine": Distance.COSINE, "euclid": Distance.EUCLID, "dot": Distance.DOT}
|
|
dist = dist_map.get(req.distance.lower())
|
|
if dist is None:
|
|
raise HTTPException(400, f"Unknown distance: {req.distance}. Use cosine, euclid, or dot.")
|
|
|
|
collections = [c.name for c in client.get_collections().collections]
|
|
if req.name in collections:
|
|
raise HTTPException(409, f"Collection '{req.name}' already exists")
|
|
|
|
# Apply the same production defaults as _ensure_collection so all
|
|
# collections start with tuned HNSW and optimizer settings.
|
|
client.create_collection(
|
|
collection_name=req.name,
|
|
vectors_config=VectorParams(size=req.vector_dim, distance=dist),
|
|
hnsw_config=HnswConfigDiff(m=16, ef_construct=200, on_disk=False),
|
|
optimizers_config=OptimizersConfigDiff(indexing_threshold=20000, default_segment_number=4),
|
|
)
|
|
return {"created": req.name, "vector_dim": req.vector_dim, "distance": req.distance}
|
|
|
|
|
|
@app.get("/collections")
|
|
def list_collections():
|
|
info = client.get_collections()
|
|
return {"collections": [c.name for c in info.collections]}
|
|
|
|
|
|
@app.get("/collections/{name}")
|
|
def collection_info(name: str):
|
|
try:
|
|
info = client.get_collection(name)
|
|
cfg = info.config
|
|
hnsw = cfg.hnsw_config
|
|
opt = cfg.optimizer_config
|
|
quant = cfg.quantization_config
|
|
return {
|
|
"name": name,
|
|
"vectors_count": info.vectors_count,
|
|
"indexed_vectors_count": info.indexed_vectors_count,
|
|
"points_count": info.points_count,
|
|
"segments_count": info.segments_count,
|
|
"status": info.status.value if info.status else None,
|
|
"optimizer_status": str(info.optimizer_status) if info.optimizer_status else None,
|
|
"hnsw": {
|
|
"m": hnsw.m,
|
|
"ef_construct": hnsw.ef_construct,
|
|
"on_disk": hnsw.on_disk,
|
|
"full_scan_threshold": hnsw.full_scan_threshold,
|
|
"max_indexing_threads": hnsw.max_indexing_threads,
|
|
} if hnsw else None,
|
|
"optimizer": {
|
|
"indexing_threshold": opt.indexing_threshold,
|
|
"default_segment_number": opt.default_segment_number,
|
|
"max_segment_size": opt.max_segment_size,
|
|
"memmap_threshold": opt.memmap_threshold,
|
|
"flush_interval_sec": opt.flush_interval_sec,
|
|
} if opt else None,
|
|
"quantization": str(quant) if quant else None,
|
|
"payload_schema": {
|
|
k: {
|
|
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
|
|
"points": v.points,
|
|
}
|
|
for k, v in (info.payload_schema or {}).items()
|
|
},
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
|
|
@app.delete("/collections/{name}")
|
|
def delete_collection(name: str):
|
|
client.delete_collection(name)
|
|
return {"deleted": name}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Upsert endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/upsert")
|
|
async def upsert_url(req: UpsertUrlRequest):
|
|
"""Embed an image by URL via CLIP, then store the vector in Qdrant."""
|
|
vector = await _embed_url(req.url)
|
|
pid = _point_id(req.id)
|
|
payload = {**req.metadata, "source_url": req.url}
|
|
# preserve original user-provided id if it wasn't usable as a point id
|
|
if req.id is not None and str(pid) != str(req.id):
|
|
payload["_original_id"] = req.id
|
|
col = _col(req.collection)
|
|
|
|
try:
|
|
client.upsert(
|
|
collection_name=col,
|
|
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
|
|
return {"id": pid, "collection": col, "dim": len(vector)}
|
|
|
|
|
|
@app.post("/upsert/file")
|
|
async def upsert_file(
|
|
file: UploadFile = File(...),
|
|
id: Optional[str] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
metadata_json: Optional[str] = Form(None),
|
|
):
|
|
"""Embed an uploaded image via CLIP, then store the vector in Qdrant."""
|
|
import json
|
|
|
|
data = await file.read()
|
|
vector = await _embed_bytes(data)
|
|
pid = _point_id(id)
|
|
|
|
payload: Dict[str, Any] = {}
|
|
if metadata_json:
|
|
try:
|
|
payload = json.loads(metadata_json)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(400, "metadata_json must be valid JSON")
|
|
# preserve original user-provided id if it wasn't usable as a point id
|
|
if id is not None and str(pid) != str(id):
|
|
payload["_original_id"] = id
|
|
|
|
col = _col(collection)
|
|
try:
|
|
client.upsert(
|
|
collection_name=col,
|
|
points=[PointStruct(id=pid, vector=vector, payload=payload)],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
|
|
return {"id": pid, "collection": col, "dim": len(vector)}
|
|
|
|
|
|
@app.post("/upsert/vector")
|
|
def upsert_vector(req: UpsertVectorRequest):
|
|
"""Store a pre-computed vector directly (skip CLIP embedding)."""
|
|
pid = _point_id(req.id)
|
|
col = _col(req.collection)
|
|
payload = dict(req.metadata or {})
|
|
if req.id is not None and str(pid) != str(req.id):
|
|
payload["_original_id"] = req.id
|
|
try:
|
|
client.upsert(
|
|
collection_name=col,
|
|
points=[PointStruct(id=pid, vector=req.vector, payload=payload)],
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
return {"id": pid, "collection": col, "dim": len(req.vector)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Search endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/search")
|
|
async def search_url(req: SearchUrlRequest):
|
|
"""Embed an image by URL via CLIP, then search Qdrant for similar vectors."""
|
|
vector = await _embed_url(req.url)
|
|
return _do_search(vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only)
|
|
|
|
|
|
@app.post("/search/file")
|
|
async def search_file(
|
|
file: UploadFile = File(...),
|
|
limit: int = Form(5),
|
|
score_threshold: Optional[float] = Form(None),
|
|
collection: Optional[str] = Form(None),
|
|
hnsw_ef: Optional[int] = Form(None),
|
|
exact: bool = Form(False),
|
|
indexed_only: bool = Form(False),
|
|
filter_metadata_json: Optional[str] = Form(None),
|
|
):
|
|
"""Embed an uploaded image via CLIP, then search Qdrant for similar vectors."""
|
|
import json
|
|
filter_metadata: Dict[str, Any] = {}
|
|
if filter_metadata_json:
|
|
try:
|
|
filter_metadata = json.loads(filter_metadata_json)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(400, "filter_metadata_json must be valid JSON")
|
|
data = await file.read()
|
|
vector = await _embed_bytes(data)
|
|
return _do_search(vector, int(limit), score_threshold, collection, filter_metadata, hnsw_ef, exact, indexed_only)
|
|
|
|
|
|
@app.post("/search/vector")
|
|
def search_vector(req: SearchVectorRequest):
|
|
"""Search Qdrant using a pre-computed vector."""
|
|
return _do_search(req.vector, req.limit, req.score_threshold, req.collection, req.filter_metadata, req.hnsw_ef, req.exact, req.indexed_only)
|
|
|
|
|
|
def _do_search(
|
|
vector: List[float],
|
|
limit: int,
|
|
score_threshold: Optional[float],
|
|
collection: Optional[str],
|
|
filter_metadata: Dict[str, Any],
|
|
hnsw_ef: Optional[int] = None,
|
|
exact: bool = False,
|
|
indexed_only: bool = False,
|
|
):
|
|
col = _col(collection)
|
|
qfilter = _build_filter(filter_metadata)
|
|
ef = hnsw_ef if hnsw_ef is not None else SEARCH_HNSW_EF
|
|
|
|
results = client.query_points(
|
|
collection_name=col,
|
|
query=vector,
|
|
limit=limit,
|
|
score_threshold=score_threshold,
|
|
query_filter=qfilter,
|
|
search_params=SearchParams(hnsw_ef=ef, exact=exact, indexed_only=indexed_only),
|
|
)
|
|
|
|
hits = []
|
|
for point in results.points:
|
|
hits.append({
|
|
"id": point.id,
|
|
"score": point.score,
|
|
"metadata": point.payload,
|
|
})
|
|
|
|
return {"results": hits, "collection": col, "count": len(hits)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Delete points
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.post("/delete")
|
|
def delete_points(req: DeleteRequest):
|
|
col = _col(req.collection)
|
|
client.delete(
|
|
collection_name=col,
|
|
points_selector=req.ids,
|
|
)
|
|
return {"deleted": req.ids, "collection": col}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Get point by ID
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/points/{point_id}")
|
|
def get_point(point_id: str, collection: Optional[str] = None):
|
|
col = _col(collection)
|
|
try:
|
|
points = client.retrieve(collection_name=col, ids=[point_id], with_vectors=True)
|
|
if not points:
|
|
raise HTTPException(404, f"Point '{point_id}' not found")
|
|
p = points[0]
|
|
return {
|
|
"id": p.id,
|
|
"vector": p.vector,
|
|
"metadata": p.payload,
|
|
"collection": col,
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
|
|
@app.get("/points/by-original-id/{original_id}")
|
|
def get_point_by_original_id(original_id: str, collection: Optional[str] = None):
|
|
col = _col(collection)
|
|
try:
|
|
points, _ = client.scroll(
|
|
collection_name=col,
|
|
scroll_filter=_id_filter(original_id),
|
|
limit=1,
|
|
with_vectors=True,
|
|
with_payload=True,
|
|
)
|
|
if not points:
|
|
raise HTTPException(404, f"Point with _original_id '{original_id}' not found")
|
|
point = points[0]
|
|
return {
|
|
"id": point.id,
|
|
"vector": point.vector,
|
|
"metadata": point.payload,
|
|
"collection": col,
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Payload index management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SCHEMA_TYPE_MAP: Dict[str, PayloadSchemaType] = {
|
|
t.value: t for t in PayloadSchemaType
|
|
}
|
|
|
|
|
|
def _resolve_schema_type(type_str: str) -> PayloadSchemaType:
|
|
schema = _SCHEMA_TYPE_MAP.get(type_str.lower())
|
|
if schema is None:
|
|
raise HTTPException(400, f"Unknown index type '{type_str}'. Valid: {', '.join(_SCHEMA_TYPE_MAP)}")
|
|
return schema
|
|
|
|
|
|
class PayloadIndexRequest(BaseModel):
|
|
field: str
|
|
type: str = Field(default="keyword", description="keyword | integer | float | bool | geo | datetime | text | uuid")
|
|
collection: Optional[str] = None
|
|
|
|
|
|
class EnsureIndexesRequest(BaseModel):
|
|
"""List of field specs, each with 'field' and optional 'type' keys."""
|
|
fields: List[Dict[str, str]]
|
|
collection: Optional[str] = None
|
|
|
|
|
|
@app.get("/collections/{name}/indexes")
|
|
def collection_indexes(name: str):
|
|
"""List all payload indexes for a collection."""
|
|
try:
|
|
info = client.get_collection(name)
|
|
schema = info.payload_schema or {}
|
|
return {
|
|
"collection": name,
|
|
"indexes": {
|
|
k: {
|
|
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
|
|
"points": v.points,
|
|
}
|
|
for k, v in schema.items()
|
|
},
|
|
"count": len(schema),
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
|
|
@app.post("/collections/{name}/indexes")
|
|
def create_index(name: str, req: PayloadIndexRequest):
|
|
"""Create a payload index on a single field."""
|
|
col = req.collection or name
|
|
schema = _resolve_schema_type(req.type)
|
|
try:
|
|
client.create_payload_index(
|
|
collection_name=col,
|
|
field_name=req.field,
|
|
field_schema=schema,
|
|
)
|
|
return {"collection": col, "field": req.field, "type": req.type, "status": "created"}
|
|
except Exception as e:
|
|
raise HTTPException(500, str(e))
|
|
|
|
|
|
@app.post("/collections/{name}/ensure-indexes")
|
|
def ensure_indexes(name: str, req: EnsureIndexesRequest):
|
|
"""Idempotently ensure payload indexes exist for a list of fields.
|
|
|
|
Skips fields that are already indexed; only creates the missing ones.
|
|
Example body: {"fields": [{"field": "is_public", "type": "bool"}, {"field": "category_id", "type": "integer"}]}
|
|
"""
|
|
col = req.collection or name
|
|
try:
|
|
info = client.get_collection(col)
|
|
except Exception as e:
|
|
raise HTTPException(404, str(e))
|
|
|
|
existing = set(info.payload_schema.keys()) if info.payload_schema else set()
|
|
created: List[str] = []
|
|
skipped: List[str] = []
|
|
|
|
for field_spec in req.fields:
|
|
field = field_spec.get("field")
|
|
type_str = field_spec.get("type", "keyword")
|
|
if not field:
|
|
raise HTTPException(400, "Each field spec must include a 'field' key")
|
|
if field in existing:
|
|
skipped.append(field)
|
|
continue
|
|
schema = _resolve_schema_type(type_str)
|
|
try:
|
|
client.create_payload_index(
|
|
collection_name=col,
|
|
field_name=field,
|
|
field_schema=schema,
|
|
)
|
|
created.append(field)
|
|
except Exception as exc:
|
|
raise HTTPException(500, f"Failed to index '{field}': {exc}")
|
|
|
|
return {"collection": col, "created": created, "skipped": skipped}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Collection HNSW + optimizer configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class CollectionConfigRequest(BaseModel):
|
|
hnsw_m: Optional[int] = Field(default=None, ge=4, le=64, description="Edges per node in the HNSW graph.")
|
|
hnsw_ef_construct: Optional[int] = Field(default=None, ge=10, le=1000, description="ef during index construction. Changes apply to new segments only.")
|
|
hnsw_on_disk: Optional[bool] = Field(default=None, description="Store HNSW graph on disk (saves RAM, slightly slower queries).")
|
|
indexing_threshold: Optional[int] = Field(default=None, ge=0, description="Min payload changes before a segment is indexed.")
|
|
default_segment_number: Optional[int] = Field(default=None, ge=1, le=32, description="Target number of segments for parallelism.")
|
|
# Scalar quantization — reduces RAM ~4x, often speeds up search on large collections.
|
|
# Set quantization_type='int8' to enable. Use always_ram=True to keep quantized
|
|
# vectors in RAM (recommended on VPS with limited memory but fast disk).
|
|
quantization_type: Optional[str] = Field(default=None, description="Enable scalar quantization: 'int8'. Set to null to keep current setting.")
|
|
quantization_quantile: float = Field(default=0.99, ge=0.5, le=1.0, description="Fraction of vectors used to calibrate quantization range (0.99 recommended).")
|
|
quantization_always_ram: bool = Field(default=True, description="Keep quantized vectors in RAM even when raw vectors are on disk.")
|
|
|
|
|
|
@app.post("/collections/{name}/configure")
|
|
def configure_collection(name: str, req: CollectionConfigRequest):
|
|
"""Apply HNSW and optimizer configuration updates to an existing collection.
|
|
|
|
Changes are applied in-place without data loss or re-ingestion.
|
|
Note: hnsw_m and hnsw_ef_construct only affect newly created segments.
|
|
"""
|
|
hnsw_kwargs = {k: v for k, v in {
|
|
"m": req.hnsw_m,
|
|
"ef_construct": req.hnsw_ef_construct,
|
|
"on_disk": req.hnsw_on_disk,
|
|
}.items() if v is not None}
|
|
|
|
opt_kwargs = {k: v for k, v in {
|
|
"indexing_threshold": req.indexing_threshold,
|
|
"default_segment_number": req.default_segment_number,
|
|
}.items() if v is not None}
|
|
|
|
# Build optional scalar quantization config
|
|
quant_config = None
|
|
if req.quantization_type is not None:
|
|
if req.quantization_type.lower() != "int8":
|
|
raise HTTPException(400, f"Unsupported quantization_type '{req.quantization_type}'. Only 'int8' is supported.")
|
|
quant_config = ScalarQuantizationConfig(
|
|
type=ScalarType.INT8,
|
|
quantile=req.quantization_quantile,
|
|
always_ram=req.quantization_always_ram,
|
|
)
|
|
|
|
if not hnsw_kwargs and not opt_kwargs and quant_config is None:
|
|
raise HTTPException(400, "No configuration fields provided")
|
|
|
|
try:
|
|
client.update_collection(
|
|
collection_name=name,
|
|
hnsw_config=HnswConfigDiff(**hnsw_kwargs) if hnsw_kwargs else None,
|
|
optimizers_config=OptimizersConfigDiff(**opt_kwargs) if opt_kwargs else None,
|
|
quantization_config=quant_config,
|
|
)
|
|
return {
|
|
"collection": name,
|
|
"status": "updated",
|
|
"hnsw_changes": hnsw_kwargs,
|
|
"optimizer_changes": opt_kwargs,
|
|
"quantization": {"type": req.quantization_type, "quantile": req.quantization_quantile, "always_ram": req.quantization_always_ram} if quant_config else None,
|
|
}
|
|
except Exception as exc:
|
|
raise HTTPException(500, str(exc))
|