Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/arduino/app_bricks/object_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA <http://www.arduino.cc>
#
# SPDX-License-Identifier: MPL-2.0
from typing import Any

from PIL import Image
from arduino.app_utils import brick, Logger, draw_bounding_boxes
from arduino.app_utils import brick, Logger, draw_bounding_boxes, Shape
from arduino.app_internal.core import EdgeImpulseRunnerFacade

logger = Logger("ObjectDetection")
Expand All @@ -20,8 +21,19 @@ class ObjectDetection(EdgeImpulseRunnerFacade):
"""

def __init__(self, confidence: float = 0.3):
"""Initialize the ObjectDetection module.

Args:
confidence (float): Minimum confidence threshold for detections. Default is 0.3 (30%).

Raises:
ValueError: If model information cannot be retrieved.
"""
self.confidence = confidence
super().__init__()
self._model_info = self.get_model_info()
if not self._model_info:
raise ValueError("Failed to retrieve model information. Ensure the Edge Impulse service is running.")

def detect_from_file(self, image_path: str, confidence: float = None) -> dict | None:
"""Process a local image file to detect and identify objects.
Expand All @@ -38,7 +50,7 @@ def detect_from_file(self, image_path: str, confidence: float = None) -> dict |
ret = super().infer_from_file(image_path)
return self._extract_detection(ret, confidence)

def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) -> dict:
def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) -> dict[str, list[Any]] | None:
"""Process an in-memory image to detect and identify objects.

Args:
Expand All @@ -63,9 +75,17 @@ def draw_bounding_boxes(self, image: Image.Image | bytes, detections: dict) -> I

Returns:
Image with bounding boxes and key points drawn.
None if no detection or invalid image.
None if input image or detections are invalid.
"""
return draw_bounding_boxes(image, detections)
if not image or not detections:
return None

shape = None
if self._model_info.model_type == "object_detection":
shape = Shape.RECTANGLE
elif self._model_info.model_type == "constrained_object_detection":
shape = Shape.CIRCLE
return draw_bounding_boxes(image, detections, shape=shape)

def _extract_detection(self, item, confidence: float = None):
if not item:
Expand Down
64 changes: 31 additions & 33 deletions src/arduino/app_utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

logger = Logger(__name__)


class Shape:
RECTANGLE = "rectangle"
CIRCLE = "circle"


# Define a mapping of confidence ranges to colors for bounding boxes
CONFIDENCE_MAP = {
(0, 20): "#FF0976", # Pink
Expand Down Expand Up @@ -40,7 +46,7 @@ def _read(file_path: str) -> bytes:


def get_image_type(image_bytes: bytes | Image.Image) -> str | None:
"""Detect the type of an image from bytes or a PIL Image object.
"""Detect the type of image from bytes or a PIL Image object.

Returns:
str: The image type in lowercase (e.g., 'jpeg', 'png').
Expand All @@ -60,7 +66,7 @@ def get_image_type(image_bytes: bytes | Image.Image) -> str | None:
return None


def get_image_bytes(image: str | Image.Image | bytes) -> bytes:
def get_image_bytes(image: str | Image.Image | bytes) -> bytes | None:
"""Convert different type of image objects to bytes."""
if image is None:
return None
Expand All @@ -78,23 +84,12 @@ def get_image_bytes(image: str | Image.Image | bytes) -> bytes:
return None


def draw_colored_dot(draw, x, y, color, size):
"""Draws a large colored dot on a PIL Image at the specified coordinate.

Args:
draw: An ImageDraw object from PIL.
x: The x-coordinate of the center of the dot.
y: The y-coordinate of the center of the dot.
color: A color value that PIL understands (e.g., "red", (255, 0, 0), "#FF0000").
size: The radius of the dot (in pixels).
"""
# Calculate the bounding box for the circle
bounding_box = (x - size, y - size, x + size, y + size)
# Draw a filled ellipse (which looks like a circle if the bounding box is a square)
draw.ellipse(bounding_box, fill=color)


def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: ImageDraw.ImageDraw = None) -> Image.Image | None:
def draw_bounding_boxes(
image: Image.Image | bytes,
detection: dict,
draw: ImageDraw.ImageDraw = None,
shape: Shape = Shape.RECTANGLE,
) -> Image.Image | None:
"""Draw bounding boxes on an image using PIL.

The thickness of the box and font size are scaled based on image size.
Expand All @@ -104,6 +99,8 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
detection (dict): A dictionary containing detection results with keys 'class_name', 'bounding_box_xyxy', and
'confidence'.
draw (ImageDraw.ImageDraw, optional): An existing ImageDraw object to use. If None, a new one is created.
shape (Shape, optional): Shape of the bounding box. Defaults to rectangle.
itself. Defaults to False.
"""
if isinstance(image, bytes):
image_box = Image.open(io.BytesIO(image))
Expand All @@ -116,6 +113,10 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
if not detection or "detection" not in detection:
return None

if shape not in (Shape.RECTANGLE, Shape.CIRCLE):
logger.warning(f"Unsupported shape '{shape}'. Defaulting to rectangle.")
shape = Shape.RECTANGLE

detection = detection["detection"]

# Scale font size and box thickness based on image size and number of detections
Expand Down Expand Up @@ -163,12 +164,19 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
x2_text = x1 + text_width + label_hpad * 2

# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline=box_color, width=box_thickness)
if shape == Shape.CIRCLE:
center_x = int((x1 + x2) / 2)
center_y = int((y1 + y2) / 2)
radius = 10
bounding_box = (center_x - radius, center_y - radius, center_x + radius, center_y + radius)
draw.ellipse(bounding_box, outline=box_color, width=2)
else:
draw.rectangle((x1, y1, x2, y2), outline=box_color, width=box_thickness)
# Draw label background (dark gray, semi-transparent) on overlay
label_bg_color = (0, 0, 0, 128)
overlay = Image.new("RGBA", image_box.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle([x1, y1_text, x2_text, y2_text], fill=label_bg_color, outline=None)
overlay_draw.rectangle((x1, y1_text, x2_text, y2_text), fill=label_bg_color, outline=None)
image_box = image_box.convert("RGBA")
image_box = Image.alpha_composite(image_box, overlay)
draw = ImageDraw.Draw(image_box)
Expand All @@ -181,7 +189,6 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
def draw_anomaly_markers(
image: Image.Image | bytes,
detection: dict,
draw: ImageDraw.ImageDraw = None,
) -> Image.Image | None:
"""Draw bounding boxes on an image using PIL.

Expand All @@ -191,10 +198,6 @@ def draw_anomaly_markers(
image (Image.Image|bytes): The image to draw on, can be a PIL Image or bytes.
detection (dict): A dictionary containing detection results with keys 'class_name', 'bounding_box_xyxy', and
'score'.
draw (ImageDraw.ImageDraw, optional): An existing ImageDraw object to use. If None, a new one is created.
label_above_box (bool, optional): If True, labels are drawn above the bounding box. Defaults to False.
colours (list, optional): List of colors to use for bounding boxes. Defaults to a predefined palette.
text_color (str, optional): Color of the text labels. Defaults to "white".
"""
if isinstance(image, bytes):
image_box = Image.open(io.BytesIO(image))
Expand All @@ -204,9 +207,6 @@ def draw_anomaly_markers(
if image_box.mode != "RGBA":
image_box = image_box.convert("RGBA")

if draw is None:
draw = ImageDraw.Draw(image_box)

max_anomaly_score = detection.get("anomaly_max_score", 0.0)

if not detection or "detection" not in detection:
Expand Down Expand Up @@ -239,10 +239,8 @@ def draw_anomaly_markers(
temp_layer = Image.new("RGBA", image_box.size, (0, 0, 0, 0))
temp_draw = ImageDraw.Draw(temp_layer)

temp_draw.rectangle([x1, y1, x2, y2], fill=fill_color_with_alpha)
temp_draw.rectangle([x1, y1, x2, y2], outline=outline_color, width=box_thickness)
temp_draw.rectangle((x1, y1, x2, y2), fill=fill_color_with_alpha)
temp_draw.rectangle((x1, y1, x2, y2), outline=outline_color, width=box_thickness)
image_box = Image.alpha_composite(image_box, temp_layer)

draw = ImageDraw.Draw(image_box)

return image_box
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from arduino.app_bricks.object_detection import ObjectDetection


class ModelInfo:
def __init__(self, model_type: str):
self.model_type = model_type


@pytest.fixture(autouse=True)
def mock_dependencies(monkeypatch: pytest.MonkeyPatch):
"""Mock external dependencies in __init__.
Expand All @@ -19,6 +24,7 @@ def mock_dependencies(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("arduino.app_internal.core.load_brick_compose_file", lambda cls: fake_compose)
monkeypatch.setattr("arduino.app_internal.core.resolve_address", lambda host: "127.0.0.1")
monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "8100")])
monkeypatch.setattr("arduino.app_bricks.object_detection.ObjectDetection.get_model_info", lambda self: ModelInfo("object-detection"))


@pytest.fixture
Expand Down