Files
SkinbaseNova/services/enhance-worker/app/engines/realesrgan_ncnn_engine.py

214 lines
8.3 KiB
Python

from __future__ import annotations
import logging
import os
import subprocess
import time
import uuid
from pathlib import Path
from fastapi import HTTPException, status
from PIL import Image
from ..config import Settings
from ..image_io import DownloadedImage, delete_temp_file, prepare_input_for_engine, validate_generated_image
from .base import EngineHealth, UpscaleEngine, UpscaleEngineUnavailable, UpscaleResult
LOGGER = logging.getLogger("skinbase.enhance_worker.realesrgan_ncnn")
MODE_MODEL_MAP = {
"standard": "default",
"artwork": "default",
"photo": "default",
"illustration": "anime",
}
class RealEsrganNcnnEngine(UpscaleEngine):
def __init__(self, settings: Settings) -> None:
self.settings = settings
def health(self) -> EngineHealth:
available_models = self.available_models()
binary_path = Path(self.settings.realesrgan_bin)
model_dir = Path(self.settings.realesrgan_model_dir)
binary_exists = binary_path.exists()
binary_executable = binary_exists and binary_path.is_file() and os.access(binary_path, os.X_OK)
model_dir_exists = model_dir.exists() and model_dir.is_dir()
models_loaded = self.settings.realesrgan_default_model in available_models
return EngineHealth(
status="ok" if binary_exists and binary_executable and model_dir_exists and models_loaded else "degraded",
engine="realesrgan-ncnn",
device=self.settings.device,
models_loaded=models_loaded,
details={
"realesrgan": {
"binary_configured": self.settings.realesrgan_bin.strip() != "",
"binary_exists": binary_exists,
"binary_executable": binary_executable,
"model_dir_exists": model_dir_exists,
"available_models": available_models,
"default_model": self.settings.realesrgan_default_model,
}
},
)
def available_models(self) -> list[str]:
model_dir = Path(self.settings.realesrgan_model_dir)
if not model_dir.exists() or not model_dir.is_dir():
return []
params = {path.stem for path in model_dir.glob("*.param")}
bins = {path.stem for path in model_dir.glob("*.bin")}
return sorted(params & bins)
def upscale(self, downloaded: DownloadedImage, scale: int, mode: str, output_format: str) -> UpscaleResult:
if self.health().status != "ok":
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")
prepared = prepare_input_for_engine(downloaded, self.settings)
temp_output = Path(self.settings.tmp_dir) / f"realesrgan-output-{uuid.uuid4().hex}.png"
started_at = time.perf_counter()
try:
requested_model, used_model, model_fallback = self.resolve_model(mode)
command = self.build_command(prepared.path, temp_output, used_model)
self.run_command(command)
native_scale = 4
image, _, _, _, _ = validate_generated_image(
temp_output,
self.settings,
expected_width=prepared.width * native_scale,
expected_height=prepared.height * native_scale,
)
post_downsampled = False
if scale == 2:
image = image.resize((prepared.width * 2, prepared.height * 2), Image.Resampling.LANCZOS)
post_downsampled = True
output_width, output_height = image.size
if output_width > self.settings.max_output_width or output_height > self.settings.max_output_height:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Upscaled output exceeded the maximum allowed dimensions.",
)
return UpscaleResult(
image=image,
metadata={
"engine": "realesrgan-ncnn",
"model": used_model,
"requested_model": requested_model,
"used_model": used_model,
"model_fallback": model_fallback,
"requested_scale": scale,
"native_model_scale": native_scale,
"post_downsampled": post_downsampled,
"mode": mode,
"device": self.settings.device,
"processing_seconds": round(time.perf_counter() - started_at, 3),
"input_width": prepared.width,
"input_height": prepared.height,
"output_width": output_width,
"output_height": output_height,
"output_format": output_format,
"real_ai_upscale": True,
"configured_output_ext": self.settings.realesrgan_output_ext,
},
)
finally:
delete_temp_file(prepared.path)
delete_temp_file(temp_output)
def resolve_model(self, mode: str) -> tuple[str, str, bool]:
available_models = set(self.available_models())
requested_model = self.settings.realesrgan_default_model
if MODE_MODEL_MAP.get(mode) == "anime":
requested_model = self.settings.realesrgan_anime_model
if requested_model in available_models:
return requested_model, requested_model, False
if self.settings.realesrgan_allow_model_fallback and self.settings.realesrgan_default_model in available_models:
return requested_model, self.settings.realesrgan_default_model, True
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")
def build_command(self, input_path: Path, output_path: Path, model_name: str) -> list[str]:
command = [
self.settings.realesrgan_bin,
"-i",
str(input_path),
"-o",
str(output_path),
"-n",
model_name,
"-m",
self.settings.realesrgan_model_dir,
]
if self.settings.realesrgan_gpu_id >= 0:
command.extend(["-g", str(self.settings.realesrgan_gpu_id)])
if self.settings.realesrgan_tile > 0:
command.extend(["-t", str(self.settings.realesrgan_tile)])
if self.settings.realesrgan_tta:
command.append("-x")
if self.settings.realesrgan_verbose:
command.append("-v")
return command
def run_command(self, command: list[str]) -> None:
import signal
try:
proc = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
start_new_session=True, # new process group so we can kill all descendants
)
except FileNotFoundError as exception:
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.") from exception
pgid = os.getpgid(proc.pid)
def _kill_group() -> None:
try:
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
pass
try:
stdout, stderr = proc.communicate(timeout=self.settings.realesrgan_timeout_seconds)
except subprocess.TimeoutExpired:
_kill_group()
proc.communicate()
LOGGER.warning("Real-ESRGAN ncnn command timed out after %s seconds", self.settings.realesrgan_timeout_seconds)
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")
except BaseException:
# Thread cancellation or other unexpected error — ensure the process is killed
_kill_group()
proc.communicate()
raise
if proc.returncode != 0:
LOGGER.warning(
"Real-ESRGAN ncnn command failed with code %s; stdout bytes=%s stderr bytes=%s",
proc.returncode,
len(stdout or ""),
len(stderr or ""),
)
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")