llm: add FastAPI shim, gateway LLM endpoints, tests, and docs
This commit is contained in:
338
gateway/main.py
338
gateway/main.py
@@ -1,17 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
logger = logging.getLogger("gateway")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
@@ -23,6 +24,16 @@ QDRANT_SVC_URL = os.getenv("QDRANT_SVC_URL", "http://qdrant-svc:8000")
|
||||
CARD_RENDERER_URL = os.getenv("CARD_RENDERER_URL", "http://card-renderer:8000")
|
||||
MATURITY_URL = os.getenv("MATURITY_URL", "http://maturity:8000")
|
||||
MATURITY_ENABLED = os.getenv("MATURITY_ENABLED", "true").lower() not in ("0", "false", "no")
|
||||
LLM_URL = os.getenv("LLM_URL", "http://llm:8080")
|
||||
LLM_ENABLED = os.getenv("LLM_ENABLED", "false").lower() not in ("0", "false", "no")
|
||||
LLM_DEFAULT_MODEL = os.getenv("LLM_DEFAULT_MODEL", "qwen3-1.7b-instruct-q4_k_m")
|
||||
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120"))
|
||||
LLM_MAX_TOKENS_HARD_LIMIT = max(1, int(os.getenv("LLM_MAX_TOKENS_HARD_LIMIT", "1024")))
|
||||
LLM_MAX_TOKENS_DEFAULT = min(
|
||||
LLM_MAX_TOKENS_HARD_LIMIT,
|
||||
max(1, int(os.getenv("LLM_MAX_TOKENS_DEFAULT", "256"))),
|
||||
)
|
||||
LLM_MAX_REQUEST_BYTES = max(1024, int(os.getenv("LLM_MAX_REQUEST_BYTES", "65536")))
|
||||
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "20"))
|
||||
|
||||
# API key (set via env var `API_KEY`). If not set, gateway will reject requests.
|
||||
@@ -36,6 +47,21 @@ API_KEY = os.getenv("API_KEY")
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
class LLMGatewayError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
code: str,
|
||||
message: str,
|
||||
details: Optional[Any] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.details = details
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
"""Return the shared httpx client. Raises if called before lifespan starts."""
|
||||
if _http_client is None:
|
||||
@@ -74,6 +100,17 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as exc:
|
||||
logger.warning("gateway startup: qdrant-svc warm ping failed (non-fatal): %s", exc)
|
||||
|
||||
if LLM_ENABLED:
|
||||
try:
|
||||
t_warm = time.perf_counter()
|
||||
r = await _http_client.get(f"{LLM_URL}/health", timeout=min(LLM_TIMEOUT, 10))
|
||||
logger.info(
|
||||
"gateway startup: llm warm ping done status=%s elapsed_ms=%.1f",
|
||||
r.status_code, (time.perf_counter() - t_warm) * 1000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("gateway startup: llm warm ping failed (non-fatal): %s", exc)
|
||||
|
||||
logger.info("gateway startup complete elapsed_ms=%.1f", (time.perf_counter() - t0) * 1000)
|
||||
|
||||
yield # application runs
|
||||
@@ -90,13 +127,31 @@ class APIKeyMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
|
||||
if not API_KEY or key != API_KEY:
|
||||
if _is_llm_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": {"code": "unauthorized", "message": "Unauthorized"}},
|
||||
)
|
||||
return JSONResponse(status_code=401, content={"detail": "Unauthorized"})
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def _is_llm_path(path: str) -> bool:
|
||||
return path.startswith("/v1/") or path.startswith("/ai/")
|
||||
|
||||
|
||||
app = FastAPI(title="Skinbase Vision Gateway", version="1.0.0", lifespan=lifespan)
|
||||
app.add_middleware(APIKeyMiddleware)
|
||||
|
||||
|
||||
@app.exception_handler(LLMGatewayError)
|
||||
async def handle_llm_gateway_error(_: Request, exc: LLMGatewayError):
|
||||
error: Dict[str, Any] = {"code": exc.code, "message": exc.message}
|
||||
if exc.details is not None:
|
||||
error["details"] = exc.details
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": error})
|
||||
|
||||
|
||||
class ClipRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
limit: int = Field(default=5, ge=1, le=50)
|
||||
@@ -118,6 +173,219 @@ class MaturityRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
@field_validator("content")
|
||||
@classmethod
|
||||
def validate_content(cls, value: str) -> str:
|
||||
if not value or not value.strip():
|
||||
raise ValueError("message content must not be empty")
|
||||
return value
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatMessage] = Field(min_length=1, max_length=100)
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = Field(default=None, ge=1)
|
||||
stream: bool = False
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
stop: Optional[str | List[str]] = None
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
model = value.strip()
|
||||
if not model:
|
||||
raise ValueError("model must not be empty")
|
||||
return model
|
||||
|
||||
@field_validator("temperature")
|
||||
@classmethod
|
||||
def validate_temperature(cls, value: Optional[float]) -> Optional[float]:
|
||||
if value is None:
|
||||
return value
|
||||
if value < 0.0 or value > 2.0:
|
||||
raise ValueError("temperature must be between 0 and 2")
|
||||
return value
|
||||
|
||||
|
||||
def _llm_timeout() -> httpx.Timeout:
|
||||
return httpx.Timeout(LLM_TIMEOUT, connect=min(LLM_TIMEOUT, 10))
|
||||
|
||||
|
||||
def _assert_llm_enabled() -> None:
|
||||
if not LLM_ENABLED:
|
||||
raise LLMGatewayError(503, "llm_disabled", "LLM service is disabled")
|
||||
|
||||
|
||||
def _extract_upstream_error_message(response: httpx.Response) -> str:
|
||||
try:
|
||||
payload = response.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict) and error.get("message"):
|
||||
return str(error["message"])
|
||||
if payload.get("message"):
|
||||
return str(payload["message"])
|
||||
if payload.get("detail"):
|
||||
return str(payload["detail"])
|
||||
|
||||
text = response.text.strip()
|
||||
return text[:500] if text else f"Upstream returned HTTP {response.status_code}"
|
||||
|
||||
|
||||
def _map_upstream_llm_status(status_code: int) -> int:
|
||||
if status_code in (400, 413, 422):
|
||||
return status_code
|
||||
if 400 <= status_code < 500:
|
||||
return 422
|
||||
return 503
|
||||
|
||||
|
||||
def _normalize_chat_payload(payload: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
normalized = payload.model_dump(exclude_none=True)
|
||||
normalized["model"] = normalized.get("model") or LLM_DEFAULT_MODEL
|
||||
normalized["max_tokens"] = min(
|
||||
int(normalized.get("max_tokens") or LLM_MAX_TOKENS_DEFAULT),
|
||||
LLM_MAX_TOKENS_HARD_LIMIT,
|
||||
)
|
||||
|
||||
if "temperature" in normalized:
|
||||
normalized["temperature"] = max(0.0, min(2.0, float(normalized["temperature"])))
|
||||
|
||||
if normalized.get("stream"):
|
||||
raise LLMGatewayError(
|
||||
422,
|
||||
"streaming_not_supported",
|
||||
"Streaming responses are not enabled for this gateway",
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
async def _parse_llm_request(request: Request) -> ChatCompletionRequest:
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length:
|
||||
try:
|
||||
if int(content_length) > LLM_MAX_REQUEST_BYTES:
|
||||
raise LLMGatewayError(
|
||||
413,
|
||||
"payload_too_large",
|
||||
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
|
||||
)
|
||||
except ValueError:
|
||||
raise LLMGatewayError(400, "invalid_request", "Invalid Content-Length header")
|
||||
|
||||
body = await request.body()
|
||||
if not body:
|
||||
raise LLMGatewayError(400, "invalid_request", "Request body is required")
|
||||
if len(body) > LLM_MAX_REQUEST_BYTES:
|
||||
raise LLMGatewayError(
|
||||
413,
|
||||
"payload_too_large",
|
||||
f"Request exceeds {LLM_MAX_REQUEST_BYTES} bytes",
|
||||
)
|
||||
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
raise LLMGatewayError(400, "invalid_json", "Request body must be valid JSON")
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise LLMGatewayError(400, "invalid_request", "JSON body must be an object")
|
||||
|
||||
try:
|
||||
return ChatCompletionRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise LLMGatewayError(422, "validation_error", "Invalid chat request", exc.errors())
|
||||
|
||||
|
||||
async def _llm_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json_payload: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_assert_llm_enabled()
|
||||
|
||||
url = f"{LLM_URL}{path}"
|
||||
try:
|
||||
response = await get_http_client().request(
|
||||
method,
|
||||
url,
|
||||
json=json_payload,
|
||||
timeout=_llm_timeout(),
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise LLMGatewayError(504, "llm_timeout", "LLM request timed out")
|
||||
except httpx.RequestError as exc:
|
||||
raise LLMGatewayError(503, "llm_unavailable", f"LLM service is unavailable: {exc}")
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise LLMGatewayError(503, "llm_unavailable", _extract_upstream_error_message(response))
|
||||
if response.status_code >= 400:
|
||||
raise LLMGatewayError(
|
||||
_map_upstream_llm_status(response.status_code),
|
||||
"llm_rejected_request",
|
||||
_extract_upstream_error_message(response),
|
||||
)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except Exception:
|
||||
raise LLMGatewayError(503, "llm_invalid_response", "LLM service returned invalid JSON")
|
||||
|
||||
|
||||
def _normalize_ai_chat_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
choices = response.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain choices")
|
||||
|
||||
first_choice = choices[0] if isinstance(choices[0], dict) else {}
|
||||
message = first_choice.get("message") if isinstance(first_choice.get("message"), dict) else {}
|
||||
content = message.get("content")
|
||||
if not isinstance(content, str):
|
||||
raise LLMGatewayError(503, "llm_invalid_response", "LLM response did not contain message content")
|
||||
|
||||
usage = response.get("usage") if isinstance(response.get("usage"), dict) else {}
|
||||
return {
|
||||
"model": response.get("model") or LLM_DEFAULT_MODEL,
|
||||
"content": content,
|
||||
"finish_reason": first_choice.get("finish_reason") or "stop",
|
||||
"usage": {
|
||||
"prompt_tokens": int(usage.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(usage.get("completion_tokens") or 0),
|
||||
"total_tokens": int(usage.get("total_tokens") or 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _get_llm_models_payload() -> Dict[str, Any]:
|
||||
models = await _llm_request("GET", "/v1/models")
|
||||
if isinstance(models.get("data"), list) and models["data"]:
|
||||
return models
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": LLM_DEFAULT_MODEL,
|
||||
"object": "model",
|
||||
"owned_by": "self-hosted",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def _get_health(base: str) -> Dict[str, Any]:
|
||||
try:
|
||||
r = await get_http_client().get(f"{base}/health", timeout=5)
|
||||
@@ -184,8 +452,12 @@ async def health():
|
||||
_get_health(YOLO_URL),
|
||||
_get_health(QDRANT_SVC_URL),
|
||||
]
|
||||
llm_index: Optional[int] = None
|
||||
if MATURITY_ENABLED:
|
||||
health_checks.append(_get_health(MATURITY_URL))
|
||||
if LLM_ENABLED:
|
||||
llm_index = len(health_checks)
|
||||
health_checks.append(_get_health(LLM_URL))
|
||||
|
||||
results = await asyncio.gather(*health_checks)
|
||||
services: Dict[str, Any] = {
|
||||
@@ -196,9 +468,71 @@ async def health():
|
||||
}
|
||||
if MATURITY_ENABLED:
|
||||
services["maturity"] = results[4]
|
||||
if LLM_ENABLED and llm_index is not None:
|
||||
services["llm"] = {
|
||||
"enabled": True,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": results[llm_index],
|
||||
}
|
||||
else:
|
||||
services["llm"] = {
|
||||
"enabled": False,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": {"status": "disabled"},
|
||||
}
|
||||
return {"status": "ok", "services": services}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def llm_chat_completions(request: Request):
|
||||
payload = _normalize_chat_payload(await _parse_llm_request(request))
|
||||
return await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def llm_models():
|
||||
return await _get_llm_models_payload()
|
||||
|
||||
|
||||
@app.post("/ai/chat")
|
||||
async def ai_chat(request: Request):
|
||||
payload = _normalize_chat_payload(await _parse_llm_request(request))
|
||||
response = await _llm_request("POST", "/v1/chat/completions", json_payload=payload)
|
||||
return _normalize_ai_chat_response(response)
|
||||
|
||||
|
||||
@app.get("/ai/models")
|
||||
async def ai_models():
|
||||
models = await _get_llm_models_payload()
|
||||
return {
|
||||
"enabled": LLM_ENABLED,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"models": models.get("data", []),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/ai/health")
|
||||
async def ai_health():
|
||||
if not LLM_ENABLED:
|
||||
return {
|
||||
"status": "ok",
|
||||
"enabled": False,
|
||||
"reachable": False,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": {"status": "disabled"},
|
||||
}
|
||||
|
||||
upstream = await _get_health(LLM_URL)
|
||||
reachable = upstream.get("status") == "ok"
|
||||
return {
|
||||
"status": "ok" if reachable else "degraded",
|
||||
"enabled": True,
|
||||
"reachable": reachable,
|
||||
"default_model": LLM_DEFAULT_MODEL,
|
||||
"upstream": upstream,
|
||||
}
|
||||
|
||||
|
||||
# ---- Individual analyze endpoints (URL) ----
|
||||
|
||||
@app.post("/analyze/clip")
|
||||
|
||||
Reference in New Issue
Block a user