from __future__ import annotations import io import os import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone from pathlib import Path import httpx from fastapi import HTTPException, status from PIL import Image, ImageOps from .config import Settings ALLOWED_MIMES = {"image/jpeg", "image/png", "image/webp"} FORMAT_TO_MIME = {"jpg": "image/jpeg", "png": "image/png", "webp": "image/webp"} FORMAT_TO_EXTENSION = {"JPEG": "jpg", "PNG": "png", "WEBP": "webp"} OUTPUT_FORMATS = {"jpg": "JPEG", "png": "PNG", "webp": "WEBP"} @dataclass(frozen=True) class DownloadedImage: path: Path width: int height: int mime: str filesize: int @dataclass(frozen=True) class StoredImage: filename: str path: Path width: int height: int filesize: int mime: str @dataclass(frozen=True) class PreparedImage: path: Path width: int height: int mime: str def ensure_directories(settings: Settings) -> None: Path(settings.tmp_dir).mkdir(parents=True, exist_ok=True) Path(settings.output_dir).mkdir(parents=True, exist_ok=True) Path(settings.model_dir).mkdir(parents=True, exist_ok=True) Path(settings.realesrgan_model_dir).mkdir(parents=True, exist_ok=True) Path(settings.realesrgan_bin).parent.mkdir(parents=True, exist_ok=True) def cleanup_expired_files(settings: Settings) -> None: threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.result_ttl_minutes) for directory in (Path(settings.tmp_dir), Path(settings.output_dir)): if not directory.exists(): continue for item in directory.iterdir(): if not item.is_file(): continue modified_at = datetime.fromtimestamp(item.stat().st_mtime, tz=timezone.utc) if modified_at <= threshold: item.unlink(missing_ok=True) def validate_image_bytes(binary: bytes, max_width: int, max_height: int) -> tuple[int, int, str]: try: with Image.open(io.BytesIO(binary)) as image: width, height = image.size mime = Image.MIME.get(image.format or "", "").lower() except Exception as exc: # pragma: no cover - Pillow raises multiple subclasses. raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") from exc if mime not in ALLOWED_MIMES: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") if width < 1 or height < 1 or width > max_width or height > max_height: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") return width, height, mime def download_source_image(source_url: str, settings: Settings) -> DownloadedImage: max_bytes = settings.max_upload_mb * 1024 * 1024 try: with httpx.stream("GET", source_url, follow_redirects=True, timeout=30.0) as response: response.raise_for_status() content_length = response.headers.get("content-length") if content_length and int(content_length) > max_bytes: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") buffer = bytearray() for chunk in response.iter_bytes(): buffer.extend(chunk) if len(buffer) > max_bytes: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") binary = bytes(buffer) except HTTPException: raise except Exception as exc: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="The source file could not be downloaded by the worker.", ) from exc width, height, mime = validate_image_bytes(binary, settings.max_input_width, settings.max_input_height) extension = mime.split("/")[-1].replace("jpeg", "jpg") path = Path(settings.tmp_dir) / f"input-{uuid.uuid4().hex}.{extension}" path.write_bytes(binary) return DownloadedImage(path=path, width=width, height=height, mime=mime, filesize=len(binary)) def save_output_image(image: Image.Image, output_format: str, settings: Settings, job_id: int) -> StoredImage: width, height = image.size if width < 1 or height < 1 or width > settings.max_output_width or height > settings.max_output_height: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") target_format = OUTPUT_FORMATS[output_format] filename = f"job-{job_id}-{uuid.uuid4().hex}.{FORMAT_TO_EXTENSION[target_format]}" path = Path(settings.output_dir) / filename save_image = image if target_format == "JPEG" and image.mode not in {"RGB", "L"}: save_image = image.convert("RGB") kwargs: dict[str, int] = {} if target_format == "WEBP": kwargs = {"quality": 90, "method": 6} elif target_format == "JPEG": kwargs = {"quality": 92} save_image.save(path, target_format, **kwargs) return StoredImage( filename=filename, path=path, width=width, height=height, filesize=path.stat().st_size, mime=FORMAT_TO_MIME[output_format], ) def prepare_input_for_engine(downloaded: DownloadedImage, settings: Settings) -> PreparedImage: image = load_normalized_image(downloaded.path) width, height = image.size if width * height > settings.realesrgan_preprocess_max_pixels: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") prepared_path = Path(settings.tmp_dir) / f"prepared-{uuid.uuid4().hex}.png" prepared_path.parent.mkdir(parents=True, exist_ok=True) prepared_image = image if prepared_image.mode not in {"RGB", "RGBA", "L", "LA"}: prepared_image = prepared_image.convert("RGBA" if "A" in prepared_image.getbands() else "RGB") prepared_image.save(prepared_path, "PNG") return PreparedImage( path=prepared_path, width=width, height=height, mime="image/png", ) def validate_generated_image( path: Path, settings: Settings, *, expected_width: int | None = None, expected_height: int | None = None, ) -> tuple[Image.Image, int, int, int, str]: if not path.exists() or not path.is_file(): raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") filesize = path.stat().st_size if filesize <= 0: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") image = load_normalized_image(path) width, height = image.size if width > settings.max_output_width or height > settings.max_output_height: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Upscaled output exceeded the maximum allowed dimensions.", ) if expected_width is not None and expected_height is not None and (width != expected_width or height != expected_height): raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") mime = Image.MIME.get(image.format or "", "").lower() or "image/png" if mime not in ALLOWED_MIMES: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") return image, width, height, filesize, mime def delete_temp_file(path: Path | None) -> None: if path is None: return path.unlink(missing_ok=True) def resolve_result_path(settings: Settings, filename: str) -> Path: safe_name = os.path.basename(filename) if safe_name != filename or safe_name == "": raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") return Path(settings.output_dir) / safe_name def load_normalized_image(path: Path) -> Image.Image: with Image.open(path) as image: normalized = ImageOps.exif_transpose(image) normalized.load() return normalized