llm: add FastAPI shim, gateway LLM endpoints, tests, and docs

This commit is contained in:
2026-04-12 09:41:21 +02:00
parent baf497b015
commit 59c9584250
15 changed files with 1779 additions and 11 deletions

View File

@@ -1,17 +1,18 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Literal, Optional
import httpx
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
from fastapi.responses import JSONResponse, Response
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError, field_validator
logger = logging.getLogger("gateway")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
@@ -23,6 +24,16 @@ QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000")
MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000")
MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no")
LLM_URL = os.getenv("LLM_URL", "http://llm:8080")
LLM_ENABLED = os.getenv("LLM_ENABLED", "false").lower() not in ("0", "false", "no")
LLM_DEFAULT_MODEL = os.getenv("LLM_DEFAULT_MODEL", "qwen3-1.7b-instruct-q4_k_m")
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120"))
LLM_MAX_TOKENS_HARD_LIMIT = max(1, int(os.getenv("LLM_MAX_TOKENS_HARD_LIMIT", "1024")))
LLM_MAX_TOKENS_DEFAULT = min(
LLM_MAX_TOKENS_HARD_LIMIT,
max(1, int(os.getenv("LLM_MAX_TOKENS_DEFAULT", "256"))),
)
LLM_MAX_REQUEST_BYTES = max(1024, int(os.getenv("LLM_MAX_REQUEST_BYTES", "65536")))
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
@@ -36,6 +47,21 @@ API_KEY = os.getenv("API_KEY")
_http_client: httpx.AsyncClient | None = None
class LLMGatewayError(Exception):
def __init__(
self,
status_code: int,
code: str,
message: str,
details: Optional[Any] = None,
):
self.status_code = status_code
self.code = code
self.message = message
self.details = details
super().__init__(message)
def get_http_client() -> httpx.AsyncClient:
"""Return the shared httpx client. Raises if called before lifespan starts."""
if _http_client is None:
@@ -74,6 +100,17 @@ async def lifespan(app: FastAPI):
except Exception as exc:
logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc)
if LLM_ENABLED:
try:
t_warm = time.perf_counter()
r = await _http_client.get(f"{LLM_URL}/health", timeout=min(LLM_TIMEOUT, 10))
logger.info(
"gateway startup: llm warm ping done status=%s elapsed_ms=%.1f",
r.status_code, (time.perf_counter() - t_warm) * 1000,
)
except Exception as exc:
logger.warning("gateway startup: llm warm ping failed (non-fatal): %s", exc)
logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
yield # application runs
@@ -90,13 +127,31 @@ class APIKeyMiddleware(BaseHTTPMiddleware):
return await call_next(request)
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
if not API_KEY or key != API_KEY:
if _is_llm_path(request.url.path):
return JSONResponse(
status_code=401,
content={"error": {"code": "unauthorized", "message": "Unauthorized"}},
)
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
return await call_next(request)
def _is_llm_path(path: str) -> bool:
return path.startswith("/v1/") or path.startswith("/ai/")
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan)
app.add_middleware(APIKeyMiddleware)
@app.exception_handler(LLMGatewayError)
async def handle_llm_gateway_error(_: Request, exc: LLMGatewayError):
error: Dict[str, Any] = {"code": exc.code, "message": exc.message}
if exc.details is not None:
error["details"] = exc.details
return JSONResponse(status_code=exc.status_code, content={"error": error})
class ClipRequest(BaseModel):
url: Optional[str] = None
limit: int = Field(default=5, ge=1, le=50)
@@ -118,6 +173,219 @@ class MaturityRequest(BaseModel):
url: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
@field_validator("content")
@classmethod
def validate_content(cls, value: str) -> str:
if not value or not value.strip():
raise ValueError("message content must not be empty")
return value
class ChatCompletionRequest(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage] = Field(min_length=1, max_length=100)
temperature: Optional[float] = None
max_tokens: Optional[int] = Field(default=None, ge=1)
stream: bool = False
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
stop: Optional[str | List[str]] = None
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
@field_validator("model")
@classmethod
def validate_model(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
model = value.strip()
if not model:
raise ValueError("model must not be empty")
return model
@field_validator("temperature")
@classmethod
def validate_temperature(cls, value: Optional[float]) -> Optional[float]:
if value is None:
return value
if value < 0.0 or value > 2.0:
raise ValueError("temperature must be between 0 and 2")
return value
def _llm_timeout() -> httpx.Timeout:
return httpx.Timeout(LLM_TIMEOUT, connect=min(LLM_TIMEOUT, 10))
def _assert_llm_enabled() -> None:
if not LLM_ENABLED:
raise LLMGatewayError(503, "llm_disabled", "LLM service is disabled")
def _extract_upstream_error_message(response: httpx.Response) -> str:
try:
payload = response.json()
except Exception:
payload = None
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict) and error.get("message"):
return str(error["message"])
if payload.get("message"):
return str(payload["message"])
if payload.get("detail"):
return str(payload["detail"])
text = response.text.strip()
return text[:500] if text else f"Upstream returned HTTP {response.status_code}"
def _map_upstream_llm_status(status_code: int) -> int:
if status_code in (400, 413, 422):
return status_code
if 400 <= status_code < 500:
return 422
return 503
def _normalize_chat_payload(payload: ChatCompletionRequest) -> Dict[str, Any]:
normalized = payload.model_dump(exclude_none=True)
normalized["model"] = normalized.get("model") or LLM_DEFAULT_MODEL
normalized["max_tokens"] = min(
int(normalized.get("max_tokens") or LLM_MAX_TOKENS_DEFAULT),
LLM_MAX_TOKENS_HARD_LIMIT,
)
if "temperature" in normalized:
normalized["temperature"] = max(0.0, min(2.0, float(normalized["temperature"])))
if normalized.get("stream"):
raise LLMGatewayError(
422,
"streaming_not_supported",
"Streaming responses are not enabled for this gateway",
)
return normalized
async def _parse_llm_request(request: Request) -> ChatCompletionRequest:
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > LLM_MAX_REQUEST_BYTES:
raise LLMGatewayError(
413,
"payload_too_large",
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
)
except ValueError:
raise LLMGatewayError(400, "invalid_request", "Invalid Content-Length header")
body = await request.body()
if not body:
raise LLMGatewayError(400, "invalid_request", "Request body is required")
if len(body) > LLM_MAX_REQUEST_BYTES:
raise LLMGatewayError(
413,
"payload_too_large",
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
)
try:
payload = json.loads(body)
except json.JSONDecodeError:
raise LLMGatewayError(400, "invalid_json", "Request body must be valid JSON")
if not isinstance(payload, dict):
raise LLMGatewayError(400, "invalid_request", "JSON body must be an object")
try:
return ChatCompletionRequest.model_validate(payload)
except ValidationError as exc:
raise LLMGatewayError(422, "validation_error", "Invalid chat request", exc.errors())
async def _llm_request(
method: str,
path: str,
*,
json_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
_assert_llm_enabled()
url = f"{LLM_URL}{path}"
try:
response = await get_http_client().request(
method,
url,
json=json_payload,
timeout=_llm_timeout(),
)
except httpx.TimeoutException:
raise LLMGatewayError(504, "llm_timeout", "LLM request timed out")
except httpx.RequestError as exc:
raise LLMGatewayError(503, "llm_unavailable", f"LLM service is unavailable: {exc}")
if response.status_code >= 500:
raise LLMGatewayError(503, "llm_unavailable", _extract_upstream_error_message(response))
if response.status_code >= 400:
raise LLMGatewayError(
_map_upstream_llm_status(response.status_code),
"llm_rejected_request",
_extract_upstream_error_message(response),
)
try:
return response.json()
except Exception:
raise LLMGatewayError(503, "llm_invalid_response", "LLM service returned invalid JSON")
def _normalize_ai_chat_response(response: Dict[str, Any]) -> Dict[str, Any]:
choices = response.get("choices")
if not isinstance(choices, list) or not choices:
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain choices")
first_choice = choices[0] if isinstance(choices[0], dict) else {}
message = first_choice.get("message") if isinstance(first_choice.get("message"), dict) else {}
content = message.get("content")
if not isinstance(content, str):
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain message content")
usage = response.get("usage") if isinstance(response.get("usage"), dict) else {}
return {
"model": response.get("model") or LLM_DEFAULT_MODEL,
"content": content,
"finish_reason": first_choice.get("finish_reason") or "stop",
"usage": {
"prompt_tokens": int(usage.get("prompt_tokens") or 0),
"completion_tokens": int(usage.get("completion_tokens") or 0),
"total_tokens": int(usage.get("total_tokens") or 0),
},
}
async def _get_llm_models_payload() -> Dict[str, Any]:
models = await _llm_request("GET", "/v1/models")
if isinstance(models.get("data"), list) and models["data"]:
return models
return {
"object": "list",
"data": [
{
"id": LLM_DEFAULT_MODEL,
"object": "model",
"owned_by": "self-hosted",
}
],
}
async def _get_health(base: str) -> Dict[str, Any]:
try:
r = await get_http_client().get(f"{base}/health", timeout=5)
@@ -184,8 +452,12 @@ async def health():
_get_health(YOLO_URL),
_get_health(QDRANT_SVC_URL),
]
llm_index: Optional[int] = None
if MATURITY_ENABLED:
health_checks.append(_get_health(MATURITY_URL))
if LLM_ENABLED:
llm_index = len(health_checks)
health_checks.append(_get_health(LLM_URL))
results = await asyncio.gather(*health_checks)
services: Dict[str, Any] = {
@@ -196,9 +468,71 @@ async def health():
}
if MATURITY_ENABLED:
services["maturity"] = results[4]
if LLM_ENABLED and llm_index is not None:
services["llm"] = {
"enabled": True,
"default_model": LLM_DEFAULT_MODEL,
"upstream": results[llm_index],
}
else:
services["llm"] = {
"enabled": False,
"default_model": LLM_DEFAULT_MODEL,
"upstream": {"status": "disabled"},
}
return {"status": "ok", "services": services}
@app.post("/v1/chat/completions")
async def llm_chat_completions(request: Request):
payload = _normalize_chat_payload(await _parse_llm_request(request))
return await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
@app.get("/v1/models")
async def llm_models():
return await _get_llm_models_payload()
@app.post("/ai/chat")
async def ai_chat(request: Request):
payload = _normalize_chat_payload(await _parse_llm_request(request))
response = await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
return _normalize_ai_chat_response(response)
@app.get("/ai/models")
async def ai_models():
models = await _get_llm_models_payload()
return {
"enabled": LLM_ENABLED,
"default_model": LLM_DEFAULT_MODEL,
"models": models.get("data", []),
}
@app.get("/ai/health")
async def ai_health():
if not LLM_ENABLED:
return {
"status": "ok",
"enabled": False,
"reachable": False,
"default_model": LLM_DEFAULT_MODEL,
"upstream": {"status": "disabled"},
}
upstream = await _get_health(LLM_URL)
reachable = upstream.get("status") == "ok"
return {
"status": "ok" if reachable else "degraded",
"enabled": True,
"reachable": reachable,
"default_model": LLM_DEFAULT_MODEL,
"upstream": upstream,
}
# ---- Individual analyze endpoints (URL) ----
@app.post("/analyze/clip")