first commit
This commit is contained in:
19
yolo/Dockerfile
Normal file
19
yolo/Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY yolo/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY yolo /app
|
||||
COPY common /app/common
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
70
yolo/main.py
Normal file
70
yolo/main.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
from ultralytics import YOLO
|
||||
|
||||
from common.image_io import fetch_url_bytes, bytes_to_pil, ImageLoadError
|
||||
|
||||
YOLO_MODEL = os.getenv("YOLO_MODEL", "yolov8n.pt")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
app = FastAPI(title="Skinbase YOLO Service", version="1.0.0")
|
||||
|
||||
model = YOLO(YOLO_MODEL)
|
||||
|
||||
|
||||
class DetectRequest(BaseModel):
|
||||
url: Optional[str] = None
|
||||
conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "device": DEVICE, "model": YOLO_MODEL}
|
||||
|
||||
|
||||
def _detect_bytes(data: bytes, conf: float):
|
||||
img = bytes_to_pil(data)
|
||||
|
||||
results = model(img)
|
||||
|
||||
best: Dict[str, float] = {}
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
score = float(box.conf[0])
|
||||
if score < conf:
|
||||
continue
|
||||
cls_id = int(box.cls[0])
|
||||
label = model.names.get(cls_id, str(cls_id))
|
||||
if label not in best or best[label] < score:
|
||||
best[label] = score
|
||||
|
||||
detections = [{"label": k, "confidence": v} for k, v in best.items()]
|
||||
detections.sort(key=lambda x: x["confidence"], reverse=True)
|
||||
|
||||
return {"detections": detections, "model": YOLO_MODEL}
|
||||
|
||||
|
||||
@app.post("/detect")
|
||||
def detect(req: DetectRequest):
|
||||
if not req.url:
|
||||
raise HTTPException(400, "url is required")
|
||||
try:
|
||||
data = fetch_url_bytes(req.url)
|
||||
return _detect_bytes(data, float(req.conf))
|
||||
except ImageLoadError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@app.post("/detect/file")
|
||||
async def detect_file(
|
||||
file: UploadFile = File(...),
|
||||
conf: float = Form(0.25),
|
||||
):
|
||||
data = await file.read()
|
||||
return _detect_bytes(data, float(conf))
|
||||
8
yolo/requirements.txt
Normal file
8
yolo/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.115.5
|
||||
uvicorn[standard]==0.30.6
|
||||
python-multipart==0.0.9
|
||||
requests==2.32.3
|
||||
pillow==10.4.0
|
||||
torch==2.4.1
|
||||
torchvision==0.19.1
|
||||
ultralytics==8.3.5
|
||||
Reference in New Issue
Block a user