llm: add FastAPI shim, gateway LLM endpoints, tests, and docs

This commit is contained in:
2026-04-12 09:41:21 +02:00
parent baf497b015
commit 59c9584250
15 changed files with 1779 additions and 11 deletions

View File

@@ -1,6 +1,9 @@
from __future__ import annotations
import asyncio
import logging
import os
import time
import uuid
from typing import Any, Dict, List, Optional
@@ -39,6 +42,9 @@ SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128"))
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
client: QdrantClient = None # type: ignore[assignment]
logger = logging.getLogger("qdrant_svc")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# ---------------------------------------------------------------------------
# Startup / shutdown
@@ -47,8 +53,24 @@ client: QdrantClient = None # type: ignore[assignment]
@app.on_event("startup")
def startup():
global client
t0 = time.perf_counter()
logger.info("qdrant_svc startup: connecting to %s:%s", QDRANT_HOST, QDRANT_PORT)
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
_ensure_collection()
# Warm the gRPC/HTTP connection and load collection metadata into memory
# so the first real request does not pay the one-time connect cost.
try:
info = client.get_collection(COLLECTION_NAME)
logger.info(
"qdrant_svc startup: warm ping OK collection=%s points=%s indexed=%s elapsed_ms=%.1f",
COLLECTION_NAME,
info.points_count,
info.indexed_vectors_count,
(time.perf_counter() - t0) * 1000,
)
except Exception as exc:
logger.warning("qdrant_svc startup: warm ping failed (non-fatal): %s", exc)
logger.info("qdrant_svc startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
def _ensure_collection():
@@ -68,6 +90,44 @@ def _ensure_collection():
default_segment_number=4, # parallelism-friendly segment count
),
)
_ensure_payload_indexes()
# Payload fields needed for filtered search. type values match PayloadSchemaType.
_REQUIRED_PAYLOAD_INDEXES: List[Dict[str, str]] = [
{"field": "user_id", "type": "keyword"},
{"field": "is_public", "type": "bool"},
{"field": "is_nsfw", "type": "bool"},
{"field": "is_deleted", "type": "bool"},
{"field": "status", "type": "keyword"},
{"field": "category_id", "type": "integer"},
{"field": "content_type_id", "type": "integer"},
]
def _ensure_payload_indexes():
"""Create any missing payload indexes for the default collection."""
try:
info = client.get_collection(COLLECTION_NAME)
except Exception:
return # collection doesn't exist yet, will be created next
existing = set(info.payload_schema.keys()) if info.payload_schema else set()
for spec in _REQUIRED_PAYLOAD_INDEXES:
field = spec["field"]
if field in existing:
continue
schema = _SCHEMA_TYPE_MAP.get(spec["type"])
if schema is None:
continue
try:
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name=field,
field_schema=schema,
)
logger.info("_ensure_payload_indexes: created index field=%s type=%s", field, spec["type"])
except Exception as exc:
logger.warning("_ensure_payload_indexes: could not index field=%s: %s", field, exc)
# ---------------------------------------------------------------------------
@@ -213,23 +273,31 @@ def health():
@app.get("/inspect")
def inspect():
async def inspect():
"""Return a full diagnostic summary for every collection.
Covers: vector counts, segment counts, HNSW config, optimizer config,
quantization, payload indexes and their coverage. Designed for production
health checks and the Qdrant optimization workflow.
"""
t0 = time.perf_counter()
logger.info("inspect: start")
try:
all_collections = client.get_collections().collections
all_collections = await asyncio.get_event_loop().run_in_executor(
None, lambda: client.get_collections().collections
)
except Exception as exc:
return {"status": "error", "detail": str(exc)}
result = {}
for col_desc in all_collections:
name = col_desc.name
t_col = time.perf_counter()
try:
info = client.get_collection(name)
info = await asyncio.get_event_loop().run_in_executor(
None, lambda n=name: client.get_collection(n)
)
cfg = info.config
hnsw = cfg.hnsw_config
opt = cfg.optimizer_config
@@ -281,9 +349,14 @@ def inspect():
"payload_index_count": len(info.payload_schema or {}),
"search_hnsw_ef": SEARCH_HNSW_EF,
}
logger.info(
"inspect: collection=%s points=%s elapsed_ms=%.1f",
name, points_count, (time.perf_counter() - t_col) * 1000,
)
except Exception as exc:
result[name] = {"error": str(exc)}
logger.info("inspect: done collections=%d total_elapsed_ms=%.1f", len(result), (time.perf_counter() - t0) * 1000)
return {"collections": result, "total": len(result)}
@@ -757,3 +830,54 @@ def configure_collection(name: str, req: CollectionConfigRequest):
}
except Exception as exc:
raise HTTPException(500, str(exc))
# ---------------------------------------------------------------------------
# Payload update (used by backfill / repair tooling)
# ---------------------------------------------------------------------------
class BatchUpdatePayloadRequest(BaseModel):
"""Update payload fields for a batch of points identified by their Qdrant IDs.
``updates`` is a list of ``{"id": "<qdrant-point-id>", "payload": {...}}`` items.
Only the supplied payload keys are merged into existing payloads (set_payload
semantics — existing keys not mentioned are left untouched).
"""
updates: List[Dict[str, Any]]
collection: Optional[str] = None
@app.post("/points/batch-update-payload")
def batch_update_payload(req: BatchUpdatePayloadRequest):
"""Merge payload fields for a list of points without touching vectors.
Useful for backfilling metadata (is_public, category_id, etc.) for points
that were upserted without full payload coverage.
"""
if not req.updates:
return {"updated": 0, "collection": _col(req.collection)}
col = _col(req.collection)
updated = 0
errors: List[str] = []
for item in req.updates:
pid_raw = item.get("id")
payload = item.get("payload", {})
if pid_raw is None or not payload:
continue
pid = _point_id(str(pid_raw))
try:
client.set_payload(
collection_name=col,
payload=payload,
points=[pid],
)
updated += 1
except Exception as exc:
errors.append(f"{pid_raw}: {exc}")
result: Dict[str, Any] = {"updated": updated, "collection": col}
if errors:
result["errors"] = errors
return result