214 lines
8.3 KiB
Python
214 lines
8.3 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
from fastapi import HTTPException, status
|
|
from PIL import Image
|
|
|
|
from ..config import Settings
|
|
from ..image_io import DownloadedImage, delete_temp_file, prepare_input_for_engine, validate_generated_image
|
|
from .base import EngineHealth, UpscaleEngine, UpscaleEngineUnavailable, UpscaleResult
|
|
|
|
|
|
LOGGER = logging.getLogger("skinbase.enhance_worker.realesrgan_ncnn")
|
|
|
|
MODE_MODEL_MAP = {
|
|
"standard": "default",
|
|
"artwork": "default",
|
|
"photo": "default",
|
|
"illustration": "anime",
|
|
}
|
|
|
|
|
|
class RealEsrganNcnnEngine(UpscaleEngine):
|
|
def __init__(self, settings: Settings) -> None:
|
|
self.settings = settings
|
|
|
|
def health(self) -> EngineHealth:
|
|
available_models = self.available_models()
|
|
binary_path = Path(self.settings.realesrgan_bin)
|
|
model_dir = Path(self.settings.realesrgan_model_dir)
|
|
binary_exists = binary_path.exists()
|
|
binary_executable = binary_exists and binary_path.is_file() and os.access(binary_path, os.X_OK)
|
|
model_dir_exists = model_dir.exists() and model_dir.is_dir()
|
|
models_loaded = self.settings.realesrgan_default_model in available_models
|
|
|
|
return EngineHealth(
|
|
status="ok" if binary_exists and binary_executable and model_dir_exists and models_loaded else "degraded",
|
|
engine="realesrgan-ncnn",
|
|
device=self.settings.device,
|
|
models_loaded=models_loaded,
|
|
details={
|
|
"realesrgan": {
|
|
"binary_configured": self.settings.realesrgan_bin.strip() != "",
|
|
"binary_exists": binary_exists,
|
|
"binary_executable": binary_executable,
|
|
"model_dir_exists": model_dir_exists,
|
|
"available_models": available_models,
|
|
"default_model": self.settings.realesrgan_default_model,
|
|
}
|
|
},
|
|
)
|
|
|
|
def available_models(self) -> list[str]:
|
|
model_dir = Path(self.settings.realesrgan_model_dir)
|
|
|
|
if not model_dir.exists() or not model_dir.is_dir():
|
|
return []
|
|
|
|
params = {path.stem for path in model_dir.glob("*.param")}
|
|
bins = {path.stem for path in model_dir.glob("*.bin")}
|
|
|
|
return sorted(params & bins)
|
|
|
|
def upscale(self, downloaded: DownloadedImage, scale: int, mode: str, output_format: str) -> UpscaleResult:
|
|
if self.health().status != "ok":
|
|
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")
|
|
|
|
prepared = prepare_input_for_engine(downloaded, self.settings)
|
|
temp_output = Path(self.settings.tmp_dir) / f"realesrgan-output-{uuid.uuid4().hex}.png"
|
|
started_at = time.perf_counter()
|
|
|
|
try:
|
|
requested_model, used_model, model_fallback = self.resolve_model(mode)
|
|
command = self.build_command(prepared.path, temp_output, used_model)
|
|
self.run_command(command)
|
|
|
|
native_scale = 4
|
|
image, _, _, _, _ = validate_generated_image(
|
|
temp_output,
|
|
self.settings,
|
|
expected_width=prepared.width * native_scale,
|
|
expected_height=prepared.height * native_scale,
|
|
)
|
|
|
|
post_downsampled = False
|
|
if scale == 2:
|
|
image = image.resize((prepared.width * 2, prepared.height * 2), Image.Resampling.LANCZOS)
|
|
post_downsampled = True
|
|
|
|
output_width, output_height = image.size
|
|
|
|
if output_width > self.settings.max_output_width or output_height > self.settings.max_output_height:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail="Upscaled output exceeded the maximum allowed dimensions.",
|
|
)
|
|
|
|
return UpscaleResult(
|
|
image=image,
|
|
metadata={
|
|
"engine": "realesrgan-ncnn",
|
|
"model": used_model,
|
|
"requested_model": requested_model,
|
|
"used_model": used_model,
|
|
"model_fallback": model_fallback,
|
|
"requested_scale": scale,
|
|
"native_model_scale": native_scale,
|
|
"post_downsampled": post_downsampled,
|
|
"mode": mode,
|
|
"device": self.settings.device,
|
|
"processing_seconds": round(time.perf_counter() - started_at, 3),
|
|
"input_width": prepared.width,
|
|
"input_height": prepared.height,
|
|
"output_width": output_width,
|
|
"output_height": output_height,
|
|
"output_format": output_format,
|
|
"real_ai_upscale": True,
|
|
"configured_output_ext": self.settings.realesrgan_output_ext,
|
|
},
|
|
)
|
|
finally:
|
|
delete_temp_file(prepared.path)
|
|
delete_temp_file(temp_output)
|
|
|
|
def resolve_model(self, mode: str) -> tuple[str, str, bool]:
|
|
available_models = set(self.available_models())
|
|
requested_model = self.settings.realesrgan_default_model
|
|
|
|
if MODE_MODEL_MAP.get(mode) == "anime":
|
|
requested_model = self.settings.realesrgan_anime_model
|
|
|
|
if requested_model in available_models:
|
|
return requested_model, requested_model, False
|
|
|
|
if self.settings.realesrgan_allow_model_fallback and self.settings.realesrgan_default_model in available_models:
|
|
return requested_model, self.settings.realesrgan_default_model, True
|
|
|
|
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")
|
|
|
|
def build_command(self, input_path: Path, output_path: Path, model_name: str) -> list[str]:
|
|
command = [
|
|
self.settings.realesrgan_bin,
|
|
"-i",
|
|
str(input_path),
|
|
"-o",
|
|
str(output_path),
|
|
"-n",
|
|
model_name,
|
|
"-m",
|
|
self.settings.realesrgan_model_dir,
|
|
]
|
|
|
|
if self.settings.realesrgan_gpu_id >= 0:
|
|
command.extend(["-g", str(self.settings.realesrgan_gpu_id)])
|
|
|
|
if self.settings.realesrgan_tile > 0:
|
|
command.extend(["-t", str(self.settings.realesrgan_tile)])
|
|
|
|
if self.settings.realesrgan_tta:
|
|
command.append("-x")
|
|
|
|
if self.settings.realesrgan_verbose:
|
|
command.append("-v")
|
|
|
|
return command
|
|
|
|
def run_command(self, command: list[str]) -> None:
|
|
import signal
|
|
|
|
try:
|
|
proc = subprocess.Popen(
|
|
command,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
start_new_session=True, # new process group so we can kill all descendants
|
|
)
|
|
except FileNotFoundError as exception:
|
|
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.") from exception
|
|
|
|
pgid = os.getpgid(proc.pid)
|
|
|
|
def _kill_group() -> None:
|
|
try:
|
|
os.killpg(pgid, signal.SIGKILL)
|
|
except ProcessLookupError:
|
|
pass
|
|
|
|
try:
|
|
stdout, stderr = proc.communicate(timeout=self.settings.realesrgan_timeout_seconds)
|
|
except subprocess.TimeoutExpired:
|
|
_kill_group()
|
|
proc.communicate()
|
|
LOGGER.warning("Real-ESRGAN ncnn command timed out after %s seconds", self.settings.realesrgan_timeout_seconds)
|
|
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.")
|
|
except BaseException:
|
|
# Thread cancellation or other unexpected error — ensure the process is killed
|
|
_kill_group()
|
|
proc.communicate()
|
|
raise
|
|
|
|
if proc.returncode != 0:
|
|
LOGGER.warning(
|
|
"Real-ESRGAN ncnn command failed with code %s; stdout bytes=%s stderr bytes=%s",
|
|
proc.returncode,
|
|
len(stdout or ""),
|
|
len(stderr or ""),
|
|
)
|
|
raise UpscaleEngineUnavailable("Upscale engine is not available. Check model files and worker installation.") |