llm: add FastAPI shim, gateway LLM endpoints, tests, and docs
This commit is contained in:
244
qdrant/backfill_payloads.py
Normal file
244
qdrant/backfill_payloads.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
backfill_payloads.py — Repair missing payload fields for existing Qdrant points.
|
||||
|
||||
WHY THIS EXISTS
|
||||
---------------
|
||||
If artworks were initially upserted without the full payload (is_public, is_nsfw,
|
||||
category_id, content_type_id, is_deleted, status), those fields will have near-0%
|
||||
coverage in the payload index. This prevents filtered searches (e.g., is_public=true)
|
||||
from returning correct results.
|
||||
|
||||
This script scrolls through all points in the collection, detects which ones are
|
||||
missing the required fields, and lets you supply a lookup function that fetches the
|
||||
correct values from your source-of-truth (database, API, CSV, etc.).
|
||||
|
||||
HOW TO ADAPT
|
||||
------------
|
||||
1. Fill in `fetch_payloads_for_ids()` to return a dict mapping qdrant-point-id ->
|
||||
payload patch for each missing ID. The simplest approach is a SQL query to your
|
||||
Skinbase database using the `_original_id` stored in the Qdrant payload.
|
||||
|
||||
2. Run the script directly (no app container needed, just qdrant-client installed):
|
||||
|
||||
# Inside Docker network:
|
||||
docker exec -it vision-qdrant-svc-1 python /app/backfill_payloads.py
|
||||
|
||||
# Or from host with qdrant-client installed:
|
||||
pip install qdrant-client
|
||||
QDRANT_HOST=localhost QDRANT_PORT=6333 python qdrant/backfill_payloads.py
|
||||
|
||||
3. The script is resumable: it prints the last-processed offset ID so you can
|
||||
restart from where you left off by setting RESUME_OFFSET env var.
|
||||
|
||||
REQUIRED ENV VARS (all optional, sensible defaults for Docker Compose):
|
||||
QDRANT_HOST default: qdrant
|
||||
QDRANT_PORT default: 6333
|
||||
COLLECTION_NAME default: images
|
||||
BATCH_SIZE default: 256
|
||||
DRY_RUN default: 0 (set to 1 to only report, no writes)
|
||||
RESUME_OFFSET default: None (UUID or int of last seen point to skip to)
|
||||
|
||||
FIELDS CHECKED
|
||||
--------------
|
||||
user_id, is_public, is_nsfw, category_id, content_type_id, is_deleted, status
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger("backfill")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
|
||||
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "256"))
|
||||
DRY_RUN = os.getenv("DRY_RUN", "0") == "1"
|
||||
RESUME_OFFSET: Optional[str] = os.getenv("RESUME_OFFSET") # point id to continue from
|
||||
|
||||
# Fields that MUST be present in every point payload for filtered search to work.
|
||||
REQUIRED_FIELDS = [
|
||||
"user_id",
|
||||
"is_public",
|
||||
"is_nsfw",
|
||||
"category_id",
|
||||
"content_type_id",
|
||||
"is_deleted",
|
||||
"status",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TODO: implement this function to fetch correct payload values from your DB.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def fetch_payloads_for_ids(
|
||||
missing_ids: List[Any],
|
||||
original_ids: Dict[Any, str],
|
||||
) -> Dict[Any, Dict[str, Any]]:
|
||||
"""Return a mapping of qdrant_point_id -> payload_patch for the given IDs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
missing_ids:
|
||||
List of Qdrant point IDs (UUID strings or ints) that need patching.
|
||||
original_ids:
|
||||
Dict mapping qdrant_point_id -> original application ID (stored in
|
||||
`_original_id` payload field, or the point id itself if they match).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict mapping each point id to a dict of fields to set.
|
||||
Only include the fields you want to SET — existing fields are not cleared.
|
||||
|
||||
Example implementation (pseudo-code for your database):
|
||||
|
||||
import psycopg2
|
||||
conn = psycopg2.connect(os.environ["DATABASE_URL"])
|
||||
cur = conn.cursor()
|
||||
orig_id_list = list(original_ids.values())
|
||||
cur.execute(
|
||||
"SELECT id, user_id, is_public, is_nsfw, category_id, "
|
||||
" content_type_id, is_deleted, status "
|
||||
"FROM artworks WHERE id = ANY(%s)",
|
||||
(orig_id_list,)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
by_orig = {str(r[0]): r for r in rows}
|
||||
result = {}
|
||||
for qdrant_id, orig_id in original_ids.items():
|
||||
row = by_orig.get(str(orig_id))
|
||||
if row:
|
||||
result[qdrant_id] = {
|
||||
"user_id": str(row[1]),
|
||||
"is_public": bool(row[2]),
|
||||
"is_nsfw": bool(row[3]),
|
||||
"category_id": int(row[4]) if row[4] is not None else None,
|
||||
"content_type_id": int(row[5]) if row[5] is not None else None,
|
||||
"is_deleted": bool(row[6]),
|
||||
"status": str(row[7]),
|
||||
}
|
||||
return result
|
||||
"""
|
||||
# ---- STUB: replace with your real implementation ----
|
||||
log.warning(
|
||||
"fetch_payloads_for_ids() is a stub — no data will be patched.\n"
|
||||
"Edit qdrant/backfill_payloads.py and implement this function."
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core backfill logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_backfill():
|
||||
log.info(
|
||||
"backfill start collection=%s host=%s:%s dry_run=%s batch=%d",
|
||||
COLLECTION_NAME, QDRANT_HOST, QDRANT_PORT, DRY_RUN, BATCH_SIZE,
|
||||
)
|
||||
|
||||
qclient = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||||
|
||||
# Verify collection exists
|
||||
collections = [c.name for c in qclient.get_collections().collections]
|
||||
if COLLECTION_NAME not in collections:
|
||||
log.error("Collection '%s' not found. Existing: %s", COLLECTION_NAME, collections)
|
||||
sys.exit(1)
|
||||
|
||||
info = qclient.get_collection(COLLECTION_NAME)
|
||||
total_points = info.points_count or 0
|
||||
log.info("collection points_count=%d indexed_vectors=%d", total_points, info.indexed_vectors_count or 0)
|
||||
|
||||
offset = RESUME_OFFSET
|
||||
scanned = 0
|
||||
missing_count = 0
|
||||
patched = 0
|
||||
errors = 0
|
||||
t_start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
points, next_offset = qclient.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
offset=offset,
|
||||
limit=BATCH_SIZE,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
scanned += len(points)
|
||||
|
||||
# Find points missing any required field
|
||||
needs_patch: List[Any] = []
|
||||
original_ids: Dict[Any, str] = {}
|
||||
for pt in points:
|
||||
payload = pt.payload or {}
|
||||
missing = [f for f in REQUIRED_FIELDS if f not in payload or payload[f] is None]
|
||||
if missing:
|
||||
needs_patch.append(pt.id)
|
||||
# Use _original_id if present (IDs that couldn't be stored as Qdrant IDs)
|
||||
original_ids[pt.id] = str(payload.get("_original_id", pt.id))
|
||||
missing_count += 1
|
||||
|
||||
if needs_patch:
|
||||
patches = fetch_payloads_for_ids(needs_patch, original_ids)
|
||||
for pid, patch in patches.items():
|
||||
if not patch:
|
||||
continue
|
||||
if DRY_RUN:
|
||||
log.info("[DRY RUN] would patch id=%s fields=%s", pid, list(patch.keys()))
|
||||
else:
|
||||
try:
|
||||
qclient.set_payload(
|
||||
collection_name=COLLECTION_NAME,
|
||||
payload=patch,
|
||||
points=[pid],
|
||||
)
|
||||
patched += 1
|
||||
except Exception as exc:
|
||||
log.error("failed to patch id=%s: %s", pid, exc)
|
||||
errors += 1
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
rate = scanned / elapsed if elapsed > 0 else 0
|
||||
log.info(
|
||||
"progress scanned=%d/%d missing=%d patched=%d errors=%d rate=%.0f/s offset=%s",
|
||||
scanned, total_points, missing_count, patched, errors, rate, next_offset,
|
||||
)
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
log.info(
|
||||
"backfill complete scanned=%d missing=%d patched=%d errors=%d elapsed=%.1fs",
|
||||
scanned, missing_count, patched, errors, elapsed,
|
||||
)
|
||||
|
||||
if missing_count > 0 and patched == 0 and not DRY_RUN:
|
||||
log.warning(
|
||||
"%d points are missing payload fields but 0 were patched. "
|
||||
"Implement fetch_payloads_for_ids() in this script.",
|
||||
missing_count,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_backfill()
|
||||
130
qdrant/main.py
130
qdrant/main.py
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -39,6 +42,9 @@ SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128"))
|
||||
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
|
||||
client: QdrantClient = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger("qdrant_svc")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup / shutdown
|
||||
@@ -47,8 +53,24 @@ client: QdrantClient = None # type: ignore[assignment]
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
global client
|
||||
t0 = time.perf_counter()
|
||||
logger.info("qdrant_svc startup: connecting to %s:%s", QDRANT_HOST, QDRANT_PORT)
|
||||
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||||
_ensure_collection()
|
||||
# Warm the gRPC/HTTP connection and load collection metadata into memory
|
||||
# so the first real request does not pay the one-time connect cost.
|
||||
try:
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
logger.info(
|
||||
"qdrant_svc startup: warm ping OK collection=%s points=%s indexed=%s elapsed_ms=%.1f",
|
||||
COLLECTION_NAME,
|
||||
info.points_count,
|
||||
info.indexed_vectors_count,
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("qdrant_svc startup: warm ping failed (non-fatal): %s", exc)
|
||||
logger.info("qdrant_svc startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
|
||||
|
||||
def _ensure_collection():
|
||||
@@ -68,6 +90,44 @@ def _ensure_collection():
|
||||
default_segment_number=4, # parallelism-friendly segment count
|
||||
),
|
||||
)
|
||||
_ensure_payload_indexes()
|
||||
|
||||
|
||||
# Payload fields needed for filtered search. type values match PayloadSchemaType.
|
||||
_REQUIRED_PAYLOAD_INDEXES: List[Dict[str, str]] = [
|
||||
{"field": "user_id", "type": "keyword"},
|
||||
{"field": "is_public", "type": "bool"},
|
||||
{"field": "is_nsfw", "type": "bool"},
|
||||
{"field": "is_deleted", "type": "bool"},
|
||||
{"field": "status", "type": "keyword"},
|
||||
{"field": "category_id", "type": "integer"},
|
||||
{"field": "content_type_id", "type": "integer"},
|
||||
]
|
||||
|
||||
|
||||
def _ensure_payload_indexes():
|
||||
"""Create any missing payload indexes for the default collection."""
|
||||
try:
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
except Exception:
|
||||
return # collection doesn't exist yet, will be created next
|
||||
existing = set(info.payload_schema.keys()) if info.payload_schema else set()
|
||||
for spec in _REQUIRED_PAYLOAD_INDEXES:
|
||||
field = spec["field"]
|
||||
if field in existing:
|
||||
continue
|
||||
schema = _SCHEMA_TYPE_MAP.get(spec["type"])
|
||||
if schema is None:
|
||||
continue
|
||||
try:
|
||||
client.create_payload_index(
|
||||
collection_name=COLLECTION_NAME,
|
||||
field_name=field,
|
||||
field_schema=schema,
|
||||
)
|
||||
logger.info("_ensure_payload_indexes: created index field=%s type=%s", field, spec["type"])
|
||||
except Exception as exc:
|
||||
logger.warning("_ensure_payload_indexes: could not index field=%s: %s", field, exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -213,23 +273,31 @@ def health():
|
||||
|
||||
|
||||
@app.get("/inspect")
|
||||
def inspect():
|
||||
async def inspect():
|
||||
"""Return a full diagnostic summary for every collection.
|
||||
|
||||
Covers: vector counts, segment counts, HNSW config, optimizer config,
|
||||
quantization, payload indexes and their coverage. Designed for production
|
||||
health checks and the Qdrant optimization workflow.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
logger.info("inspect: start")
|
||||
|
||||
try:
|
||||
all_collections = client.get_collections().collections
|
||||
all_collections = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: client.get_collections().collections
|
||||
)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
result = {}
|
||||
for col_desc in all_collections:
|
||||
name = col_desc.name
|
||||
t_col = time.perf_counter()
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
info = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda n=name: client.get_collection(n)
|
||||
)
|
||||
cfg = info.config
|
||||
hnsw = cfg.hnsw_config
|
||||
opt = cfg.optimizer_config
|
||||
@@ -281,9 +349,14 @@ def inspect():
|
||||
"payload_index_count": len(info.payload_schema or {}),
|
||||
"search_hnsw_ef": SEARCH_HNSW_EF,
|
||||
}
|
||||
logger.info(
|
||||
"inspect: collection=%s points=%s elapsed_ms=%.1f",
|
||||
name, points_count, (time.perf_counter() - t_col) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
result[name] = {"error": str(exc)}
|
||||
|
||||
logger.info("inspect: done collections=%d total_elapsed_ms=%.1f", len(result), (time.perf_counter() - t0) * 1000)
|
||||
return {"collections": result, "total": len(result)}
|
||||
|
||||
|
||||
@@ -757,3 +830,54 @@ def configure_collection(name: str, req: CollectionConfigRequest):
|
||||
}
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload update (used by backfill / repair tooling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BatchUpdatePayloadRequest(BaseModel):
|
||||
"""Update payload fields for a batch of points identified by their Qdrant IDs.
|
||||
|
||||
``updates`` is a list of ``{"id": "<qdrant-point-id>", "payload": {...}}`` items.
|
||||
Only the supplied payload keys are merged into existing payloads (set_payload
|
||||
semantics — existing keys not mentioned are left untouched).
|
||||
"""
|
||||
updates: List[Dict[str, Any]]
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/points/batch-update-payload")
|
||||
def batch_update_payload(req: BatchUpdatePayloadRequest):
|
||||
"""Merge payload fields for a list of points without touching vectors.
|
||||
|
||||
Useful for backfilling metadata (is_public, category_id, etc.) for points
|
||||
that were upserted without full payload coverage.
|
||||
"""
|
||||
if not req.updates:
|
||||
return {"updated": 0, "collection": _col(req.collection)}
|
||||
|
||||
col = _col(req.collection)
|
||||
updated = 0
|
||||
errors: List[str] = []
|
||||
|
||||
for item in req.updates:
|
||||
pid_raw = item.get("id")
|
||||
payload = item.get("payload", {})
|
||||
if pid_raw is None or not payload:
|
||||
continue
|
||||
pid = _point_id(str(pid_raw))
|
||||
try:
|
||||
client.set_payload(
|
||||
collection_name=col,
|
||||
payload=payload,
|
||||
points=[pid],
|
||||
)
|
||||
updated += 1
|
||||
except Exception as exc:
|
||||
errors.append(f"{pid_raw}: {exc}")
|
||||
|
||||
result: Dict[str, Any] = {"updated": updated, "collection": col}
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user