llm: add FastAPI shim, gateway LLM endpoints, tests, and docs
This commit is contained in:
313
tests/test_gateway_llm.py
Normal file
313
tests/test_gateway_llm.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
BASE_ENV = {
|
||||
"API_KEY": "test-key",
|
||||
"CLIP_URL": "http://clip:8000",
|
||||
"BLIP_URL": "http://blip:8000",
|
||||
"YOLO_URL": "http://yolo:8000",
|
||||
"QDRANT_SVC_URL": "http://qdrant-svc:8000",
|
||||
"CARD_RENDERER_URL": "http://card-renderer:8000",
|
||||
"MATURITY_URL": "http://maturity:8000",
|
||||
"LLM_URL": "http://llm:8080",
|
||||
"LLM_TIMEOUT": "5",
|
||||
"LLM_DEFAULT_MODEL": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"LLM_MAX_TOKENS_DEFAULT": "256",
|
||||
"LLM_MAX_TOKENS_HARD_LIMIT": "1024",
|
||||
"LLM_MAX_REQUEST_BYTES": "65536",
|
||||
}
|
||||
|
||||
|
||||
def load_gateway_module(*, llm_enabled: bool, extra_env: Optional[Dict[str, str]] = None):
|
||||
env = BASE_ENV | {"LLM_ENABLED": "true" if llm_enabled else "false"}
|
||||
if extra_env:
|
||||
env |= extra_env
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import gateway.main as gateway_main
|
||||
|
||||
return importlib.reload(gateway_main)
|
||||
|
||||
|
||||
class StubUpstreamClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_responses: Optional[Dict[tuple[str, str], httpx.Response]] = None,
|
||||
get_responses: Optional[Dict[str, httpx.Response]] = None,
|
||||
request_exception: Optional[Exception] = None,
|
||||
get_exception: Optional[Exception] = None,
|
||||
):
|
||||
self.request_responses = request_responses or {}
|
||||
self.get_responses = get_responses or {}
|
||||
self.request_exception = request_exception
|
||||
self.get_exception = get_exception
|
||||
|
||||
async def request(self, method: str, url: str, **_: Any) -> httpx.Response:
|
||||
if self.request_exception is not None:
|
||||
raise self.request_exception
|
||||
response = self.request_responses.get((method.upper(), url))
|
||||
if response is None:
|
||||
return httpx.Response(404, json={"error": {"message": f"No stub for {method} {url}"}})
|
||||
return response
|
||||
|
||||
async def get(self, url: str, **_: Any) -> httpx.Response:
|
||||
if self.get_exception is not None:
|
||||
raise self.get_exception
|
||||
response = self.get_responses.get(url)
|
||||
if response is None:
|
||||
return httpx.Response(404, json={"detail": f"No stub for GET {url}"})
|
||||
return response
|
||||
|
||||
|
||||
class GatewayLLMTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def _request(
|
||||
self,
|
||||
module: Any,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
json_payload: Optional[Dict[str, Any]] = None,
|
||||
content: Optional[bytes] = None,
|
||||
) -> httpx.Response:
|
||||
transport = httpx.ASGITransport(app=module.app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
return await client.request(method, path, headers=headers, json=json_payload, content=content)
|
||||
|
||||
async def test_llm_endpoint_requires_api_key(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 401)
|
||||
self.assertEqual(response.json()["error"]["code"], "unauthorized")
|
||||
|
||||
async def test_llm_disabled_returns_503(self):
|
||||
module = load_gateway_module(llm_enabled=False)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 503)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_disabled")
|
||||
|
||||
async def test_unreachable_llm_returns_normalized_503(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_exception=httpx.ConnectError("boom", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 503)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_unavailable")
|
||||
|
||||
async def test_validation_error_is_normalized(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": []},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 422)
|
||||
self.assertEqual(response.json()["error"]["code"], "validation_error")
|
||||
|
||||
async def test_invalid_json_returns_400(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"X-API-Key": "test-key", "Content-Type": "application/json"},
|
||||
content=b'{"messages": [',
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json()["error"]["code"], "invalid_json")
|
||||
|
||||
async def test_oversized_payload_returns_413(self):
|
||||
module = load_gateway_module(llm_enabled=True, extra_env={"LLM_MAX_REQUEST_BYTES": "64"})
|
||||
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "x" * 5000}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 413)
|
||||
self.assertEqual(response.json()["error"]["code"], "payload_too_large")
|
||||
|
||||
async def test_ai_chat_normalizes_successful_response(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
upstream_response = httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion",
|
||||
"model": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {"role": "assistant", "content": "Generated text here."},
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
|
||||
},
|
||||
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
|
||||
)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): upstream_response},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(
|
||||
response.json(),
|
||||
{
|
||||
"model": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"content": "Generated text here.",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 12, "completion_tokens": 8, "total_tokens": 20},
|
||||
},
|
||||
)
|
||||
|
||||
async def test_ai_health_reports_reachable_llm(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
stub_client = StubUpstreamClient(
|
||||
get_responses={
|
||||
f"{module.LLM_URL}/health": httpx.Response(
|
||||
200,
|
||||
json={"status": "ok", "model": "Qwen3-1.7B-Instruct-Q4_K_M.gguf", "context_size": 4096, "threads": 4},
|
||||
request=httpx.Request("GET", f"{module.LLM_URL}/health"),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"GET",
|
||||
"/ai/health",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(response.json()["reachable"])
|
||||
self.assertEqual(response.json()["default_model"], "qwen3-1.7b-instruct-q4_k_m")
|
||||
|
||||
async def test_timeout_returns_504(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_exception=httpx.ReadTimeout("timeout", request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions")),
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/ai/chat",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 504)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_timeout")
|
||||
|
||||
async def test_upstream_400_is_preserved(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
bad_request_response = httpx.Response(
|
||||
400,
|
||||
json={"error": {"message": "Bad prompt"}},
|
||||
request=httpx.Request("POST", f"{module.LLM_URL}/v1/chat/completions"),
|
||||
)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_responses={("POST", f"{module.LLM_URL}/v1/chat/completions"): bad_request_response},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
json_payload={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response.json()["error"]["code"], "llm_rejected_request")
|
||||
|
||||
async def test_models_endpoint_returns_upstream_metadata(self):
|
||||
module = load_gateway_module(llm_enabled=True)
|
||||
models_response = httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "qwen3-1.7b-instruct-q4_k_m",
|
||||
"object": "model",
|
||||
"owned_by": "self-hosted",
|
||||
}
|
||||
],
|
||||
},
|
||||
request=httpx.Request("GET", f"{module.LLM_URL}/v1/models"),
|
||||
)
|
||||
stub_client = StubUpstreamClient(
|
||||
request_responses={("GET", f"{module.LLM_URL}/v1/models"): models_response},
|
||||
)
|
||||
|
||||
with patch.object(module, "get_http_client", return_value=stub_client):
|
||||
response = await self._request(
|
||||
module,
|
||||
"GET",
|
||||
"/v1/models",
|
||||
headers={"X-API-Key": "test-key"},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json()["data"][0]["id"], "qwen3-1.7b-instruct-q4_k_m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user