Files
SkinbaseNova/services/enhance-worker/app/image_io.py

236 lines
8.1 KiB
Python

from __future__ import annotations
import io
import os
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
import httpx
from fastapi import HTTPException, status
from PIL import Image, ImageOps
from .config import Settings
ALLOWED_MIMES = {"image/jpeg", "image/png", "image/webp"}
FORMAT_TO_MIME = {"jpg": "image/jpeg", "png": "image/png", "webp": "image/webp"}
FORMAT_TO_EXTENSION = {"JPEG": "jpg", "PNG": "png", "WEBP": "webp"}
OUTPUT_FORMATS = {"jpg": "JPEG", "png": "PNG", "webp": "WEBP"}
@dataclass(frozen=True)
class DownloadedImage:
path: Path
width: int
height: int
mime: str
filesize: int
@dataclass(frozen=True)
class StoredImage:
filename: str
path: Path
width: int
height: int
filesize: int
mime: str
@dataclass(frozen=True)
class PreparedImage:
path: Path
width: int
height: int
mime: str
def ensure_directories(settings: Settings) -> None:
Path(settings.tmp_dir).mkdir(parents=True, exist_ok=True)
Path(settings.output_dir).mkdir(parents=True, exist_ok=True)
Path(settings.model_dir).mkdir(parents=True, exist_ok=True)
Path(settings.realesrgan_model_dir).mkdir(parents=True, exist_ok=True)
Path(settings.realesrgan_bin).parent.mkdir(parents=True, exist_ok=True)
def cleanup_expired_files(settings: Settings) -> None:
threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.result_ttl_minutes)
for directory in (Path(settings.tmp_dir), Path(settings.output_dir)):
if not directory.exists():
continue
for item in directory.iterdir():
if not item.is_file():
continue
modified_at = datetime.fromtimestamp(item.stat().st_mtime, tz=timezone.utc)
if modified_at <= threshold:
item.unlink(missing_ok=True)
def validate_image_bytes(binary: bytes, max_width: int, max_height: int) -> tuple[int, int, str]:
try:
with Image.open(io.BytesIO(binary)) as image:
width, height = image.size
mime = Image.MIME.get(image.format or "", "").lower()
except Exception as exc: # pragma: no cover - Pillow raises multiple subclasses.
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") from exc
if mime not in ALLOWED_MIMES:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
if width < 1 or height < 1 or width > max_width or height > max_height:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
return width, height, mime
def download_source_image(source_url: str, settings: Settings) -> DownloadedImage:
max_bytes = settings.max_upload_mb * 1024 * 1024
try:
with httpx.stream("GET", source_url, follow_redirects=True, timeout=30.0) as response:
response.raise_for_status()
content_length = response.headers.get("content-length")
if content_length and int(content_length) > max_bytes:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
buffer = bytearray()
for chunk in response.iter_bytes():
buffer.extend(chunk)
if len(buffer) > max_bytes:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
binary = bytes(buffer)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="The source file could not be downloaded by the worker.",
) from exc
width, height, mime = validate_image_bytes(binary, settings.max_input_width, settings.max_input_height)
extension = mime.split("/")[-1].replace("jpeg", "jpg")
path = Path(settings.tmp_dir) / f"input-{uuid.uuid4().hex}.{extension}"
path.write_bytes(binary)
return DownloadedImage(path=path, width=width, height=height, mime=mime, filesize=len(binary))
def save_output_image(image: Image.Image, output_format: str, settings: Settings, job_id: int) -> StoredImage:
width, height = image.size
if width < 1 or height < 1 or width > settings.max_output_width or height > settings.max_output_height:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
target_format = OUTPUT_FORMATS[output_format]
filename = f"job-{job_id}-{uuid.uuid4().hex}.{FORMAT_TO_EXTENSION[target_format]}"
path = Path(settings.output_dir) / filename
save_image = image
if target_format == "JPEG" and image.mode not in {"RGB", "L"}:
save_image = image.convert("RGB")
kwargs: dict[str, int] = {}
if target_format == "WEBP":
kwargs = {"quality": 90, "method": 6}
elif target_format == "JPEG":
kwargs = {"quality": 92}
save_image.save(path, target_format, **kwargs)
return StoredImage(
filename=filename,
path=path,
width=width,
height=height,
filesize=path.stat().st_size,
mime=FORMAT_TO_MIME[output_format],
)
def prepare_input_for_engine(downloaded: DownloadedImage, settings: Settings) -> PreparedImage:
image = load_normalized_image(downloaded.path)
width, height = image.size
if width * height > settings.realesrgan_preprocess_max_pixels:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
prepared_path = Path(settings.tmp_dir) / f"prepared-{uuid.uuid4().hex}.png"
prepared_path.parent.mkdir(parents=True, exist_ok=True)
prepared_image = image
if prepared_image.mode not in {"RGB", "RGBA", "L", "LA"}:
prepared_image = prepared_image.convert("RGBA" if "A" in prepared_image.getbands() else "RGB")
prepared_image.save(prepared_path, "PNG")
return PreparedImage(
path=prepared_path,
width=width,
height=height,
mime="image/png",
)
def validate_generated_image(
path: Path,
settings: Settings,
*,
expected_width: int | None = None,
expected_height: int | None = None,
) -> tuple[Image.Image, int, int, int, str]:
if not path.exists() or not path.is_file():
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
filesize = path.stat().st_size
if filesize <= 0:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
image = load_normalized_image(path)
width, height = image.size
if width > settings.max_output_width or height > settings.max_output_height:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Upscaled output exceeded the maximum allowed dimensions.",
)
if expected_width is not None and expected_height is not None and (width != expected_width or height != expected_height):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
mime = Image.MIME.get(image.format or "", "").lower() or "image/png"
if mime not in ALLOWED_MIMES:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
return image, width, height, filesize, mime
def delete_temp_file(path: Path | None) -> None:
if path is None:
return
path.unlink(missing_ok=True)
def resolve_result_path(settings: Settings, filename: str) -> Path:
safe_name = os.path.basename(filename)
if safe_name != filename or safe_name == "":
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
return Path(settings.output_dir) / safe_name
def load_normalized_image(path: Path) -> Image.Image:
with Image.open(path) as image:
normalized = ImageOps.exif_transpose(image)
normalized.load()
return normalized