Compare commits
5 Commits
3f925e17d5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c64e0541 | |||
| bfa6a4ad26 | |||
| 59c9584250 | |||
| baf497b015 | |||
| f681ab980d |
14
.env.example
14
.env.example
@@ -7,6 +7,13 @@ CLIP_URL=http://clip:8000
|
||||
BLIP_URL=http://blip:8000
|
||||
YOLO_URL=http://yolo:8000
|
||||
QDRANT_SVC_URL=http://qdrant-svc:8000
|
||||
LLM_URL=http://llm:8080
|
||||
LLM_ENABLED=false
|
||||
LLM_TIMEOUT=120
|
||||
LLM_DEFAULT_MODEL=qwen3-1.7b-instruct-q4_k_m
|
||||
LLM_MAX_TOKENS_DEFAULT=256
|
||||
LLM_MAX_TOKENS_HARD_LIMIT=1024
|
||||
LLM_MAX_REQUEST_BYTES=65536
|
||||
|
||||
# HuggingFace token for private/gated models (optional). Leave empty if unused.
|
||||
# Never commit a real token to this file.
|
||||
@@ -21,3 +28,10 @@ VECTOR_DIM=512
|
||||
# Gateway runtime
|
||||
VISION_TIMEOUT=300
|
||||
MAX_IMAGE_BYTES=52428800
|
||||
|
||||
# Local llama.cpp LLM service (only needed when you run the llm profile locally)
|
||||
MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf
|
||||
LLM_CONTEXT_SIZE=4096
|
||||
LLM_THREADS=4
|
||||
LLM_GPU_LAYERS=0
|
||||
LLM_EXTRA_ARGS=
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,6 +49,7 @@ qdrant_data/
|
||||
*.pth
|
||||
*.bin
|
||||
*.ckpt
|
||||
*.gguf
|
||||
|
||||
# Numpy arrays
|
||||
*.npy
|
||||
|
||||
190
README.md
190
README.md
@@ -1,7 +1,7 @@
|
||||
# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant + Card Renderer) – Dockerized FastAPI
|
||||
# Skinbase Vision Stack (CLIP + BLIP + YOLO + Qdrant + Card Renderer + Maturity + LLM) – Dockerized FastAPI
|
||||
|
||||
This repository provides **five standalone vision services** (CLIP / BLIP / YOLO / Qdrant / Card Renderer)
|
||||
and a **Gateway API** that can call them individually or together.
|
||||
This repository provides internal AI services for image analysis, vector search, card rendering, moderation,
|
||||
and text generation behind a single **Gateway API**.
|
||||
|
||||
## Services & Ports
|
||||
|
||||
@@ -12,6 +12,8 @@ and a **Gateway API** that can call them individually or together.
|
||||
- `qdrant`: vector DB (port `6333` exposed for direct access)
|
||||
- `qdrant-svc`: internal Qdrant API wrapper
|
||||
- `card-renderer`: internal card rendering service
|
||||
- `maturity`: internal NSFW/maturity classifier service
|
||||
- `llm`: internal text-generation service using a thin FastAPI shim over `llama-server` (profile-based, internal only)
|
||||
|
||||
## Run
|
||||
|
||||
@@ -19,6 +21,16 @@ and a **Gateway API** that can call them individually or together.
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
That starts the default vision stack only. The LLM service is disabled by default so operators are not forced to run Qwen3 on the same host.
|
||||
|
||||
To also start the local llama.cpp service:
|
||||
|
||||
```bash
|
||||
docker compose --profile llm up -d --build
|
||||
```
|
||||
|
||||
Before enabling the `llm` profile locally, place the GGUF model file described in [models/qwen3/README.md](models/qwen3/README.md) and set `LLM_ENABLED=true` in `.env`.
|
||||
|
||||
If you use BLIP, create a `.env` file first.
|
||||
|
||||
Required variables:
|
||||
@@ -30,6 +42,35 @@ HUGGINGFACE_TOKEN=your_huggingface_token_here
|
||||
|
||||
`HUGGINGFACE_TOKEN` is required when the configured BLIP model is private, gated, or otherwise requires Hugging Face authentication.
|
||||
|
||||
Optional maturity configuration (override in `.env` if needed):
|
||||
|
||||
```bash
|
||||
MATURITY_MODEL=Falconsai/nsfw_image_detection
|
||||
MATURITY_THRESHOLD_MATURE=0.80
|
||||
MATURITY_THRESHOLD_REVIEW=0.60
|
||||
MATURITY_ENABLED=true
|
||||
```
|
||||
|
||||
Optional LLM configuration:
|
||||
|
||||
```bash
|
||||
LLM_ENABLED=false
|
||||
LLM_URL=http://llm:8080
|
||||
LLM_DEFAULT_MODEL=qwen3-1.7b-instruct-q4_k_m
|
||||
LLM_TIMEOUT=120
|
||||
LLM_MAX_TOKENS_DEFAULT=256
|
||||
LLM_MAX_TOKENS_HARD_LIMIT=1024
|
||||
LLM_MAX_REQUEST_BYTES=65536
|
||||
|
||||
# Local llm profile only
|
||||
MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf
|
||||
LLM_CONTEXT_SIZE=4096
|
||||
LLM_THREADS=4
|
||||
LLM_GPU_LAYERS=0
|
||||
```
|
||||
|
||||
Recommended production topology for the LLM: keep the gateway on the current vision host and point `LLM_URL` at a separate private machine or VPN-reachable container host. Running the full vision stack and Qwen3 together on a small 4c/8GB VPS will usually degrade both.
|
||||
|
||||
Service startup now waits on container healthchecks, so first boot may take longer while models finish loading.
|
||||
|
||||
## Health
|
||||
@@ -38,6 +79,71 @@ Service startup now waits on container healthchecks, so first boot may take long
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/health
|
||||
```
|
||||
|
||||
LLM-specific gateway health:
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/ai/health
|
||||
```
|
||||
|
||||
## LLM Smoke Test
|
||||
|
||||
Use this checklist on a Docker-capable host after provisioning the GGUF file and setting `LLM_ENABLED=true`.
|
||||
|
||||
1. Start the gateway and local LLM profile.
|
||||
|
||||
```bash
|
||||
docker compose --profile llm up -d --build gateway llm
|
||||
```
|
||||
|
||||
2. Confirm the LLM container is running and healthy.
|
||||
|
||||
```bash
|
||||
docker compose ps llm
|
||||
docker compose logs --tail=100 llm
|
||||
```
|
||||
|
||||
3. Check the internal LLM health contract.
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/health
|
||||
```
|
||||
|
||||
Expected fields: `status`, `model`, `context_size`, `threads`.
|
||||
|
||||
4. Check gateway health and LLM reachability.
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/health
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/ai/health
|
||||
```
|
||||
|
||||
5. Verify model discovery through the gateway.
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/v1/models
|
||||
```
|
||||
|
||||
6. Run a short non-streaming chat completion.
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST http://127.0.0.1:8003/ai/chat \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a concise assistant for Skinbase Nova."},
|
||||
{"role": "user", "content": "Write one sentence about an artist who creates cinematic sci-fi wallpaper packs."}
|
||||
],
|
||||
"max_tokens": 80
|
||||
}'
|
||||
```
|
||||
|
||||
7. If anything fails, inspect the two relevant services first.
|
||||
|
||||
```bash
|
||||
docker compose logs --tail=200 llm
|
||||
docker compose logs --tail=200 gateway
|
||||
```
|
||||
|
||||
## Universal analyze (ALL)
|
||||
|
||||
### With URL
|
||||
@@ -96,6 +202,41 @@ curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/analyze/yo
|
||||
-F "conf=0.25"
|
||||
```
|
||||
|
||||
## Maturity / NSFW analysis
|
||||
|
||||
Analyzes an image and returns a normalized maturity signal for Nova moderation workflows.
|
||||
|
||||
### Analyze by URL
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/analyze/maturity \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp"}'
|
||||
```
|
||||
|
||||
### Analyze from file upload
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/analyze/maturity/file \
|
||||
-F "file=@/path/to/image.webp"
|
||||
```
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"maturity_label": "mature",
|
||||
"confidence": 0.94,
|
||||
"score": 0.94,
|
||||
"labels": ["nsfw"],
|
||||
"model": "Falconsai/nsfw_image_detection",
|
||||
"threshold_used": 0.80,
|
||||
"analysis_time_ms": 183.0,
|
||||
"source": "maturity-service",
|
||||
"action_hint": "flag_high",
|
||||
"advisory": "High-confidence mature content detected"
|
||||
}
|
||||
```
|
||||
|
||||
`action_hint` values: `safe`, `review`, `flag_high`. Nova should use these to decide blur/queue/flag behaviour.
|
||||
|
||||
## Vector DB (Qdrant) via gateway
|
||||
|
||||
Qdrant point IDs must be either:
|
||||
@@ -226,10 +367,51 @@ curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/cards/rend
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp","title":"Artwork Title"}'
|
||||
```
|
||||
|
||||
## LLM / Chat Completions
|
||||
|
||||
The gateway exposes stable text-generation endpoints backed by the internal `llm` service. They reuse the existing `X-API-Key` protection and keep the LLM container internal-only.
|
||||
|
||||
### OpenAI-style chat endpoint
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a concise assistant for Skinbase Nova."},
|
||||
{"role": "user", "content": "Write a short creator biography for an artist who just hit 10,000 followers."}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 220,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
### Project-friendly chat endpoint
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" -X POST https://vision.klevze.net/ai/chat \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a concise assistant for Skinbase Nova."},
|
||||
{"role": "user", "content": "Suggest metadata tags for a cyberpunk wallpaper pack."}
|
||||
],
|
||||
"max_tokens": 180
|
||||
}'
|
||||
```
|
||||
|
||||
### List models
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/v1/models
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/ai/models
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- This is a **starter scaffold**. Models are loaded at service startup.
|
||||
- Models are loaded at service startup; initial container start can take 1–2 minutes as model weights are downloaded.
|
||||
- Qdrant data is persisted in the project folder at `./data/qdrant`, so it survives container restarts and recreates.
|
||||
- The local `llm` profile does **not** auto-download Qwen3 weights. Mount the GGUF file explicitly and let startup fail fast if it is missing.
|
||||
- Remote image URLs are restricted to public `http`/`https` hosts. Localhost, private IP ranges, and non-image content types are rejected.
|
||||
- The maturity service uses `Falconsai/nsfw_image_detection` (ViT-based). Thresholds are configurable via `.env`. The model handles photos and stylized digital art but should be calibrated against real Skinbase content before production use.
|
||||
- For small VPS deployments, prefer `LLM_ENABLED=true` with `LLM_URL` pointing to a separate LLM host instead of running the `llm` profile on the same machine.
|
||||
- For production: add auth, rate limits, and restrict gateway exposure (private network).
|
||||
- GPU: you can add NVIDIA runtime later (compose profiles) if needed.
|
||||
|
||||
276
USAGE.md
276
USAGE.md
@@ -1,10 +1,10 @@
|
||||
# Skinbase Vision Stack — Usage Guide
|
||||
|
||||
This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant services).
|
||||
This document explains how to run and use the Skinbase Vision Stack (Gateway + CLIP, BLIP, YOLO, Qdrant, Card Renderer, Maturity, and optional LLM services).
|
||||
|
||||
## Overview
|
||||
|
||||
- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc`, `card-renderer` (FastAPI each, except `qdrant` which is the official Qdrant DB).
|
||||
- Services: `gateway`, `clip`, `blip`, `yolo`, `qdrant`, `qdrant-svc`, `card-renderer`, `maturity`, `llm` (FastAPI each except `qdrant`; `llm` is a thin FastAPI shim that manages an internal `llama-server` process).
|
||||
- Gateway is the public API endpoint; the other services are internal.
|
||||
|
||||
## Model overview
|
||||
@@ -19,6 +19,10 @@ This document explains how to run and use the Skinbase Vision Stack (Gateway + C
|
||||
|
||||
- **Card Renderer**: Generates branded social-card images (e.g. Open Graph previews) from artwork images. Applies smart center-weighted cropping, gradient overlays, title/username/tag text, and an optional logo. Returns binary image bytes (WebP by default). Template: `nova-artwork-v1`.
|
||||
|
||||
- **Maturity**: Dedicated NSFW/maturity classifier. Accepts an image and returns a normalized safety signal including `maturity_label` (`safe`/`mature`), `confidence`, raw `score`, optional sublabels (e.g. `nsfw`), and an `action_hint` (`safe`, `review`, `flag_high`) designed for Nova moderation workflows. Powered by `Falconsai/nsfw_image_detection` (ViT-based, HuggingFace). Thresholds are configurable via environment variables.
|
||||
|
||||
- **LLM**: Internal text-generation service backed by `llama.cpp` and a GGUF Qwen3 model. Exposed through the gateway for non-streaming chat completions and model discovery. Intended for Nova workflows such as creator bios, metadata suggestions, moderation helper text, and other short internal generation tasks.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker Desktop (with `docker compose`) or a Docker environment.
|
||||
@@ -40,12 +44,61 @@ Notes:
|
||||
- `HUGGINGFACE_TOKEN` is required if the configured BLIP model requires Hugging Face authentication.
|
||||
- Startup uses container healthchecks, so initial boot can take longer while models download and warm up.
|
||||
|
||||
Optional maturity configuration (can be added to `.env` to override defaults):
|
||||
|
||||
```bash
|
||||
MATURITY_MODEL=Falconsai/nsfw_image_detection
|
||||
MATURITY_THRESHOLD_MATURE=0.80
|
||||
MATURITY_THRESHOLD_REVIEW=0.60
|
||||
MATURITY_ENABLED=true
|
||||
```
|
||||
|
||||
- `MATURITY_THRESHOLD_MATURE`: score above this → `mature` + `flag_high` (default `0.80`).
|
||||
- `MATURITY_THRESHOLD_REVIEW`: score above this but below mature threshold → `mature` + `review` (default `0.60`).
|
||||
- `MATURITY_ENABLED`: set to `false` to disable maturity endpoints at the gateway without removing the service.
|
||||
|
||||
Optional LLM configuration:
|
||||
|
||||
```bash
|
||||
LLM_URL=http://llm:8080
|
||||
LLM_ENABLED=false
|
||||
LLM_TIMEOUT=120
|
||||
LLM_DEFAULT_MODEL=qwen3-1.7b-instruct-q4_k_m
|
||||
LLM_MAX_TOKENS_DEFAULT=256
|
||||
LLM_MAX_TOKENS_HARD_LIMIT=1024
|
||||
LLM_MAX_REQUEST_BYTES=65536
|
||||
|
||||
# Local llm profile only
|
||||
MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf
|
||||
LLM_CONTEXT_SIZE=4096
|
||||
LLM_THREADS=4
|
||||
LLM_GPU_LAYERS=0
|
||||
LLM_EXTRA_ARGS=
|
||||
```
|
||||
|
||||
Run from repository root:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
That starts the default vision stack only.
|
||||
|
||||
To also start the local LLM service:
|
||||
|
||||
```bash
|
||||
docker compose --profile llm up -d --build
|
||||
```
|
||||
|
||||
Before enabling the `llm` profile, provision the GGUF model described in [models/qwen3/README.md](models/qwen3/README.md) and set `LLM_ENABLED=true` in `.env`.
|
||||
|
||||
For small production hosts, the preferred setup is usually to keep the gateway local and point `LLM_URL` at a separate private LLM host:
|
||||
|
||||
```bash
|
||||
LLM_ENABLED=true
|
||||
LLM_URL=http://private-llm-host:8080
|
||||
```
|
||||
|
||||
Stop:
|
||||
|
||||
```bash
|
||||
@@ -67,6 +120,74 @@ Check the gateway health endpoint:
|
||||
curl https://vision.klevze.net/health
|
||||
```
|
||||
|
||||
Check LLM-specific gateway health:
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/ai/health
|
||||
```
|
||||
|
||||
## LLM smoke test checklist
|
||||
|
||||
Use this sequence on a machine with Docker available after you have mounted the GGUF model and enabled the gateway with `LLM_ENABLED=true`.
|
||||
|
||||
1. Start the gateway with the `llm` profile.
|
||||
|
||||
```bash
|
||||
docker compose --profile llm up -d --build gateway llm
|
||||
```
|
||||
|
||||
2. Confirm the LLM service came up cleanly.
|
||||
|
||||
```bash
|
||||
docker compose ps llm
|
||||
docker compose logs --tail=100 llm
|
||||
```
|
||||
|
||||
3. Check the repo-owned internal health endpoint.
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8080/health
|
||||
```
|
||||
|
||||
Expected fields: `status`, `model`, `context_size`, `threads`.
|
||||
|
||||
4. Confirm the gateway sees the LLM backend.
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/health
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/ai/health
|
||||
```
|
||||
|
||||
5. Verify model discovery.
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/v1/models
|
||||
curl -H "X-API-Key: <your-api-key>" http://127.0.0.1:8003/ai/models
|
||||
```
|
||||
|
||||
6. Run a small chat request through the gateway.
|
||||
|
||||
```bash
|
||||
curl -X POST http://127.0.0.1:8003/v1/chat/completions \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a concise assistant for Skinbase Nova."},
|
||||
{"role": "user", "content": "Write one short admin help sentence about reviewing wallpaper metadata."}
|
||||
],
|
||||
"max_tokens": 60,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
7. If startup or health fails, inspect the relevant logs.
|
||||
|
||||
```bash
|
||||
docker compose logs --tail=200 llm
|
||||
docker compose logs --tail=200 gateway
|
||||
```
|
||||
|
||||
## Universal analyze (ALL)
|
||||
|
||||
Analyze an image by URL (gateway aggregates CLIP, BLIP, YOLO):
|
||||
@@ -168,9 +289,151 @@ Parameters:
|
||||
|
||||
Return: detected objects with `class`, `confidence`, and `bbox` (bounding box coordinates).
|
||||
|
||||
### Qdrant — vector storage & similarity search
|
||||
### Maturity — NSFW / maturity analysis
|
||||
|
||||
The Qdrant integration lets you store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service.
|
||||
Analyzes an image for mature or NSFW content and returns a structured signal intended for Nova moderation workflows.
|
||||
|
||||
URL request:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/maturity \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url":"https://files.skinbase.org/img/aa/bb/cc/md.webp"}'
|
||||
```
|
||||
|
||||
File upload:
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/analyze/maturity/file \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-F "file=@/path/to/image.webp"
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"maturity_label": "mature",
|
||||
"confidence": 0.94,
|
||||
"score": 0.94,
|
||||
"labels": ["nsfw"],
|
||||
"model": "Falconsai/nsfw_image_detection",
|
||||
"threshold_used": 0.80,
|
||||
"analysis_time_ms": 183.0,
|
||||
"source": "maturity-service",
|
||||
"action_hint": "flag_high",
|
||||
"advisory": "High-confidence mature content detected"
|
||||
}
|
||||
```
|
||||
|
||||
Response fields:
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `maturity_label` | string | `safe` or `mature` |
|
||||
| `confidence` | float | Confidence in the label decision (0–1). For `safe`, this is `1 - score`. |
|
||||
| `score` | float | Raw NSFW probability from the model (0–1). |
|
||||
| `labels` | array | Sublabels when mature: currently `["nsfw"]`. Empty for safe results. |
|
||||
| `model` | string | Model identifier / HuggingFace model ID. |
|
||||
| `threshold_used` | float | The threshold value that determined the label. |
|
||||
| `analysis_time_ms` | float | Inference time in milliseconds. |
|
||||
| `source` | string | Always `maturity-service`. |
|
||||
| `action_hint` | string | `safe`, `review`, or `flag_high`. Use this in Nova to drive blur/queue/flag decisions. |
|
||||
| `advisory` | string | Short human-readable explanation. |
|
||||
|
||||
`action_hint` decision logic:
|
||||
- `flag_high`: score ≥ `MATURITY_THRESHOLD_MATURE` (default 0.80) — high-confidence mature, flag for moderation.
|
||||
- `review`: score ≥ `MATURITY_THRESHOLD_REVIEW` (default 0.60) but below mature threshold — possible mature, queue for human review.
|
||||
- `safe`: score below both thresholds — content appears safe.
|
||||
|
||||
If the maturity service is unavailable the gateway returns a `502` or `503` error. **Nova must not treat a gateway failure as a `safe` result** — retry or queue for later processing.
|
||||
|
||||
## LLM / Chat endpoints
|
||||
|
||||
The gateway validates requests, clamps `max_tokens` to configured limits, rejects oversized payloads, and normalizes downstream failures into JSON under an `error` key.
|
||||
|
||||
### OpenAI-style chat completions
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/v1/chat/completions \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a concise assistant for Skinbase Nova."},
|
||||
{"role": "user", "content": "Write a short biography for a creator known for sci-fi environments."}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 220,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
Supported request fields:
|
||||
- `messages` (required)
|
||||
- `temperature`
|
||||
- `max_tokens`
|
||||
- `stream` (`false` only in v1)
|
||||
- `top_p`
|
||||
- `stop`
|
||||
- `presence_penalty`
|
||||
- `frequency_penalty`
|
||||
|
||||
Validation rules:
|
||||
- At least one message is required.
|
||||
- Roles must be `system`, `user`, or `assistant`.
|
||||
- Empty message content is rejected.
|
||||
- Oversized request bodies return `413`.
|
||||
- `max_tokens` is clamped to `LLM_MAX_TOKENS_HARD_LIMIT`.
|
||||
|
||||
### Project-friendly chat response
|
||||
|
||||
```bash
|
||||
curl -X POST https://vision.klevze.net/ai/chat \
|
||||
-H "X-API-Key: <your-api-key>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful metadata assistant."},
|
||||
{"role": "user", "content": "Suggest five tags for a fantasy castle wallpaper."}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"content": "fantasy castle, moonlit fortress, medieval towers, epic landscape, digital painting",
|
||||
"finish_reason": "stop",
|
||||
"usage": {
|
||||
"prompt_tokens": 48,
|
||||
"completion_tokens": 19,
|
||||
"total_tokens": 67
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Model discovery
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/v1/models
|
||||
curl -H "X-API-Key: <your-api-key>" https://vision.klevze.net/ai/models
|
||||
```
|
||||
|
||||
### Failure modes
|
||||
|
||||
- `401`: missing or invalid API key
|
||||
- `413`: request body exceeds `LLM_MAX_REQUEST_BYTES`
|
||||
- `422`: validation failure or unsupported streaming request
|
||||
- `503`: LLM disabled or upstream unavailable
|
||||
- `504`: upstream timeout
|
||||
|
||||
## Vector DB (Qdrant)
|
||||
|
||||
Use the Qdrant gateway endpoints to store image embeddings and find visually similar images. Embeddings are generated automatically by the CLIP service.
|
||||
|
||||
Qdrant point IDs must be either an unsigned integer or a UUID string. If you send another string value, the wrapper may replace it with a generated UUID and store the original value in metadata as `_original_id`.
|
||||
|
||||
@@ -457,7 +720,9 @@ uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
- Qdrant upsert error about invalid point ID: use a UUID or unsigned integer for `id`, or omit it and use the returned generated `id`.
|
||||
- Image URL rejected before download: the URL may point to localhost, a private IP, a non-`http/https` scheme, or a non-image content type.
|
||||
- High memory / OOM: increase host memory or reduce model footprint; consider GPUs.
|
||||
- Slow startup: model weights load on service startup — expect extra time.
|
||||
- Slow startup: model weights load on service startup — expect extra time. The maturity service (`start_period: 90s`) may take longer on first boot as it downloads the classifier weights (~330 MB). Mount `~/.cache/huggingface` as a volume to persist across rebuilds.
|
||||
- Maturity endpoint returns `503`: `MATURITY_ENABLED` is set to `false` in environment configuration.
|
||||
- Maturity endpoint returns `502`: the maturity container is unhealthy or still starting up; wait and retry.
|
||||
|
||||
## Extending
|
||||
|
||||
@@ -469,6 +734,7 @@ uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
- `docker-compose.yml` — composition and service definitions.
|
||||
- `gateway/` — gateway FastAPI server.
|
||||
- `clip/`, `blip/`, `yolo/` — service implementations and Dockerfiles.
|
||||
- `maturity/` — NSFW/maturity classifier service (ViT-based, HuggingFace `Falconsai/nsfw_image_detection`).
|
||||
- `qdrant/` — Qdrant API wrapper service (FastAPI).
|
||||
- `card-renderer/` — card rendering service (FastAPI).
|
||||
- `common/` — shared helpers (e.g., image I/O).
|
||||
|
||||
@@ -13,6 +13,15 @@ services:
|
||||
- YOLO_URL=http://yolo:8000
|
||||
- QDRANT_SVC_URL=http://qdrant-svc:8000
|
||||
- CARD_RENDERER_URL=http://card-renderer:8000
|
||||
- MATURITY_URL=http://maturity:8000
|
||||
- LLM_URL=${LLM_URL:-http://llm:8080}
|
||||
- LLM_ENABLED=${LLM_ENABLED:-false}
|
||||
- LLM_TIMEOUT=${LLM_TIMEOUT:-120}
|
||||
- LLM_DEFAULT_MODEL=${LLM_DEFAULT_MODEL:-qwen3-1.7b-instruct-q4_k_m}
|
||||
- LLM_MAX_TOKENS_DEFAULT=${LLM_MAX_TOKENS_DEFAULT:-256}
|
||||
- LLM_MAX_TOKENS_HARD_LIMIT=${LLM_MAX_TOKENS_HARD_LIMIT:-1024}
|
||||
- LLM_MAX_REQUEST_BYTES=${LLM_MAX_REQUEST_BYTES:-65536}
|
||||
- MATURITY_ENABLED=true
|
||||
- API_KEY=${API_KEY}
|
||||
- VISION_TIMEOUT=300
|
||||
- MAX_IMAGE_BYTES=52428800
|
||||
@@ -27,6 +36,8 @@ services:
|
||||
condition: service_healthy
|
||||
card-renderer:
|
||||
condition: service_healthy
|
||||
maturity:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5).read()"]
|
||||
interval: 30s
|
||||
@@ -131,3 +142,42 @@ services:
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
maturity:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: maturity/Dockerfile
|
||||
environment:
|
||||
- MATURITY_MODEL=${MATURITY_MODEL:-Falconsai/nsfw_image_detection}
|
||||
- MATURITY_THRESHOLD_MATURE=${MATURITY_THRESHOLD_MATURE:-0.80}
|
||||
- MATURITY_THRESHOLD_REVIEW=${MATURITY_THRESHOLD_REVIEW:-0.60}
|
||||
- MAX_IMAGE_BYTES=52428800
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5).read()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 90s
|
||||
|
||||
llm:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: llm/Dockerfile
|
||||
environment:
|
||||
- MODEL_PATH=${MODEL_PATH:-/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf}
|
||||
- LLM_MODEL_NAME=${LLM_DEFAULT_MODEL:-qwen3-1.7b-instruct-q4_k_m}
|
||||
- LLM_CONTEXT_SIZE=${LLM_CONTEXT_SIZE:-4096}
|
||||
- LLM_THREADS=${LLM_THREADS:-4}
|
||||
- LLM_GPU_LAYERS=${LLM_GPU_LAYERS:-0}
|
||||
- LLM_PORT=8080
|
||||
- LLM_EXTRA_ARGS=${LLM_EXTRA_ARGS:-}
|
||||
volumes:
|
||||
- ./models/qwen3:/models:ro
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-fsS", "http://127.0.0.1:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 120s
|
||||
profiles:
|
||||
- llm
|
||||
|
||||
|
||||
670
gateway/main.py
670
gateway/main.py
@@ -1,25 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
logger = logging.getLogger("gateway")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
CLIP_URL = os.getenv("CLIP_URL", "http://clip:8000")
|
||||
BLIP_URL = os.getenv("BLIP_URL", "http://blip:8000")
|
||||
YOLO_URL = os.getenv("YOLO_URL", "http://yolo:8000")
|
||||
QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
|
||||
CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000")
|
||||
MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000")
|
||||
MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no")
|
||||
LLM_URL = os.getenv("LLM_URL", "http://llm:8080")
|
||||
LLM_ENABLED = os.getenv("LLM_ENABLED", "false").lower() not in ("0", "false", "no")
|
||||
LLM_DEFAULT_MODEL = os.getenv("LLM_DEFAULT_MODEL", "qwen3-1.7b-instruct-q4_k_m")
|
||||
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120"))
|
||||
LLM_MAX_TOKENS_HARD_LIMIT = max(1, int(os.getenv("LLM_MAX_TOKENS_HARD_LIMIT", "1024")))
|
||||
LLM_MAX_TOKENS_DEFAULT = min(
|
||||
LLM_MAX_TOKENS_HARD_LIMIT,
|
||||
max(1, int(os.getenv("LLM_MAX_TOKENS_DEFAULT", "256"))),
|
||||
)
|
||||
LLM_MAX_REQUEST_BYTES = max(1024, int(os.getenv("LLM_MAX_REQUEST_BYTES", "65536")))
|
||||
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
|
||||
|
||||
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
|
||||
API_KEY = os.getenv("API_KEY")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared persistent HTTP client — created once at startup, reused across all
|
||||
# requests. This eliminates per-request TCP connect + DNS latency (the main
|
||||
# cause of 20 s first-request latency observed on /vectors/inspect).
|
||||
# ---------------------------------------------------------------------------
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
class LLMGatewayError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
code: str,
|
||||
message: str,
|
||||
details: Optional[Any] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.details = details
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
"""Return the shared httpx client. Raises if called before lifespan starts."""
|
||||
if _http_client is None:
|
||||
raise RuntimeError("HTTP client not initialised — lifespan not running")
|
||||
return _http_client
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: create shared HTTP client and warm upstream connections."""
|
||||
global _http_client
|
||||
|
||||
t0 = time.perf_counter()
|
||||
logger.info("gateway startup: creating shared HTTP client")
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=30,
|
||||
)
|
||||
_http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(VISION_TIMEOUT, connect=10),
|
||||
limits=limits,
|
||||
)
|
||||
|
||||
# Warm the qdrant-svc connection so the first real request does not pay
|
||||
# the TCP handshake + DNS cost. Failure is non-fatal — the service may
|
||||
# still be starting when the gateway starts.
|
||||
try:
|
||||
t_warm = time.perf_counter()
|
||||
r = await _http_client.get(f"{QDRANT_SVC_URL}/health", timeout=8)
|
||||
logger.info(
|
||||
"gateway startup: qdrant-svc warm ping done status=%s elapsed_ms=%.1f",
|
||||
r.status_code, (time.perf_counter() - t_warm) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc)
|
||||
|
||||
if LLM_ENABLED:
|
||||
try:
|
||||
t_warm = time.perf_counter()
|
||||
r = await _http_client.get(f"{LLM_URL}/health", timeout=min(LLM_TIMEOUT, 10))
|
||||
logger.info(
|
||||
"gateway startup: llm warm ping done status=%s elapsed_ms=%.1f",
|
||||
r.status_code, (time.perf_counter() - t_warm) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("gateway startup: llm warm ping failed (non-fatal): %s", exc)
|
||||
|
||||
logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
|
||||
yield # application runs
|
||||
|
||||
logger.info("gateway shutdown: closing shared HTTP client")
|
||||
await _http_client.aclose()
|
||||
_http_client = None
|
||||
|
||||
|
||||
class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# allow health and docs endpoints without API key
|
||||
@@ -27,13 +127,31 @@ class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
|
||||
if not API_KEY or key != API_KEY:
|
||||
if _is_llm_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": {"code": "unauthorized", "message": "Unauthorized"}},
|
||||
)
|
||||
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
||||
return await call_next(request)
|
||||
|
||||
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0")
|
||||
|
||||
def _is_llm_path(path: str) -> bool:
|
||||
return path.startswith("/v1/") or path.startswith("/ai/")
|
||||
|
||||
|
||||
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan)
|
||||
app.add_middleware(APIKeyMiddleware)
|
||||
|
||||
|
||||
@app.exception_handler(LLMGatewayError)
|
||||
async def handle_llm_gateway_error(_: Request, exc: LLMGatewayError):
|
||||
error: Dict[str, Any] = {"code": exc.code, "message": exc.message}
|
||||
if exc.details is not None:
|
||||
error["details"] = exc.details
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": error})
|
||||
|
||||
|
||||
class ClipRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
limit: int = Field(default=5, ge=1, le=50)
|
||||
@@ -51,19 +169,239 @@ class YoloRequest(BaseModel):
|
||||
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
async def _get_health(client: httpx.AsyncClient, base: str) -> Dict[str, Any]:
|
||||
class MaturityRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
@field_validator("content")
|
||||
@classmethod
|
||||
def validate_content(cls, value: str) -> str:
|
||||
if not value or not value.strip():
|
||||
raise ValueError("message content must not be empty")
|
||||
return value
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatMessage] = Field(min_length=1, max_length=100)
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = Field(default=None, ge=1)
|
||||
stream: bool = False
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
stop: Optional[str | List[str]] = None
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
model = value.strip()
|
||||
if not model:
|
||||
raise ValueError("model must not be empty")
|
||||
return model
|
||||
|
||||
@field_validator("temperature")
|
||||
@classmethod
|
||||
def validate_temperature(cls, value: Optional[float]) -> Optional[float]:
|
||||
if value is None:
|
||||
return value
|
||||
if value < 0.0 or value > 2.0:
|
||||
raise ValueError("temperature must be between 0 and 2")
|
||||
return value
|
||||
|
||||
|
||||
def _llm_timeout() -> httpx.Timeout:
|
||||
return httpx.Timeout(LLM_TIMEOUT, connect=min(LLM_TIMEOUT, 10))
|
||||
|
||||
|
||||
def _assert_llm_enabled() -> None:
|
||||
if not LLM_ENABLED:
|
||||
raise LLMGatewayError(503, "llm_disabled", "LLM service is disabled")
|
||||
|
||||
|
||||
def _extract_upstream_error_message(response: httpx.Response) -> str:
|
||||
try:
|
||||
r = await client.get(f"{base}/health")
|
||||
payload = response.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict) and error.get("message"):
|
||||
return str(error["message"])
|
||||
if payload.get("message"):
|
||||
return str(payload["message"])
|
||||
if payload.get("detail"):
|
||||
return str(payload["detail"])
|
||||
|
||||
text = response.text.strip()
|
||||
return text[:500] if text else f"Upstream returned HTTP {response.status_code}"
|
||||
|
||||
|
||||
def _map_upstream_llm_status(status_code: int) -> int:
|
||||
if status_code in (400, 413, 422):
|
||||
return status_code
|
||||
if 400 <= status_code < 500:
|
||||
return 422
|
||||
return 503
|
||||
|
||||
|
||||
def _normalize_chat_payload(payload: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
normalized = payload.model_dump(exclude_none=True)
|
||||
normalized["model"] = normalized.get("model") or LLM_DEFAULT_MODEL
|
||||
normalized["max_tokens"] = min(
|
||||
int(normalized.get("max_tokens") or LLM_MAX_TOKENS_DEFAULT),
|
||||
LLM_MAX_TOKENS_HARD_LIMIT,
|
||||
)
|
||||
|
||||
if "temperature" in normalized:
|
||||
normalized["temperature"] = max(0.0, min(2.0, float(normalized["temperature"])))
|
||||
|
||||
if normalized.get("stream"):
|
||||
raise LLMGatewayError(
|
||||
422,
|
||||
"streaming_not_supported",
|
||||
"Streaming responses are not enabled for this gateway",
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
async def _parse_llm_request(request: Request) -> ChatCompletionRequest:
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length:
|
||||
try:
|
||||
if int(content_length) > LLM_MAX_REQUEST_BYTES:
|
||||
raise LLMGatewayError(
|
||||
413,
|
||||
"payload_too_large",
|
||||
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
|
||||
)
|
||||
except ValueError:
|
||||
raise LLMGatewayError(400, "invalid_request", "Invalid Content-Length header")
|
||||
|
||||
body = await request.body()
|
||||
if not body:
|
||||
raise LLMGatewayError(400, "invalid_request", "Request body is required")
|
||||
if len(body) > LLM_MAX_REQUEST_BYTES:
|
||||
raise LLMGatewayError(
|
||||
413,
|
||||
"payload_too_large",
|
||||
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
|
||||
)
|
||||
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
raise LLMGatewayError(400, "invalid_json", "Request body must be valid JSON")
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise LLMGatewayError(400, "invalid_request", "JSON body must be an object")
|
||||
|
||||
try:
|
||||
return ChatCompletionRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise LLMGatewayError(422, "validation_error", "Invalid chat request", exc.errors())
|
||||
|
||||
|
||||
async def _llm_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json_payload: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_assert_llm_enabled()
|
||||
|
||||
url = f"{LLM_URL}{path}"
|
||||
try:
|
||||
response = await get_http_client().request(
|
||||
method,
|
||||
url,
|
||||
json=json_payload,
|
||||
timeout=_llm_timeout(),
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise LLMGatewayError(504, "llm_timeout", "LLM request timed out")
|
||||
except httpx.RequestError as exc:
|
||||
raise LLMGatewayError(503, "llm_unavailable", f"LLM service is unavailable: {exc}")
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise LLMGatewayError(503, "llm_unavailable", _extract_upstream_error_message(response))
|
||||
if response.status_code >= 400:
|
||||
raise LLMGatewayError(
|
||||
_map_upstream_llm_status(response.status_code),
|
||||
"llm_rejected_request",
|
||||
_extract_upstream_error_message(response),
|
||||
)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except Exception:
|
||||
raise LLMGatewayError(503, "llm_invalid_response", "LLM service returned invalid JSON")
|
||||
|
||||
|
||||
def _normalize_ai_chat_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
choices = response.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain choices")
|
||||
|
||||
first_choice = choices[0] if isinstance(choices[0], dict) else {}
|
||||
message = first_choice.get("message") if isinstance(first_choice.get("message"), dict) else {}
|
||||
content = message.get("content")
|
||||
if not isinstance(content, str):
|
||||
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain message content")
|
||||
|
||||
usage = response.get("usage") if isinstance(response.get("usage"), dict) else {}
|
||||
return {
|
||||
"model": response.get("model") or LLM_DEFAULT_MODEL,
|
||||
"content": content,
|
||||
"finish_reason": first_choice.get("finish_reason") or "stop",
|
||||
"usage": {
|
||||
"prompt_tokens": int(usage.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(usage.get("completion_tokens") or 0),
|
||||
"total_tokens": int(usage.get("total_tokens") or 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _get_llm_models_payload() -> Dict[str, Any]:
|
||||
models = await _llm_request("GET", "/v1/models")
|
||||
if isinstance(models.get("data"), list) and models["data"]:
|
||||
return models
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": LLM_DEFAULT_MODEL,
|
||||
"object": "model",
|
||||
"owned_by": "self-hosted",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def _get_health(base: str) -> Dict[str, Any]:
|
||||
try:
|
||||
r = await get_http_client().get(f"{base}/health", timeout=5)
|
||||
return r.json() if r.status_code == 200 else {"status": "bad", "code": r.status_code}
|
||||
except Exception:
|
||||
return {"status": "unreachable"}
|
||||
|
||||
|
||||
async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
r = await client.post(url, json=payload)
|
||||
r = await get_http_client().post(url, json=payload)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
logger.debug("POST %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
||||
try:
|
||||
@@ -73,12 +411,15 @@ async def _post_json(client: httpx.AsyncClient, url: str, payload: Dict[str, Any
|
||||
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
||||
|
||||
|
||||
async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _post_file(url: str, data: bytes, fields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
files = {"file": ("image", data, "application/octet-stream")}
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
r = await client.post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
||||
r = await get_http_client().post(url, data={k: str(v) for k, v in fields.items()}, files=files)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
logger.debug("POST(file) %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
||||
try:
|
||||
@@ -87,11 +428,14 @@ async def _post_file(client: httpx.AsyncClient, url: str, data: bytes, fields: D
|
||||
raise HTTPException(status_code=502, detail=f"Upstream returned non-JSON at {url}: {r.status_code} {r.text[:1000]}")
|
||||
|
||||
|
||||
async def _get_json(client: httpx.AsyncClient, url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def _get_json(url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
r = await client.get(url, params=params)
|
||||
r = await get_http_client().get(url, params=params)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed {url}: {str(e)}")
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
logger.debug("GET %s status=%s elapsed_ms=%.1f", url, r.status_code, elapsed)
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error {url}: {r.status_code} {r.text[:1000]}")
|
||||
try:
|
||||
@@ -102,14 +446,91 @@ async def _get_json(client: httpx.AsyncClient, url: str, params: Optional[Dict[s
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
clip_h, blip_h, yolo_h, qdrant_h = await asyncio.gather(
|
||||
_get_health(client, CLIP_URL),
|
||||
_get_health(client, BLIP_URL),
|
||||
_get_health(client, YOLO_URL),
|
||||
_get_health(client, QDRANT_SVC_URL),
|
||||
)
|
||||
return {"status": "ok", "services": {"clip": clip_h, "blip": blip_h, "yolo": yolo_h, "qdrant": qdrant_h}}
|
||||
health_checks = [
|
||||
_get_health(CLIP_URL),
|
||||
_get_health(BLIP_URL),
|
||||
_get_health(YOLO_URL),
|
||||
_get_health(QDRANT_SVC_URL),
|
||||
]
|
||||
llm_index: Optional[int] = None
|
||||
if MATURITY_ENABLED:
|
||||
health_checks.append(_get_health(MATURITY_URL))
|
||||
if LLM_ENABLED:
|
||||
llm_index = len(health_checks)
|
||||
health_checks.append(_get_health(LLM_URL))
|
||||
|
||||
results = await asyncio.gather(*health_checks)
|
||||
services: Dict[str, Any] = {
|
||||
"clip": results[0],
|
||||
"blip": results[1],
|
||||
"yolo": results[2],
|
||||
"qdrant": results[3],
|
||||
}
|
||||
if MATURITY_ENABLED:
|
||||
services["maturity"] = results[4]
|
||||
if LLM_ENABLED and llm_index is not None:
|
||||
services["llm"] = {
|
||||
"enabled": True,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": results[llm_index],
|
||||
}
|
||||
else:
|
||||
services["llm"] = {
|
||||
"enabled": False,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": {"status": "disabled"},
|
||||
}
|
||||
return {"status": "ok", "services": services}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def llm_chat_completions(request: Request):
|
||||
payload = _normalize_chat_payload(await _parse_llm_request(request))
|
||||
return await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def llm_models():
|
||||
return await _get_llm_models_payload()
|
||||
|
||||
|
||||
@app.post("/ai/chat")
|
||||
async def ai_chat(request: Request):
|
||||
payload = _normalize_chat_payload(await _parse_llm_request(request))
|
||||
response = await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
|
||||
return _normalize_ai_chat_response(response)
|
||||
|
||||
|
||||
@app.get("/ai/models")
|
||||
async def ai_models():
|
||||
models = await _get_llm_models_payload()
|
||||
return {
|
||||
"enabled": LLM_ENABLED,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"models": models.get("data", []),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/ai/health")
|
||||
async def ai_health():
|
||||
if not LLM_ENABLED:
|
||||
return {
|
||||
"status": "ok",
|
||||
"enabled": False,
|
||||
"reachable": False,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": {"status": "disabled"},
|
||||
}
|
||||
|
||||
upstream = await _get_health(LLM_URL)
|
||||
reachable = upstream.get("status") == "ok"
|
||||
return {
|
||||
"status": "ok" if reachable else "degraded",
|
||||
"enabled": True,
|
||||
"reachable": reachable,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": upstream,
|
||||
}
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (URL) ----
|
||||
@@ -118,24 +539,21 @@ async def health():
|
||||
async def analyze_clip(req: ClipRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{CLIP_URL}/analyze", req.model_dump())
|
||||
return await _post_json(f"{CLIP_URL}/analyze", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/blip")
|
||||
async def analyze_blip(req: BlipRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{BLIP_URL}/caption", req.model_dump())
|
||||
return await _post_json(f"{BLIP_URL}/caption", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/yolo")
|
||||
async def analyze_yolo(req: YoloRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{YOLO_URL}/detect", req.model_dump())
|
||||
return await _post_json(f"{YOLO_URL}/detect", req.model_dump())
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (file upload) ----
|
||||
@@ -151,8 +569,7 @@ async def analyze_clip_file(
|
||||
fields: Dict[str, Any] = {"limit": int(limit)}
|
||||
if threshold is not None:
|
||||
fields["threshold"] = float(threshold)
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{CLIP_URL}/analyze/file", data, fields)
|
||||
return await _post_file(f"{CLIP_URL}/analyze/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/blip/file")
|
||||
@@ -163,8 +580,7 @@ async def analyze_blip_file(
|
||||
):
|
||||
data = await file.read()
|
||||
fields = {"variants": int(variants), "max_length": int(max_length)}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{BLIP_URL}/caption/file", data, fields)
|
||||
return await _post_file(f"{BLIP_URL}/caption/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/yolo/file")
|
||||
@@ -174,8 +590,7 @@ async def analyze_yolo_file(
|
||||
):
|
||||
data = await file.read()
|
||||
fields = {"conf": float(conf)}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{YOLO_URL}/detect/file", data, fields)
|
||||
return await _post_file(f"{YOLO_URL}/detect/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/analyze/all")
|
||||
@@ -188,13 +603,11 @@ async def analyze_all(payload: Dict[str, Any]):
|
||||
blip_req = {"url": url, "variants": int(payload.get("variants", 3)), "max_length": int(payload.get("max_length", 60))}
|
||||
yolo_req = {"url": url, "conf": float(payload.get("conf", 0.25))}
|
||||
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
clip_task = _post_json(client, f"{CLIP_URL}/analyze", clip_req)
|
||||
blip_task = _post_json(client, f"{BLIP_URL}/caption", blip_req)
|
||||
yolo_task = _post_json(client, f"{YOLO_URL}/detect", yolo_req)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(
|
||||
_post_json(f"{CLIP_URL}/analyze", clip_req),
|
||||
_post_json(f"{BLIP_URL}/caption", blip_req),
|
||||
_post_json(f"{YOLO_URL}/detect", yolo_req),
|
||||
)
|
||||
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
||||
|
||||
|
||||
@@ -202,8 +615,7 @@ async def analyze_all(payload: Dict[str, Any]):
|
||||
|
||||
@app.post("/vectors/upsert")
|
||||
async def vectors_upsert(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/upsert", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/upsert/file")
|
||||
@@ -221,20 +633,17 @@ async def vectors_upsert_file(
|
||||
fields["collection"] = collection
|
||||
if metadata_json is not None:
|
||||
fields["metadata_json"] = metadata_json
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
||||
return await _post_file(f"{QDRANT_SVC_URL}/upsert/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/vectors/upsert/vector")
|
||||
async def vectors_upsert_vector(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/upsert/vector", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/search")
|
||||
async def vectors_search(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/search", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/search", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/search/file")
|
||||
@@ -258,72 +667,69 @@ async def vectors_search_file(
|
||||
fields["hnsw_ef"] = int(hnsw_ef)
|
||||
if filter_metadata_json is not None:
|
||||
fields["filter_metadata_json"] = filter_metadata_json
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_file(client, f"{QDRANT_SVC_URL}/search/file", data, fields)
|
||||
return await _post_file(f"{QDRANT_SVC_URL}/search/file", data, fields)
|
||||
|
||||
|
||||
@app.post("/vectors/search/vector")
|
||||
async def vectors_search_vector(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/search/vector", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/search/vector", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/delete")
|
||||
async def vectors_delete(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/delete", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/delete", payload)
|
||||
|
||||
|
||||
@app.get("/vectors/collections")
|
||||
async def vectors_collections():
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/collections")
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/collections")
|
||||
|
||||
|
||||
@app.post("/vectors/collections")
|
||||
async def vectors_create_collection(payload: Dict[str, Any]):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections", payload)
|
||||
|
||||
|
||||
@app.get("/vectors/collections/{name}")
|
||||
async def vectors_collection_info(name: str):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
|
||||
|
||||
@app.get("/vectors/inspect")
|
||||
async def vectors_inspect():
|
||||
"""Full diagnostic summary for all Qdrant collections (HNSW, optimizer, payload indexes, RAM estimate)."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/inspect")
|
||||
t0 = time.perf_counter()
|
||||
logger.info("vectors_inspect: start")
|
||||
result = await _get_json(f"{QDRANT_SVC_URL}/inspect")
|
||||
logger.info("vectors_inspect: done elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
return result
|
||||
|
||||
|
||||
@app.delete("/vectors/collections/{name}")
|
||||
async def vectors_delete_collection(name: str):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
r = await client.delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
try:
|
||||
r = await get_http_client().delete(f"{QDRANT_SVC_URL}/collections/{name}")
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream request failed: {exc}")
|
||||
if r.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"Upstream error: {r.status_code}")
|
||||
return r.json()
|
||||
|
||||
|
||||
@app.get("/vectors/points/{point_id}")
|
||||
async def vectors_get_point(point_id: str, collection: Optional[str] = None):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/points/{point_id}", params=params)
|
||||
|
||||
|
||||
@app.get("/vectors/points/by-original-id/{original_id}")
|
||||
async def vectors_get_point_by_original_id(original_id: str, collection: Optional[str] = None):
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/points/by-original-id/{original_id}", params=params)
|
||||
params = {}
|
||||
if collection:
|
||||
params["collection"] = collection
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/points/by-original-id/{original_id}", params=params)
|
||||
|
||||
|
||||
# ---- File-based universal analyze ----
|
||||
@@ -337,39 +743,69 @@ async def analyze_all_file(
|
||||
max_length: int = Form(60),
|
||||
):
|
||||
data = await file.read()
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
clip_task = _post_file(client, f"{CLIP_URL}/analyze/file", data, {"limit": limit})
|
||||
blip_task = _post_file(client, f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length})
|
||||
yolo_task = _post_file(client, f"{YOLO_URL}/detect/file", data, {"conf": conf})
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(clip_task, blip_task, yolo_task)
|
||||
|
||||
clip_res, blip_res, yolo_res = await asyncio.gather(
|
||||
_post_file(f"{CLIP_URL}/analyze/file", data, {"limit": limit}),
|
||||
_post_file(f"{BLIP_URL}/caption/file", data, {"variants": variants, "max_length": max_length}),
|
||||
_post_file(f"{YOLO_URL}/detect/file", data, {"conf": conf}),
|
||||
)
|
||||
return {"clip": clip_res, "blip": blip_res, "yolo": yolo_res}
|
||||
|
||||
|
||||
# ---- Maturity / NSFW analysis endpoints ----
|
||||
|
||||
|
||||
def _assert_maturity_enabled() -> None:
|
||||
if not MATURITY_ENABLED:
|
||||
raise HTTPException(status_code=503, detail="Maturity service is disabled")
|
||||
|
||||
|
||||
@app.post("/analyze/maturity")
|
||||
async def analyze_maturity(req: MaturityRequest):
|
||||
"""Analyze an image URL for maturity / NSFW content.
|
||||
|
||||
Returns a normalized maturity signal including maturity_label (safe/mature),
|
||||
confidence, score, optional sublabels, and an action_hint for Nova moderation.
|
||||
"""
|
||||
_assert_maturity_enabled()
|
||||
if not req.url:
|
||||
raise HTTPException(status_code=400, detail="url is required")
|
||||
logger.info("analyze_maturity: url=%s", req.url)
|
||||
return await _post_json(f"{MATURITY_URL}/analyze", req.model_dump())
|
||||
|
||||
|
||||
@app.post("/analyze/maturity/file")
|
||||
async def analyze_maturity_file(file: UploadFile = File(...)):
|
||||
"""Analyze an uploaded image file for maturity / NSFW content.
|
||||
|
||||
Returns the same normalized maturity signal as /analyze/maturity.
|
||||
"""
|
||||
_assert_maturity_enabled()
|
||||
data = await file.read()
|
||||
logger.info("analyze_maturity_file: filename=%s size=%d", file.filename, len(data))
|
||||
return await _post_file(f"{MATURITY_URL}/analyze/file", data, {})
|
||||
|
||||
|
||||
# ---- Card renderer endpoints ----
|
||||
|
||||
@app.get("/cards/templates")
|
||||
async def cards_templates():
|
||||
"""List available card templates."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{CARD_RENDERER_URL}/templates")
|
||||
return await _get_json(f"{CARD_RENDERER_URL}/templates")
|
||||
|
||||
|
||||
@app.post("/cards/render")
|
||||
async def cards_render(payload: Dict[str, Any]):
|
||||
"""Render a Nova card from a remote image URL. Returns binary image bytes."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
try:
|
||||
resp = await client.post(f"{CARD_RENDERER_URL}/render", json=payload)
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
||||
if resp.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
|
||||
return Response(
|
||||
content=resp.content,
|
||||
media_type=resp.headers.get("content-type", "image/webp"),
|
||||
)
|
||||
try:
|
||||
resp = await get_http_client().post(f"{CARD_RENDERER_URL}/render", json=payload)
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
||||
if resp.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
|
||||
return Response(
|
||||
content=resp.content,
|
||||
media_type=resp.headers.get("content-type", "image/webp"),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/cards/render/file")
|
||||
@@ -409,28 +845,26 @@ async def cards_render_file(
|
||||
fields["tags_json"] = tags_json
|
||||
|
||||
upload_files = {"file": (file.filename or "image", data, file.content_type or "application/octet-stream")}
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{CARD_RENDERER_URL}/render/file",
|
||||
data={k: str(v) for k, v in fields.items()},
|
||||
files=upload_files,
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
||||
if resp.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
|
||||
return Response(
|
||||
content=resp.content,
|
||||
media_type=resp.headers.get("content-type", "image/webp"),
|
||||
try:
|
||||
resp = await get_http_client().post(
|
||||
f"{CARD_RENDERER_URL}/render/file",
|
||||
data={k: str(v) for k, v in fields.items()},
|
||||
files=upload_files,
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer unreachable: {exc}")
|
||||
if resp.status_code >= 400:
|
||||
raise HTTPException(status_code=502, detail=f"card-renderer error {resp.status_code}: {resp.text[:1000]}")
|
||||
return Response(
|
||||
content=resp.content,
|
||||
media_type=resp.headers.get("content-type", "image/webp"),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/cards/render/meta")
|
||||
async def cards_render_meta(payload: Dict[str, Any]):
|
||||
"""Return crop and layout metadata for a card render (no image produced)."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{CARD_RENDERER_URL}/render/meta", payload)
|
||||
return await _post_json(f"{CARD_RENDERER_URL}/render/meta", payload)
|
||||
|
||||
|
||||
# ---- Qdrant administration endpoints (index management + collection config) ----
|
||||
@@ -438,26 +872,22 @@ async def cards_render_meta(payload: Dict[str, Any]):
|
||||
@app.get("/vectors/collections/{name}/indexes")
|
||||
async def vectors_collection_indexes(name: str):
|
||||
"""List payload indexes for a collection."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _get_json(client, f"{QDRANT_SVC_URL}/collections/{name}/indexes")
|
||||
return await _get_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes")
|
||||
|
||||
|
||||
@app.post("/vectors/collections/{name}/indexes")
|
||||
async def vectors_create_payload_index(name: str, payload: Dict[str, Any]):
|
||||
"""Create a payload index on a field in a collection."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/indexes", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/collections/{name}/ensure-indexes")
|
||||
async def vectors_ensure_indexes(name: str, payload: Dict[str, Any]):
|
||||
"""Idempotently ensure payload indexes exist for a list of fields."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/ensure-indexes", payload)
|
||||
|
||||
|
||||
@app.post("/vectors/collections/{name}/configure")
|
||||
async def vectors_configure_collection(name: str, payload: Dict[str, Any]):
|
||||
"""Update HNSW and optimizer configuration for a collection."""
|
||||
async with httpx.AsyncClient(timeout=VISION_TIMEOUT) as client:
|
||||
return await _post_json(client, f"{QDRANT_SVC_URL}/collections/{name}/configure", payload)
|
||||
return await _post_json(f"{QDRANT_SVC_URL}/collections/{name}/configure", payload)
|
||||
|
||||
55
llm/Dockerfile
Normal file
55
llm/Dockerfile
Normal file
@@ -0,0 +1,55 @@
|
||||
FROM debian:bookworm-slim AS builder
|
||||
|
||||
ARG LLAMA_CPP_REPO=https://github.com/ggml-org/llama.cpp.git
|
||||
ARG LLAMA_CPP_REF=
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
ca-certificates \
|
||||
cmake \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /src
|
||||
RUN git clone --depth 1 ${LLAMA_CPP_REPO} llama.cpp \
|
||||
&& if [ -n "${LLAMA_CPP_REF}" ]; then cd llama.cpp && git fetch --depth 1 origin "${LLAMA_CPP_REF}" && git checkout "${LLAMA_CPP_REF}"; fi
|
||||
|
||||
WORKDIR /src/llama.cpp
|
||||
RUN cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON \
|
||||
&& cmake --build build --config Release --target llama-server -j"$(nproc)"
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
bash \
|
||||
ca-certificates \
|
||||
curl \
|
||||
libgomp1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY llm/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY --from=builder /src/llama.cpp/build/bin/llama-server /usr/local/bin/llama-server
|
||||
COPY --from=builder /src/llama.cpp/build/bin/*.so* /usr/local/lib/
|
||||
RUN ldconfig
|
||||
COPY llm/main.py /app/main.py
|
||||
COPY llm/entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh /usr/local/bin/llama-server
|
||||
|
||||
ENV MODEL_PATH=/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf \
|
||||
LLM_MODEL_NAME=qwen3-1.7b-instruct-q4_k_m \
|
||||
LLM_CONTEXT_SIZE=4096 \
|
||||
LLM_THREADS=4 \
|
||||
LLM_GPU_LAYERS=0 \
|
||||
LLM_PORT=8080 \
|
||||
LLAMA_SERVER_PORT=8081 \
|
||||
LLM_STARTUP_TIMEOUT=120 \
|
||||
LLM_EXTRA_ARGS=
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
25
llm/entrypoint.sh
Normal file
25
llm/entrypoint.sh
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eu
|
||||
|
||||
MODEL_PATH="${MODEL_PATH:-/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf}"
|
||||
LLM_MODEL_NAME="${LLM_MODEL_NAME:-qwen3-1.7b-instruct-q4_k_m}"
|
||||
LLM_CONTEXT_SIZE="${LLM_CONTEXT_SIZE:-4096}"
|
||||
LLM_THREADS="${LLM_THREADS:-4}"
|
||||
LLM_GPU_LAYERS="${LLM_GPU_LAYERS:-0}"
|
||||
LLM_PORT="${LLM_PORT:-8080}"
|
||||
LLAMA_SERVER_PORT="${LLAMA_SERVER_PORT:-8081}"
|
||||
|
||||
if [ ! -f "$MODEL_PATH" ]; then
|
||||
echo "llm startup failed: model file not found at $MODEL_PATH" >&2
|
||||
echo "Mount a GGUF model into ./models/qwen3 and set MODEL_PATH if the filename differs." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -r "$MODEL_PATH" ]; then
|
||||
echo "llm startup failed: model file is not readable at $MODEL_PATH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Starting llm shim model=$LLM_MODEL_NAME model_path=$MODEL_PATH public_port=$LLM_PORT upstream_port=$LLAMA_SERVER_PORT ctx=$LLM_CONTEXT_SIZE threads=$LLM_THREADS gpu_layers=$LLM_GPU_LAYERS"
|
||||
|
||||
exec python -m uvicorn main:app --host 0.0.0.0 --port "$LLM_PORT"
|
||||
211
llm/main.py
Normal file
211
llm/main.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
logger = logging.getLogger("llm")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
MODEL_PATH = os.getenv("MODEL_PATH", "/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf")
|
||||
LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "qwen3-1.7b-instruct-q4_k_m")
|
||||
LLM_CONTEXT_SIZE = int(os.getenv("LLM_CONTEXT_SIZE", "4096"))
|
||||
LLM_THREADS = int(os.getenv("LLM_THREADS", "4"))
|
||||
LLM_GPU_LAYERS = int(os.getenv("LLM_GPU_LAYERS", "0"))
|
||||
LLAMA_SERVER_PORT = int(os.getenv("LLAMA_SERVER_PORT", "8081"))
|
||||
LLM_STARTUP_TIMEOUT = float(os.getenv("LLM_STARTUP_TIMEOUT", "120"))
|
||||
LLM_EXTRA_ARGS = os.getenv("LLM_EXTRA_ARGS", "")
|
||||
|
||||
_llama_process: subprocess.Popen[bytes] | None = None
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _upstream_base_url() -> str:
|
||||
return f"http://127.0.0.1:{LLAMA_SERVER_PORT}"
|
||||
|
||||
|
||||
def _ensure_http_client() -> httpx.AsyncClient:
|
||||
if _http_client is None:
|
||||
raise RuntimeError("HTTP client not initialised")
|
||||
return _http_client
|
||||
|
||||
|
||||
def _validate_model_path() -> None:
|
||||
model_file = Path(MODEL_PATH)
|
||||
if not model_file.is_file():
|
||||
raise RuntimeError(f"model file not found at {MODEL_PATH}")
|
||||
if not os.access(model_file, os.R_OK):
|
||||
raise RuntimeError(f"model file is not readable at {MODEL_PATH}")
|
||||
|
||||
|
||||
def _build_llama_command() -> list[str]:
|
||||
command = [
|
||||
"/usr/local/bin/llama-server",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(LLAMA_SERVER_PORT),
|
||||
"--model",
|
||||
MODEL_PATH,
|
||||
"--alias",
|
||||
LLM_MODEL_NAME,
|
||||
"--ctx-size",
|
||||
str(LLM_CONTEXT_SIZE),
|
||||
"--threads",
|
||||
str(LLM_THREADS),
|
||||
"--n-gpu-layers",
|
||||
str(LLM_GPU_LAYERS),
|
||||
]
|
||||
if LLM_EXTRA_ARGS.strip():
|
||||
command.extend(shlex.split(LLM_EXTRA_ARGS))
|
||||
return command
|
||||
|
||||
|
||||
def _llama_running() -> bool:
|
||||
return _llama_process is not None and _llama_process.poll() is None
|
||||
|
||||
|
||||
async def _wait_for_llama_ready() -> None:
|
||||
deadline = time.monotonic() + LLM_STARTUP_TIMEOUT
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
if _llama_process is not None and _llama_process.poll() is not None:
|
||||
raise RuntimeError(f"llama-server exited with code {_llama_process.poll()}")
|
||||
|
||||
try:
|
||||
response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5)
|
||||
if response.status_code == 200:
|
||||
logger.info("llm service: llama-server ready")
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise RuntimeError(f"llama-server did not become ready within {LLM_STARTUP_TIMEOUT}s: {last_error}")
|
||||
|
||||
|
||||
async def _stop_llama_process() -> None:
|
||||
global _llama_process
|
||||
|
||||
if _llama_process is None:
|
||||
return
|
||||
|
||||
if _llama_process.poll() is None:
|
||||
_llama_process.terminate()
|
||||
try:
|
||||
await asyncio.to_thread(_llama_process.wait, timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
_llama_process.kill()
|
||||
await asyncio.to_thread(_llama_process.wait, timeout=5)
|
||||
|
||||
_llama_process = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global _http_client, _llama_process
|
||||
|
||||
_validate_model_path()
|
||||
_http_client = httpx.AsyncClient(timeout=httpx.Timeout(120, connect=5))
|
||||
|
||||
command = _build_llama_command()
|
||||
logger.info("llm service: starting llama-server model=%s ctx=%s threads=%s gpu_layers=%s upstream_port=%s", LLM_MODEL_NAME, LLM_CONTEXT_SIZE, LLM_THREADS, LLM_GPU_LAYERS, LLAMA_SERVER_PORT)
|
||||
_llama_process = subprocess.Popen(command)
|
||||
|
||||
try:
|
||||
await _wait_for_llama_ready()
|
||||
yield
|
||||
finally:
|
||||
await _stop_llama_process()
|
||||
if _http_client is not None:
|
||||
await _http_client.aclose()
|
||||
_http_client = None
|
||||
|
||||
|
||||
app = FastAPI(title="Skinbase LLM Service", version="1.0.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _health_payload(status: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": status,
|
||||
"model": Path(MODEL_PATH).name,
|
||||
"model_alias": LLM_MODEL_NAME,
|
||||
"context_size": LLM_CONTEXT_SIZE,
|
||||
"threads": LLM_THREADS,
|
||||
"gpu_layers": LLM_GPU_LAYERS,
|
||||
}
|
||||
|
||||
|
||||
async def _proxy_request(method: str, path: str, *, body: bytes | None = None) -> Dict[str, Any]:
|
||||
if not _llama_running():
|
||||
raise HTTPException(status_code=503, detail="llama-server is not running")
|
||||
|
||||
headers = {"content-type": "application/json"} if body is not None else None
|
||||
try:
|
||||
response = await _ensure_http_client().request(
|
||||
method,
|
||||
f"{_upstream_base_url()}{path}",
|
||||
content=body,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(120, connect=5),
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
raise HTTPException(status_code=504, detail=f"llama-server timed out: {exc}")
|
||||
except httpx.RequestError as exc:
|
||||
raise HTTPException(status_code=503, detail=f"llama-server unavailable: {exc}")
|
||||
|
||||
if response.status_code >= 400:
|
||||
detail: Any
|
||||
try:
|
||||
detail = response.json()
|
||||
except Exception:
|
||||
detail = response.text[:1000]
|
||||
raise HTTPException(status_code=response.status_code, detail=detail)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"llama-server returned invalid JSON: {exc}")
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def handle_http_exception(_: Request, exc: HTTPException):
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": {"code": "llm_service_error", "message": str(exc.detail)}})
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
if not _llama_running():
|
||||
return JSONResponse(status_code=503, content=_health_payload("unavailable"))
|
||||
|
||||
try:
|
||||
response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5)
|
||||
if response.status_code != 200:
|
||||
return JSONResponse(status_code=503, content=_health_payload("degraded"))
|
||||
except Exception:
|
||||
return JSONResponse(status_code=503, content=_health_payload("degraded"))
|
||||
|
||||
return _health_payload("ok")
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return await _proxy_request("GET", "/v1/models")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
body = await request.body()
|
||||
return await _proxy_request("POST", "/v1/chat/completions", body=body)
|
||||
3
llm/requirements.txt
Normal file
3
llm/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
httpx==0.27.2
|
||||
17
maturity/Dockerfile
Normal file
17
maturity/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY maturity/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY maturity /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
221
maturity/main.py
Normal file
221
maturity/main.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Skinbase Maturity Analysis Service.
|
||||
|
||||
Uses a dedicated NSFW/maturity ViT classifier (Falconsai/nsfw_image_detection)
|
||||
to produce a structured, normalized maturity signal consumable by Nova moderation
|
||||
workflows.
|
||||
|
||||
Endpoints
|
||||
---------
|
||||
GET /health — liveness + model info
|
||||
POST /analyze — URL-based maturity analysis
|
||||
POST /analyze/file — multipart file-upload maturity analysis
|
||||
|
||||
Response contract
|
||||
-----------------
|
||||
{
|
||||
"maturity_label": "safe" | "mature",
|
||||
"confidence": float, # confidence in the maturity_label decision
|
||||
"score": float, # raw NSFW probability from model (0.0 – 1.0)
|
||||
"labels": list[str], # sublabels when mature, e.g. ["nsfw"]
|
||||
"model": str, # model identifier
|
||||
"threshold_used": float, # threshold that produced the label
|
||||
"analysis_time_ms": float,
|
||||
"source": "maturity-service",
|
||||
"action_hint": "safe" | "review" | "flag_high",
|
||||
"advisory": str # short human-readable reason
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
logger = logging.getLogger("maturity")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration — all tunable via environment variables
|
||||
# ---------------------------------------------------------------------------
|
||||
MATURITY_MODEL = os.getenv("MATURITY_MODEL", "Falconsai/nsfw_image_detection")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Main threshold: score >= this → "mature" + "flag_high"
|
||||
THRESHOLD_MATURE = float(os.getenv("MATURITY_THRESHOLD_MATURE", "0.80"))
|
||||
# Review band: score >= this (but below MATURE) → "mature" + "review"
|
||||
THRESHOLD_REVIEW = float(os.getenv("MATURITY_THRESHOLD_REVIEW", "0.60"))
|
||||
|
||||
# Max image bytes — same default as the rest of the stack (50 MB)
|
||||
MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", str(50 * 1024 * 1024)))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model loading — done once at import time so Docker start captures it
|
||||
# ---------------------------------------------------------------------------
|
||||
logger.info("maturity service: loading model %s on %s", MATURITY_MODEL, DEVICE)
|
||||
_t_load = time.perf_counter()
|
||||
|
||||
_processor = AutoImageProcessor.from_pretrained(MATURITY_MODEL)
|
||||
_model = AutoModelForImageClassification.from_pretrained(MATURITY_MODEL).to(DEVICE).eval()
|
||||
|
||||
# Build a label→index map from the model config so we are not fragile to label
|
||||
# ordering changes.
|
||||
_ID2LABEL: dict[int, str] = _model.config.id2label # e.g. {0: "normal", 1: "nsfw"}
|
||||
_NSFW_IDX: int = next(
|
||||
(i for i, lbl in _ID2LABEL.items() if "nsfw" in lbl.lower() or "explicit" in lbl.lower()),
|
||||
1, # fallback: assume index 1 is the NSFW class
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"maturity service: model loaded elapsed_ms=%.1f device=%s id2label=%s nsfw_idx=%s",
|
||||
(time.perf_counter() - _t_load) * 1000,
|
||||
DEVICE,
|
||||
_ID2LABEL,
|
||||
_NSFW_IDX,
|
||||
)
|
||||
|
||||
app = FastAPI(title="Skinbase Maturity Service", version="1.0.0")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MaturityRequest(BaseModel):
|
||||
url: Optional[str] = Field(default=None, description="Public image URL to analyse")
|
||||
|
||||
|
||||
class MaturityResponse(BaseModel):
|
||||
maturity_label: str = Field(description='"safe" or "mature"')
|
||||
confidence: float = Field(description="Confidence in the maturity_label decision (0–1)")
|
||||
score: float = Field(description="Raw NSFW probability from the model (0–1)")
|
||||
labels: List[str] = Field(description="Sublabels when mature content is detected")
|
||||
model: str = Field(description="Model identifier / version")
|
||||
threshold_used: float = Field(description="Threshold applied to produce the label")
|
||||
analysis_time_ms: float
|
||||
source: str = "maturity-service"
|
||||
action_hint: str = Field(description='"safe", "review", or "flag_high"')
|
||||
advisory: str = Field(description="Short human-readable reason for the decision")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_inference(data: bytes) -> MaturityResponse:
|
||||
"""Run maturity inference on raw image bytes and return a structured response."""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
try:
|
||||
img = bytes_to_pil(data)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}") from exc
|
||||
|
||||
inputs = _processor(images=img, return_tensors="pt").to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = _model(**inputs).logits
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)[0]
|
||||
nsfw_score = float(probs[_NSFW_IDX])
|
||||
|
||||
elapsed_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
# Derive label, action_hint, advisory, sublabels
|
||||
if nsfw_score >= THRESHOLD_MATURE:
|
||||
maturity_label = "mature"
|
||||
action_hint = "flag_high"
|
||||
advisory = "High-confidence mature content detected"
|
||||
labels = ["nsfw"]
|
||||
threshold_used = THRESHOLD_MATURE
|
||||
confidence = nsfw_score
|
||||
elif nsfw_score >= THRESHOLD_REVIEW:
|
||||
maturity_label = "mature"
|
||||
action_hint = "review"
|
||||
advisory = "Possible mature content — review recommended"
|
||||
labels = ["nsfw"]
|
||||
threshold_used = THRESHOLD_REVIEW
|
||||
confidence = nsfw_score
|
||||
else:
|
||||
maturity_label = "safe"
|
||||
action_hint = "safe"
|
||||
advisory = "Content appears safe"
|
||||
labels = []
|
||||
threshold_used = THRESHOLD_REVIEW
|
||||
confidence = 1.0 - nsfw_score # confidence in the "safe" verdict
|
||||
|
||||
logger.info(
|
||||
"maturity inference: maturity_label=%s action_hint=%s score=%.4f "
|
||||
"confidence=%.4f threshold_mature=%.2f threshold_review=%.2f elapsed_ms=%.1f",
|
||||
maturity_label,
|
||||
action_hint,
|
||||
nsfw_score,
|
||||
confidence,
|
||||
THRESHOLD_MATURE,
|
||||
THRESHOLD_REVIEW,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return MaturityResponse(
|
||||
maturity_label=maturity_label,
|
||||
confidence=round(confidence, 4),
|
||||
score=round(nsfw_score, 4),
|
||||
labels=labels,
|
||||
model=MATURITY_MODEL,
|
||||
threshold_used=threshold_used,
|
||||
analysis_time_ms=round(elapsed_ms, 1),
|
||||
source="maturity-service",
|
||||
action_hint=action_hint,
|
||||
advisory=advisory,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"device": DEVICE,
|
||||
"model": MATURITY_MODEL,
|
||||
"threshold_mature": THRESHOLD_MATURE,
|
||||
"threshold_review": THRESHOLD_REVIEW,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/analyze", response_model=MaturityResponse)
|
||||
def analyze(req: MaturityRequest):
|
||||
"""URL-based maturity analysis."""
|
||||
if not req.url:
|
||||
raise HTTPException(status_code=400, detail="url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url, max_bytes=MAX_IMAGE_BYTES)
|
||||
except ImageLoadError as exc:
|
||||
logger.warning("maturity analyze: image load failed url=%s error=%s", req.url, exc)
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
return _run_inference(data)
|
||||
|
||||
|
||||
@app.post("/analyze/file", response_model=MaturityResponse)
|
||||
async def analyze_file(file: UploadFile = File(...)):
|
||||
"""Multipart file-upload maturity analysis."""
|
||||
data = await file.read()
|
||||
if len(data) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File exceeds maximum allowed size of {MAX_IMAGE_BYTES} bytes",
|
||||
)
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="Empty file upload")
|
||||
return _run_inference(data)
|
||||
8
maturity/requirements.txt
Normal file
8
maturity/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
torch==2.4.1
|
||||
torchvision==0.19.1
|
||||
transformers==4.44.2
|
||||
9
models/qwen3/README.md
Normal file
9
models/qwen3/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
Place the Qwen3 GGUF model file for the local llm profile in this directory.
|
||||
|
||||
Expected default filename:
|
||||
|
||||
- `Qwen3-1.7B-Instruct-Q4_K_M.gguf`
|
||||
|
||||
You can use a different filename, but then set `MODEL_PATH` in `.env` to match the mounted path inside the container.
|
||||
|
||||
The model is intentionally not auto-downloaded at startup. Operators should provision it explicitly so container startup is predictable.
|
||||
244
qdrant/backfill_payloads.py
Normal file
244
qdrant/backfill_payloads.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
backfill_payloads.py — Repair missing payload fields for existing Qdrant points.
|
||||
|
||||
WHY THIS EXISTS
|
||||
---------------
|
||||
If artworks were initially upserted without the full payload (is_public, is_nsfw,
|
||||
category_id, content_type_id, is_deleted, status), those fields will have near-0%
|
||||
coverage in the payload index. This prevents filtered searches (e.g., is_public=true)
|
||||
from returning correct results.
|
||||
|
||||
This script scrolls through all points in the collection, detects which ones are
|
||||
missing the required fields, and lets you supply a lookup function that fetches the
|
||||
correct values from your source-of-truth (database, API, CSV, etc.).
|
||||
|
||||
HOW TO ADAPT
|
||||
------------
|
||||
1. Fill in `fetch_payloads_for_ids()` to return a dict mapping qdrant-point-id ->
|
||||
payload patch for each missing ID. The simplest approach is a SQL query to your
|
||||
Skinbase database using the `_original_id` stored in the Qdrant payload.
|
||||
|
||||
2. Run the script directly (no app container needed, just qdrant-client installed):
|
||||
|
||||
# Inside Docker network:
|
||||
docker exec -it vision-qdrant-svc-1 python /app/backfill_payloads.py
|
||||
|
||||
# Or from host with qdrant-client installed:
|
||||
pip install qdrant-client
|
||||
QDRANT_HOST=localhost QDRANT_PORT=6333 python qdrant/backfill_payloads.py
|
||||
|
||||
3. The script is resumable: it prints the last-processed offset ID so you can
|
||||
restart from where you left off by setting RESUME_OFFSET env var.
|
||||
|
||||
REQUIRED ENV VARS (all optional, sensible defaults for Docker Compose):
|
||||
QDRANT_HOST default: qdrant
|
||||
QDRANT_PORT default: 6333
|
||||
COLLECTION_NAME default: images
|
||||
BATCH_SIZE default: 256
|
||||
DRY_RUN default: 0 (set to 1 to only report, no writes)
|
||||
RESUME_OFFSET default: None (UUID or int of last seen point to skip to)
|
||||
|
||||
FIELDS CHECKED
|
||||
--------------
|
||||
user_id, is_public, is_nsfw, category_id, content_type_id, is_deleted, status
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger("backfill")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "images")
|
||||
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "256"))
|
||||
DRY_RUN = os.getenv("DRY_RUN", "0") == "1"
|
||||
RESUME_OFFSET: Optional[str] = os.getenv("RESUME_OFFSET") # point id to continue from
|
||||
|
||||
# Fields that MUST be present in every point payload for filtered search to work.
|
||||
REQUIRED_FIELDS = [
|
||||
"user_id",
|
||||
"is_public",
|
||||
"is_nsfw",
|
||||
"category_id",
|
||||
"content_type_id",
|
||||
"is_deleted",
|
||||
"status",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TODO: implement this function to fetch correct payload values from your DB.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def fetch_payloads_for_ids(
|
||||
missing_ids: List[Any],
|
||||
original_ids: Dict[Any, str],
|
||||
) -> Dict[Any, Dict[str, Any]]:
|
||||
"""Return a mapping of qdrant_point_id -> payload_patch for the given IDs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
missing_ids:
|
||||
List of Qdrant point IDs (UUID strings or ints) that need patching.
|
||||
original_ids:
|
||||
Dict mapping qdrant_point_id -> original application ID (stored in
|
||||
`_original_id` payload field, or the point id itself if they match).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict mapping each point id to a dict of fields to set.
|
||||
Only include the fields you want to SET — existing fields are not cleared.
|
||||
|
||||
Example implementation (pseudo-code for your database):
|
||||
|
||||
import psycopg2
|
||||
conn = psycopg2.connect(os.environ["DATABASE_URL"])
|
||||
cur = conn.cursor()
|
||||
orig_id_list = list(original_ids.values())
|
||||
cur.execute(
|
||||
"SELECT id, user_id, is_public, is_nsfw, category_id, "
|
||||
" content_type_id, is_deleted, status "
|
||||
"FROM artworks WHERE id = ANY(%s)",
|
||||
(orig_id_list,)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
by_orig = {str(r[0]): r for r in rows}
|
||||
result = {}
|
||||
for qdrant_id, orig_id in original_ids.items():
|
||||
row = by_orig.get(str(orig_id))
|
||||
if row:
|
||||
result[qdrant_id] = {
|
||||
"user_id": str(row[1]),
|
||||
"is_public": bool(row[2]),
|
||||
"is_nsfw": bool(row[3]),
|
||||
"category_id": int(row[4]) if row[4] is not None else None,
|
||||
"content_type_id": int(row[5]) if row[5] is not None else None,
|
||||
"is_deleted": bool(row[6]),
|
||||
"status": str(row[7]),
|
||||
}
|
||||
return result
|
||||
"""
|
||||
# ---- STUB: replace with your real implementation ----
|
||||
log.warning(
|
||||
"fetch_payloads_for_ids() is a stub — no data will be patched.\n"
|
||||
"Edit qdrant/backfill_payloads.py and implement this function."
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core backfill logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_backfill():
|
||||
log.info(
|
||||
"backfill start collection=%s host=%s:%s dry_run=%s batch=%d",
|
||||
COLLECTION_NAME, QDRANT_HOST, QDRANT_PORT, DRY_RUN, BATCH_SIZE,
|
||||
)
|
||||
|
||||
qclient = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||||
|
||||
# Verify collection exists
|
||||
collections = [c.name for c in qclient.get_collections().collections]
|
||||
if COLLECTION_NAME not in collections:
|
||||
log.error("Collection '%s' not found. Existing: %s", COLLECTION_NAME, collections)
|
||||
sys.exit(1)
|
||||
|
||||
info = qclient.get_collection(COLLECTION_NAME)
|
||||
total_points = info.points_count or 0
|
||||
log.info("collection points_count=%d indexed_vectors=%d", total_points, info.indexed_vectors_count or 0)
|
||||
|
||||
offset = RESUME_OFFSET
|
||||
scanned = 0
|
||||
missing_count = 0
|
||||
patched = 0
|
||||
errors = 0
|
||||
t_start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
points, next_offset = qclient.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
offset=offset,
|
||||
limit=BATCH_SIZE,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
if not points:
|
||||
break
|
||||
|
||||
scanned += len(points)
|
||||
|
||||
# Find points missing any required field
|
||||
needs_patch: List[Any] = []
|
||||
original_ids: Dict[Any, str] = {}
|
||||
for pt in points:
|
||||
payload = pt.payload or {}
|
||||
missing = [f for f in REQUIRED_FIELDS if f not in payload or payload[f] is None]
|
||||
if missing:
|
||||
needs_patch.append(pt.id)
|
||||
# Use _original_id if present (IDs that couldn't be stored as Qdrant IDs)
|
||||
original_ids[pt.id] = str(payload.get("_original_id", pt.id))
|
||||
missing_count += 1
|
||||
|
||||
if needs_patch:
|
||||
patches = fetch_payloads_for_ids(needs_patch, original_ids)
|
||||
for pid, patch in patches.items():
|
||||
if not patch:
|
||||
continue
|
||||
if DRY_RUN:
|
||||
log.info("[DRY RUN] would patch id=%s fields=%s", pid, list(patch.keys()))
|
||||
else:
|
||||
try:
|
||||
qclient.set_payload(
|
||||
collection_name=COLLECTION_NAME,
|
||||
payload=patch,
|
||||
points=[pid],
|
||||
)
|
||||
patched += 1
|
||||
except Exception as exc:
|
||||
log.error("failed to patch id=%s: %s", pid, exc)
|
||||
errors += 1
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
rate = scanned / elapsed if elapsed > 0 else 0
|
||||
log.info(
|
||||
"progress scanned=%d/%d missing=%d patched=%d errors=%d rate=%.0f/s offset=%s",
|
||||
scanned, total_points, missing_count, patched, errors, rate, next_offset,
|
||||
)
|
||||
|
||||
if next_offset is None:
|
||||
break
|
||||
offset = next_offset
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
log.info(
|
||||
"backfill complete scanned=%d missing=%d patched=%d errors=%d elapsed=%.1fs",
|
||||
scanned, missing_count, patched, errors, elapsed,
|
||||
)
|
||||
|
||||
if missing_count > 0 and patched == 0 and not DRY_RUN:
|
||||
log.warning(
|
||||
"%d points are missing payload fields but 0 were patched. "
|
||||
"Implement fetch_payloads_for_ids() in this script.",
|
||||
missing_count,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_backfill()
|
||||
140
qdrant/main.py
140
qdrant/main.py
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -39,6 +42,9 @@ SEARCH_HNSW_EF = int(os.getenv("SEARCH_HNSW_EF", "128"))
|
||||
app = FastAPI(title="Skinbase Qdrant Service", version="1.0.0")
|
||||
client: QdrantClient = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger("qdrant_svc")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup / shutdown
|
||||
@@ -47,8 +53,24 @@ client: QdrantClient = None # type: ignore[assignment]
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
global client
|
||||
t0 = time.perf_counter()
|
||||
logger.info("qdrant_svc startup: connecting to %s:%s", QDRANT_HOST, QDRANT_PORT)
|
||||
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||||
_ensure_collection()
|
||||
# Warm the gRPC/HTTP connection and load collection metadata into memory
|
||||
# so the first real request does not pay the one-time connect cost.
|
||||
try:
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
logger.info(
|
||||
"qdrant_svc startup: warm ping OK collection=%s points=%s indexed=%s elapsed_ms=%.1f",
|
||||
COLLECTION_NAME,
|
||||
info.points_count,
|
||||
info.indexed_vectors_count,
|
||||
(time.perf_counter() - t0) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("qdrant_svc startup: warm ping failed (non-fatal): %s", exc)
|
||||
logger.info("qdrant_svc startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
|
||||
|
||||
def _ensure_collection():
|
||||
@@ -68,6 +90,44 @@ def _ensure_collection():
|
||||
default_segment_number=4, # parallelism-friendly segment count
|
||||
),
|
||||
)
|
||||
_ensure_payload_indexes()
|
||||
|
||||
|
||||
# Payload fields needed for filtered search. type values match PayloadSchemaType.
|
||||
_REQUIRED_PAYLOAD_INDEXES: List[Dict[str, str]] = [
|
||||
{"field": "user_id", "type": "keyword"},
|
||||
{"field": "is_public", "type": "bool"},
|
||||
{"field": "is_nsfw", "type": "bool"},
|
||||
{"field": "is_deleted", "type": "bool"},
|
||||
{"field": "status", "type": "keyword"},
|
||||
{"field": "category_id", "type": "integer"},
|
||||
{"field": "content_type_id", "type": "integer"},
|
||||
]
|
||||
|
||||
|
||||
def _ensure_payload_indexes():
|
||||
"""Create any missing payload indexes for the default collection."""
|
||||
try:
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
except Exception:
|
||||
return # collection doesn't exist yet, will be created next
|
||||
existing = set(info.payload_schema.keys()) if info.payload_schema else set()
|
||||
for spec in _REQUIRED_PAYLOAD_INDEXES:
|
||||
field = spec["field"]
|
||||
if field in existing:
|
||||
continue
|
||||
schema = _SCHEMA_TYPE_MAP.get(spec["type"])
|
||||
if schema is None:
|
||||
continue
|
||||
try:
|
||||
client.create_payload_index(
|
||||
collection_name=COLLECTION_NAME,
|
||||
field_name=field,
|
||||
field_schema=schema,
|
||||
)
|
||||
logger.info("_ensure_payload_indexes: created index field=%s type=%s", field, spec["type"])
|
||||
except Exception as exc:
|
||||
logger.warning("_ensure_payload_indexes: could not index field=%s: %s", field, exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -213,37 +273,47 @@ def health():
|
||||
|
||||
|
||||
@app.get("/inspect")
|
||||
def inspect():
|
||||
async def inspect():
|
||||
"""Return a full diagnostic summary for every collection.
|
||||
|
||||
Covers: vector counts, segment counts, HNSW config, optimizer config,
|
||||
quantization, payload indexes and their coverage. Designed for production
|
||||
health checks and the Qdrant optimization workflow.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
logger.info("inspect: start")
|
||||
|
||||
try:
|
||||
all_collections = client.get_collections().collections
|
||||
all_collections = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: client.get_collections().collections
|
||||
)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
result = {}
|
||||
for col_desc in all_collections:
|
||||
name = col_desc.name
|
||||
t_col = time.perf_counter()
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
info = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda n=name: client.get_collection(n)
|
||||
)
|
||||
cfg = info.config
|
||||
hnsw = cfg.hnsw_config
|
||||
opt = cfg.optimizer_config
|
||||
quant = cfg.quantization_config
|
||||
params = cfg.params
|
||||
|
||||
# Estimate raw RAM footprint: vectors * dim * 4 bytes * 1.5 safety factor
|
||||
vec_count = info.vectors_count or 0
|
||||
# `vectors_count` is deprecated and returns 0 in newer Qdrant versions;
|
||||
# use `points_count` as the canonical count for coverage and RAM estimates.
|
||||
points_count = info.points_count or 0
|
||||
vec_count = info.vectors_count or points_count # kept for backwards compat display
|
||||
vec_dim = (
|
||||
params.vectors.size
|
||||
if hasattr(params.vectors, "size")
|
||||
else VECTOR_DIM
|
||||
)
|
||||
ram_estimate_mb = round(vec_count * vec_dim * 4 * 1.5 / 1_048_576, 1)
|
||||
ram_estimate_mb = round(points_count * vec_dim * 4 * 1.5 / 1_048_576, 1)
|
||||
|
||||
result[name] = {
|
||||
"status": info.status.value if info.status else None,
|
||||
@@ -272,16 +342,21 @@ def inspect():
|
||||
k: {
|
||||
"type": v.data_type.value if hasattr(v.data_type, "value") else str(v.data_type),
|
||||
"points": v.points,
|
||||
"coverage_pct": round(v.points / max(vec_count, 1) * 100, 1),
|
||||
"coverage_pct": round(v.points / max(points_count, 1) * 100, 1),
|
||||
}
|
||||
for k, v in (info.payload_schema or {}).items()
|
||||
},
|
||||
"payload_index_count": len(info.payload_schema or {}),
|
||||
"search_hnsw_ef": SEARCH_HNSW_EF,
|
||||
}
|
||||
logger.info(
|
||||
"inspect: collection=%s points=%s elapsed_ms=%.1f",
|
||||
name, points_count, (time.perf_counter() - t_col) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
result[name] = {"error": str(exc)}
|
||||
|
||||
logger.info("inspect: done collections=%d total_elapsed_ms=%.1f", len(result), (time.perf_counter() - t0) * 1000)
|
||||
return {"collections": result, "total": len(result)}
|
||||
|
||||
|
||||
@@ -755,3 +830,54 @@ def configure_collection(name: str, req: CollectionConfigRequest):
|
||||
}
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, str(exc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Payload update (used by backfill / repair tooling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BatchUpdatePayloadRequest(BaseModel):
|
||||
"""Update payload fields for a batch of points identified by their Qdrant IDs.
|
||||
|
||||
``updates`` is a list of ``{"id": "<qdrant-point-id>", "payload": {...}}`` items.
|
||||
Only the supplied payload keys are merged into existing payloads (set_payload
|
||||
semantics — existing keys not mentioned are left untouched).
|
||||
"""
|
||||
updates: List[Dict[str, Any]]
|
||||
collection: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/points/batch-update-payload")
|
||||
def batch_update_payload(req: BatchUpdatePayloadRequest):
|
||||
"""Merge payload fields for a list of points without touching vectors.
|
||||
|
||||
Useful for backfilling metadata (is_public, category_id, etc.) for points
|
||||
that were upserted without full payload coverage.
|
||||
"""
|
||||
if not req.updates:
|
||||
return {"updated": 0, "collection": _col(req.collection)}
|
||||
|
||||
col = _col(req.collection)
|
||||
updated = 0
|
||||
errors: List[str] = []
|
||||
|
||||
for item in req.updates:
|
||||
pid_raw = item.get("id")
|
||||
payload = item.get("payload", {})
|
||||
if pid_raw is None or not payload:
|
||||
continue
|
||||
pid = _point_id(str(pid_raw))
|
||||
try:
|
||||
client.set_payload(
|
||||
collection_name=col,
|
||||
payload=payload,
|
||||
points=[pid],
|
||||
)
|
||||
updated += 1
|
||||
except Exception as exc:
|
||||
errors.append(f"{pid_raw}: {exc}")
|
||||
|
||||
result: Dict[str, Any] = {"updated": updated, "collection": col}
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
return result
|
||||
|
||||
313
tests/test_gateway_llm.py
Normal file
313
tests/test_gateway_llm.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
BASE_ENV = {
|
||||
"API_KEY": "test-key",
|
||||
"CLIP_URL": "http://clip:8000",
|
||||
"BLIP_URL": "http://blip:8000",
|
||||
"YOLO_URL": "http://yolo:8000",
|
||||
"QDRANT_SVC_URL": "http://qdrant-svc:8000",
|
||||
"CARD_RENDERER_URL": "http://card-renderer:8000",
|
||||
"MATURITY_URL": "http://maturity:8000",
|
||||
"LLM_URL": "http://llm:8080",
|
||||
"LLM_TIMEOUT": "5",
|
||||
"LLM_DEFAULT_MODEL": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"LLM_MAX_TOKENS_DEFAULT": "256",
|
||||
"LLM_MAX_TOKENS_HARD_LIMIT": "1024",
|
||||
"LLM_MAX_REQUEST_BYTES": "65536",
|
||||
}
|
||||
|
||||
|
||||
def load_gateway_module(*, llm_enabled: bool, extra_env: Optional[Dict[str, str]] = None):
|
||||
env = BASE_ENV | {"LLM_ENABLED": "true" if llm_enabled else "false"}
|
||||
if extra_env:
|
||||
env |= extra_env
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import gateway.main as gateway_main
|
||||
|
||||
return importlib.reload(gateway_main)
|
||||
|
||||
|
||||
class StubUpstreamClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_responses: Optional[Dict[tuple[str, str], httpx.Response]] = None,
|
||||
get_responses: Optional[Dict[str, httpx.Response]] = None,
|
||||
request_exception: Optional[Exception] = None,
|
||||
get_exception: Optional[Exception] = None,
|
||||
):
|
||||
self.request_responses = request_responses or {}
|
||||
self.get_responses = get_responses or {}
|
||||
self.request_exception = request_exception
|
||||
self.get_exception = get_exception
|
||||
|
||||
async def request(self, method: str, url: str, **_: Any) -> httpx.Response:
|
||||
if self.request_exception is not None:
|
||||
raise self.request_exception
|
||||
response = self.request_responses.get((method.upper(), url))
|
||||
if response is None:
|
||||
return httpx.Response(404, json={"error": {"message": f"No stub for {method} {url}"}})
|
||||
return response
|
||||
|
||||
async def get(self, url: str, **_: Any) -> httpx.Response:
|
||||
if self.get_exception is not None:
|
||||
raise self.get_exception
|
||||
response = self.get_responses.get(url)
|
||||
if response is None:
|
||||
return httpx.Response(404, json={"detail": f"No stub for GET {url}"})
|
||||
return response
|
||||
|
||||
|
||||
class GatewayLLMTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def _request(
|
||||
self,
|
||||
module: Any,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_payload: Optional[Dict[str, Any]] = None,
|
||||
content: Optional[bytes] = None,
|
||||
) -> httpx.Response:
|
||||
transport = httpx.ASGITransport(app=module.app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
return await client.request(method, path, headers=headers, json=json_payload, content=content)
|
||||
|
||||
async def test_llm_endpoint_requires_api_key(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 401)
|
||||
self.assertEqual(response.json()["error"]["code"], "unauthorized")
|
||||
|
||||
async def test_llm_disabled_returns_503(self):
|
||||
module = load_gateway_module(llm_enabled=False)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 503)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_disabled")
|
||||
|
||||
async def test_unreachable_llm_returns_normalized_503(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_exception=httpx.ConnectError("boom", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 503)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_unavailable")
|
||||
|
||||
async def test_validation_error_is_normalized(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": []},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 422)
|
||||
self.assertEqual(response.json()["error"]["code"], "validation_error")
|
||||
|
||||
async def test_invalid_json_returns_400(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"X-API-Key": "test-key", "Content-Type": "application/json"},
|
||||
content=b'{"messages": [',
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json()["error"]["code"], "invalid_json")
|
||||
|
||||
async def test_oversized_payload_returns_413(self):
|
||||
module = load_gateway_module(llm_enabled=True, extra_env={"LLM_MAX_REQUEST_BYTES": "64"})
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "x" * 5000}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 413)
|
||||
self.assertEqual(response.json()["error"]["code"], "payload_too_large")
|
||||
|
||||
async def test_ai_chat_normalizes_successful_response(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
upstream_response = httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion",
|
||||
"model": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": "Generated text here."},
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
|
||||
},
|
||||
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
|
||||
)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): upstream_response},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(
|
||||
response.json(),
|
||||
{
|
||||
"model": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"content": "Generated text here.",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
async def test_ai_health_reports_reachable_llm(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
stub_client = StubUpstreamClient(
|
||||
get_responses={
|
||||
f"{module.LLM_URL}/health": httpx.Response(
|
||||
200,
|
||||
json={"status": "ok", "model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf", "context_size": 4096, "threads": 4},
|
||||
request=httpx.Request("GET", f"{module.LLM_URL}/health"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"GET",
|
||||
"/ai/health",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(response.json()["reachable"])
|
||||
self.assertEqual(response.json()["default_model"], "qwen3-1.7b-instruct-q4_k_m")
|
||||
|
||||
async def test_timeout_returns_504(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_exception=httpx.ReadTimeout("timeout", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 504)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_timeout")
|
||||
|
||||
async def test_upstream_400_is_preserved(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
bad_request_response = httpx.Response(
|
||||
400,
|
||||
json={"error": {"message": "Bad prompt"}},
|
||||
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
|
||||
)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): bad_request_response},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_rejected_request")
|
||||
|
||||
async def test_models_endpoint_returns_upstream_metadata(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
models_response = httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"object": "model",
|
||||
"owned_by": "self-hosted",
|
||||
}
|
||||
],
|
||||
},
|
||||
request=httpx.Request("GET", f"{module.LLM_URL}/v1/models"),
|
||||
)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_responses={("GET", f"{module.LLM_URL}/v1/models"): models_response},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"GET",
|
||||
"/v1/models",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json()["data"][0]["id"], "qwen3-1.7b-instruct-q4_k_m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
79
tests/test_llm_service.py
Normal file
79
tests/test_llm_service.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
BASE_ENV = {
|
||||
"MODEL_PATH": "D:/Sites/vision/models/qwen3/Qwen3-1.7B-Instruct-Q4_K_M.gguf",
|
||||
"LLM_MODEL_NAME": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"LLM_CONTEXT_SIZE": "4096",
|
||||
"LLM_THREADS": "4",
|
||||
"LLM_GPU_LAYERS": "0",
|
||||
"LLM_PORT": "8080",
|
||||
"LLAMA_SERVER_PORT": "8081",
|
||||
}
|
||||
|
||||
|
||||
def load_llm_module():
|
||||
with patch.dict(os.environ, BASE_ENV, clear=False):
|
||||
import llm.main as llm_main
|
||||
|
||||
return importlib.reload(llm_main)
|
||||
|
||||
|
||||
class StubHTTPClient:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
|
||||
async def get(self, *_args, **_kwargs):
|
||||
return self.response
|
||||
|
||||
|
||||
class LLMServiceTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_health_returns_repo_owned_contract(self):
|
||||
module = load_llm_module()
|
||||
module._llama_process = SimpleNamespace(poll=lambda: None)
|
||||
module._http_client = StubHTTPClient(
|
||||
httpx.Response(200, json={"object": "list", "data": []}, request=httpx.Request("GET", "http://127.0.0.1:8081/v1/models"))
|
||||
)
|
||||
|
||||
transport = httpx.ASGITransport(app=module.app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(
|
||||
response.json(),
|
||||
{
|
||||
"status": "ok",
|
||||
"model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf",
|
||||
"model_alias": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"context_size": 4096,
|
||||
"threads": 4,
|
||||
"gpu_layers": 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def test_health_reports_unavailable_when_process_is_down(self):
|
||||
module = load_llm_module()
|
||||
module._llama_process = SimpleNamespace(poll=lambda: 1)
|
||||
module._http_client = StubHTTPClient(
|
||||
httpx.Response(200, json={"object": "list", "data": []}, request=httpx.Request("GET", "http://127.0.0.1:8081/v1/models"))
|
||||
)
|
||||
|
||||
transport = httpx.ASGITransport(app=module.app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
self.assertEqual(response.status_code, 503)
|
||||
self.assertEqual(response.json()["status"], "unavailable")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user