Files
vision/common/image_io.py

92 lines
3.1 KiB
Python

from __future__ import annotations
import io
import ipaddress
import socket
from urllib.parse import urljoin, urlparse
import requests
from PIL import Image
DEFAULT_MAX_BYTES = 50 * 1024 * 1024 # 50MB
DEFAULT_MAX_REDIRECTS = 3
class ImageLoadError(Exception):
pass
def _validate_public_url(url: str) -> str:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise ImageLoadError("Only http and https URLs are allowed")
if not parsed.hostname:
raise ImageLoadError("URL must include a hostname")
hostname = parsed.hostname.strip().lower()
if hostname in {"localhost", "127.0.0.1", "::1"}:
raise ImageLoadError("Localhost URLs are not allowed")
try:
resolved = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80), type=socket.SOCK_STREAM)
except socket.gaierror as e:
raise ImageLoadError(f"Cannot resolve host: {e}") from e
for entry in resolved:
address = entry[4][0]
ip = ipaddress.ip_address(address)
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
):
raise ImageLoadError("URLs resolving to private or reserved addresses are not allowed")
return url
def fetch_url_bytes(url: str, timeout: float = 10.0, max_bytes: int = DEFAULT_MAX_BYTES) -> bytes:
current_url = _validate_public_url(url)
try:
for _ in range(DEFAULT_MAX_REDIRECTS + 1):
with requests.get(current_url, stream=True, timeout=timeout, allow_redirects=False) as r:
if 300 <= r.status_code < 400:
location = r.headers.get("location")
if not location:
raise ImageLoadError("Redirect response missing location header")
current_url = _validate_public_url(urljoin(current_url, location))
continue
r.raise_for_status()
content_type = (r.headers.get("content-type") or "").lower()
if content_type and not content_type.startswith("image/"):
raise ImageLoadError(f"URL does not point to an image content type: {content_type}")
buf = io.BytesIO()
total = 0
for chunk in r.iter_content(chunk_size=1024 * 64):
if not chunk:
continue
total += len(chunk)
if total > max_bytes:
raise ImageLoadError(f"Image exceeds max_bytes={max_bytes}")
buf.write(chunk)
return buf.getvalue()
raise ImageLoadError(f"Too many redirects (>{DEFAULT_MAX_REDIRECTS})")
except ImageLoadError:
raise
except Exception as e:
raise ImageLoadError(f"Cannot fetch image url: {e}") from e
def bytes_to_pil(data: bytes) -> Image.Image:
try:
img = Image.open(io.BytesIO(data)).convert("RGB")
return img
except Exception as e:
raise ImageLoadError(f"Cannot decode image: {e}") from e