from __future__ import annotations import asyncio import logging import os import shlex import subprocess import time from contextlib import asynccontextmanager from pathlib import Path from typing import Any, Dict, Optional import httpx from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse logger = logging.getLogger("llm") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") MODEL_PATH = os.getenv("MODEL_PATH", "/models/Qwen3-1.7B-Instruct-Q4_K_M.gguf") LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "qwen3-1.7b-instruct-q4_k_m") LLM_CONTEXT_SIZE = int(os.getenv("LLM_CONTEXT_SIZE", "4096")) LLM_THREADS = int(os.getenv("LLM_THREADS", "4")) LLM_GPU_LAYERS = int(os.getenv("LLM_GPU_LAYERS", "0")) LLAMA_SERVER_PORT = int(os.getenv("LLAMA_SERVER_PORT", "8081")) LLM_STARTUP_TIMEOUT = float(os.getenv("LLM_STARTUP_TIMEOUT", "120")) LLM_EXTRA_ARGS = os.getenv("LLM_EXTRA_ARGS", "") _llama_process: subprocess.Popen[bytes] | None = None _http_client: httpx.AsyncClient | None = None def _upstream_base_url() -> str: return f"http://127.0.0.1:{LLAMA_SERVER_PORT}" def _ensure_http_client() -> httpx.AsyncClient: if _http_client is None: raise RuntimeError("HTTP client not initialised") return _http_client def _validate_model_path() -> None: model_file = Path(MODEL_PATH) if not model_file.is_file(): raise RuntimeError(f"model file not found at {MODEL_PATH}") if not os.access(model_file, os.R_OK): raise RuntimeError(f"model file is not readable at {MODEL_PATH}") def _build_llama_command() -> list[str]: command = [ "/usr/local/bin/llama-server", "--host", "127.0.0.1", "--port", str(LLAMA_SERVER_PORT), "--model", MODEL_PATH, "--alias", LLM_MODEL_NAME, "--ctx-size", str(LLM_CONTEXT_SIZE), "--threads", str(LLM_THREADS), "--n-gpu-layers", str(LLM_GPU_LAYERS), ] if LLM_EXTRA_ARGS.strip(): command.extend(shlex.split(LLM_EXTRA_ARGS)) return command def _llama_running() -> bool: return _llama_process is not None and _llama_process.poll() is None async def _wait_for_llama_ready() -> None: deadline = time.monotonic() + LLM_STARTUP_TIMEOUT last_error: Optional[Exception] = None while time.monotonic() < deadline: if _llama_process is not None and _llama_process.poll() is not None: raise RuntimeError(f"llama-server exited with code {_llama_process.poll()}") try: response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5) if response.status_code == 200: logger.info("llm service: llama-server ready") return except Exception as exc: last_error = exc await asyncio.sleep(1) raise RuntimeError(f"llama-server did not become ready within {LLM_STARTUP_TIMEOUT}s: {last_error}") async def _stop_llama_process() -> None: global _llama_process if _llama_process is None: return if _llama_process.poll() is None: _llama_process.terminate() try: await asyncio.to_thread(_llama_process.wait, timeout=10) except subprocess.TimeoutExpired: _llama_process.kill() await asyncio.to_thread(_llama_process.wait, timeout=5) _llama_process = None @asynccontextmanager async def lifespan(app: FastAPI): global _http_client, _llama_process _validate_model_path() _http_client = httpx.AsyncClient(timeout=httpx.Timeout(120, connect=5)) command = _build_llama_command() logger.info("llm service: starting llama-server model=%s ctx=%s threads=%s gpu_layers=%s upstream_port=%s", LLM_MODEL_NAME, LLM_CONTEXT_SIZE, LLM_THREADS, LLM_GPU_LAYERS, LLAMA_SERVER_PORT) _llama_process = subprocess.Popen(command) try: await _wait_for_llama_ready() yield finally: await _stop_llama_process() if _http_client is not None: await _http_client.aclose() _http_client = None app = FastAPI(title="Skinbase LLM Service", version="1.0.0", lifespan=lifespan) def _health_payload(status: str) -> Dict[str, Any]: return { "status": status, "model": Path(MODEL_PATH).name, "model_alias": LLM_MODEL_NAME, "context_size": LLM_CONTEXT_SIZE, "threads": LLM_THREADS, "gpu_layers": LLM_GPU_LAYERS, } async def _proxy_request(method: str, path: str, *, body: bytes | None = None) -> Dict[str, Any]: if not _llama_running(): raise HTTPException(status_code=503, detail="llama-server is not running") headers = {"content-type": "application/json"} if body is not None else None try: response = await _ensure_http_client().request( method, f"{_upstream_base_url()}{path}", content=body, headers=headers, timeout=httpx.Timeout(120, connect=5), ) except httpx.TimeoutException as exc: raise HTTPException(status_code=504, detail=f"llama-server timed out: {exc}") except httpx.RequestError as exc: raise HTTPException(status_code=503, detail=f"llama-server unavailable: {exc}") if response.status_code >= 400: detail: Any try: detail = response.json() except Exception: detail = response.text[:1000] raise HTTPException(status_code=response.status_code, detail=detail) try: return response.json() except Exception as exc: raise HTTPException(status_code=502, detail=f"llama-server returned invalid JSON: {exc}") @app.exception_handler(HTTPException) async def handle_http_exception(_: Request, exc: HTTPException): return JSONResponse(status_code=exc.status_code, content={"error": {"code": "llm_service_error", "message": str(exc.detail)}}) @app.get("/health") async def health(): if not _llama_running(): return JSONResponse(status_code=503, content=_health_payload("unavailable")) try: response = await _ensure_http_client().get(f"{_upstream_base_url()}/v1/models", timeout=5) if response.status_code != 200: return JSONResponse(status_code=503, content=_health_payload("degraded")) except Exception: return JSONResponse(status_code=503, content=_health_payload("degraded")) return _health_payload("ok") @app.get("/v1/models") async def list_models(): return await _proxy_request("GET", "/v1/models") @app.post("/v1/chat/completions") async def chat_completions(request: Request): body = await request.body() return await _proxy_request("POST", "/v1/chat/completions", body=body)