236 lines
8.1 KiB
Python
236 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import io
|
|
import os
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
from fastapi import HTTPException, status
|
|
from PIL import Image, ImageOps
|
|
|
|
from .config import Settings
|
|
|
|
|
|
ALLOWED_MIMES = {"image/jpeg", "image/png", "image/webp"}
|
|
FORMAT_TO_MIME = {"jpg": "image/jpeg", "png": "image/png", "webp": "image/webp"}
|
|
FORMAT_TO_EXTENSION = {"JPEG": "jpg", "PNG": "png", "WEBP": "webp"}
|
|
OUTPUT_FORMATS = {"jpg": "JPEG", "png": "PNG", "webp": "WEBP"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DownloadedImage:
|
|
path: Path
|
|
width: int
|
|
height: int
|
|
mime: str
|
|
filesize: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class StoredImage:
|
|
filename: str
|
|
path: Path
|
|
width: int
|
|
height: int
|
|
filesize: int
|
|
mime: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PreparedImage:
|
|
path: Path
|
|
width: int
|
|
height: int
|
|
mime: str
|
|
|
|
|
|
def ensure_directories(settings: Settings) -> None:
|
|
Path(settings.tmp_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(settings.output_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(settings.model_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(settings.realesrgan_model_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(settings.realesrgan_bin).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def cleanup_expired_files(settings: Settings) -> None:
|
|
threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.result_ttl_minutes)
|
|
|
|
for directory in (Path(settings.tmp_dir), Path(settings.output_dir)):
|
|
if not directory.exists():
|
|
continue
|
|
|
|
for item in directory.iterdir():
|
|
if not item.is_file():
|
|
continue
|
|
|
|
modified_at = datetime.fromtimestamp(item.stat().st_mtime, tz=timezone.utc)
|
|
|
|
if modified_at <= threshold:
|
|
item.unlink(missing_ok=True)
|
|
|
|
|
|
def validate_image_bytes(binary: bytes, max_width: int, max_height: int) -> tuple[int, int, str]:
|
|
try:
|
|
with Image.open(io.BytesIO(binary)) as image:
|
|
width, height = image.size
|
|
mime = Image.MIME.get(image.format or "", "").lower()
|
|
except Exception as exc: # pragma: no cover - Pillow raises multiple subclasses.
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.") from exc
|
|
|
|
if mime not in ALLOWED_MIMES:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
if width < 1 or height < 1 or width > max_width or height > max_height:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
return width, height, mime
|
|
|
|
|
|
def download_source_image(source_url: str, settings: Settings) -> DownloadedImage:
|
|
max_bytes = settings.max_upload_mb * 1024 * 1024
|
|
|
|
try:
|
|
with httpx.stream("GET", source_url, follow_redirects=True, timeout=30.0) as response:
|
|
response.raise_for_status()
|
|
|
|
content_length = response.headers.get("content-length")
|
|
if content_length and int(content_length) > max_bytes:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
buffer = bytearray()
|
|
for chunk in response.iter_bytes():
|
|
buffer.extend(chunk)
|
|
if len(buffer) > max_bytes:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
binary = bytes(buffer)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail="The source file could not be downloaded by the worker.",
|
|
) from exc
|
|
|
|
width, height, mime = validate_image_bytes(binary, settings.max_input_width, settings.max_input_height)
|
|
extension = mime.split("/")[-1].replace("jpeg", "jpg")
|
|
path = Path(settings.tmp_dir) / f"input-{uuid.uuid4().hex}.{extension}"
|
|
path.write_bytes(binary)
|
|
|
|
return DownloadedImage(path=path, width=width, height=height, mime=mime, filesize=len(binary))
|
|
|
|
|
|
def save_output_image(image: Image.Image, output_format: str, settings: Settings, job_id: int) -> StoredImage:
|
|
width, height = image.size
|
|
|
|
if width < 1 or height < 1 or width > settings.max_output_width or height > settings.max_output_height:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
target_format = OUTPUT_FORMATS[output_format]
|
|
filename = f"job-{job_id}-{uuid.uuid4().hex}.{FORMAT_TO_EXTENSION[target_format]}"
|
|
path = Path(settings.output_dir) / filename
|
|
save_image = image
|
|
|
|
if target_format == "JPEG" and image.mode not in {"RGB", "L"}:
|
|
save_image = image.convert("RGB")
|
|
|
|
kwargs: dict[str, int] = {}
|
|
if target_format == "WEBP":
|
|
kwargs = {"quality": 90, "method": 6}
|
|
elif target_format == "JPEG":
|
|
kwargs = {"quality": 92}
|
|
|
|
save_image.save(path, target_format, **kwargs)
|
|
|
|
return StoredImage(
|
|
filename=filename,
|
|
path=path,
|
|
width=width,
|
|
height=height,
|
|
filesize=path.stat().st_size,
|
|
mime=FORMAT_TO_MIME[output_format],
|
|
)
|
|
|
|
|
|
def prepare_input_for_engine(downloaded: DownloadedImage, settings: Settings) -> PreparedImage:
|
|
image = load_normalized_image(downloaded.path)
|
|
width, height = image.size
|
|
|
|
if width * height > settings.realesrgan_preprocess_max_pixels:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
prepared_path = Path(settings.tmp_dir) / f"prepared-{uuid.uuid4().hex}.png"
|
|
prepared_path.parent.mkdir(parents=True, exist_ok=True)
|
|
prepared_image = image
|
|
|
|
if prepared_image.mode not in {"RGB", "RGBA", "L", "LA"}:
|
|
prepared_image = prepared_image.convert("RGBA" if "A" in prepared_image.getbands() else "RGB")
|
|
|
|
prepared_image.save(prepared_path, "PNG")
|
|
|
|
return PreparedImage(
|
|
path=prepared_path,
|
|
width=width,
|
|
height=height,
|
|
mime="image/png",
|
|
)
|
|
|
|
|
|
def validate_generated_image(
|
|
path: Path,
|
|
settings: Settings,
|
|
*,
|
|
expected_width: int | None = None,
|
|
expected_height: int | None = None,
|
|
) -> tuple[Image.Image, int, int, int, str]:
|
|
if not path.exists() or not path.is_file():
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
filesize = path.stat().st_size
|
|
|
|
if filesize <= 0:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
image = load_normalized_image(path)
|
|
width, height = image.size
|
|
|
|
if width > settings.max_output_width or height > settings.max_output_height:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail="Upscaled output exceeded the maximum allowed dimensions.",
|
|
)
|
|
|
|
if expected_width is not None and expected_height is not None and (width != expected_width or height != expected_height):
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
mime = Image.MIME.get(image.format or "", "").lower() or "image/png"
|
|
|
|
if mime not in ALLOWED_MIMES:
|
|
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Worker rejected the image.")
|
|
|
|
return image, width, height, filesize, mime
|
|
|
|
|
|
def delete_temp_file(path: Path | None) -> None:
|
|
if path is None:
|
|
return
|
|
|
|
path.unlink(missing_ok=True)
|
|
|
|
|
|
def resolve_result_path(settings: Settings, filename: str) -> Path:
|
|
safe_name = os.path.basename(filename)
|
|
if safe_name != filename or safe_name == "":
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found")
|
|
|
|
return Path(settings.output_dir) / safe_name
|
|
|
|
|
|
def load_normalized_image(path: Path) -> Image.Image:
|
|
with Image.open(path) as image:
|
|
normalized = ImageOps.exif_transpose(image)
|
|
normalized.load()
|
|
return normalized |