diff --git a/src/arduino/app_bricks/object_detection/__init__.py b/src/arduino/app_bricks/object_detection/__init__.py index 93f2e290..04fb018a 100644 --- a/src/arduino/app_bricks/object_detection/__init__.py +++ b/src/arduino/app_bricks/object_detection/__init__.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA # # SPDX-License-Identifier: MPL-2.0 +from typing import Any from PIL import Image -from arduino.app_utils import brick, Logger, draw_bounding_boxes +from arduino.app_utils import brick, Logger, draw_bounding_boxes, Shape from arduino.app_internal.core import EdgeImpulseRunnerFacade logger = Logger("ObjectDetection") @@ -20,8 +21,19 @@ class ObjectDetection(EdgeImpulseRunnerFacade): """ def __init__(self, confidence: float = 0.3): + """Initialize the ObjectDetection module. + + Args: + confidence (float): Minimum confidence threshold for detections. Default is 0.3 (30%). + + Raises: + ValueError: If model information cannot be retrieved. + """ self.confidence = confidence super().__init__() + self._model_info = self.get_model_info() + if not self._model_info: + raise ValueError("Failed to retrieve model information. Ensure the Edge Impulse service is running.") def detect_from_file(self, image_path: str, confidence: float = None) -> dict | None: """Process a local image file to detect and identify objects. @@ -38,7 +50,7 @@ def detect_from_file(self, image_path: str, confidence: float = None) -> dict | ret = super().infer_from_file(image_path) return self._extract_detection(ret, confidence) - def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) -> dict: + def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) -> dict[str, list[Any]] | None: """Process an in-memory image to detect and identify objects. Args: @@ -63,9 +75,17 @@ def draw_bounding_boxes(self, image: Image.Image | bytes, detections: dict) -> I Returns: Image with bounding boxes and key points drawn. - None if no detection or invalid image. + None if input image or detections are invalid. """ - return draw_bounding_boxes(image, detections) + if not image or not detections: + return None + + shape = None + if self._model_info.model_type == "object_detection": + shape = Shape.RECTANGLE + elif self._model_info.model_type == "constrained_object_detection": + shape = Shape.CIRCLE + return draw_bounding_boxes(image, detections, shape=shape) def _extract_detection(self, item, confidence: float = None): if not item: diff --git a/src/arduino/app_utils/image.py b/src/arduino/app_utils/image.py index 8870f9b1..48f2f25c 100644 --- a/src/arduino/app_utils/image.py +++ b/src/arduino/app_utils/image.py @@ -8,6 +8,12 @@ logger = Logger(__name__) + +class Shape: + RECTANGLE = "rectangle" + CIRCLE = "circle" + + # Define a mapping of confidence ranges to colors for bounding boxes CONFIDENCE_MAP = { (0, 20): "#FF0976", # Pink @@ -40,7 +46,7 @@ def _read(file_path: str) -> bytes: def get_image_type(image_bytes: bytes | Image.Image) -> str | None: - """Detect the type of an image from bytes or a PIL Image object. + """Detect the type of image from bytes or a PIL Image object. Returns: str: The image type in lowercase (e.g., 'jpeg', 'png'). @@ -60,7 +66,7 @@ def get_image_type(image_bytes: bytes | Image.Image) -> str | None: return None -def get_image_bytes(image: str | Image.Image | bytes) -> bytes: +def get_image_bytes(image: str | Image.Image | bytes) -> bytes | None: """Convert different type of image objects to bytes.""" if image is None: return None @@ -78,23 +84,12 @@ def get_image_bytes(image: str | Image.Image | bytes) -> bytes: return None -def draw_colored_dot(draw, x, y, color, size): - """Draws a large colored dot on a PIL Image at the specified coordinate. - - Args: - draw: An ImageDraw object from PIL. - x: The x-coordinate of the center of the dot. - y: The y-coordinate of the center of the dot. - color: A color value that PIL understands (e.g., "red", (255, 0, 0), "#FF0000"). - size: The radius of the dot (in pixels). - """ - # Calculate the bounding box for the circle - bounding_box = (x - size, y - size, x + size, y + size) - # Draw a filled ellipse (which looks like a circle if the bounding box is a square) - draw.ellipse(bounding_box, fill=color) - - -def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: ImageDraw.ImageDraw = None) -> Image.Image | None: +def draw_bounding_boxes( + image: Image.Image | bytes, + detection: dict, + draw: ImageDraw.ImageDraw = None, + shape: Shape = Shape.RECTANGLE, +) -> Image.Image | None: """Draw bounding boxes on an image using PIL. The thickness of the box and font size are scaled based on image size. @@ -104,6 +99,8 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image detection (dict): A dictionary containing detection results with keys 'class_name', 'bounding_box_xyxy', and 'confidence'. draw (ImageDraw.ImageDraw, optional): An existing ImageDraw object to use. If None, a new one is created. + shape (Shape, optional): Shape of the bounding box. Defaults to rectangle. + itself. Defaults to False. """ if isinstance(image, bytes): image_box = Image.open(io.BytesIO(image)) @@ -116,6 +113,10 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image if not detection or "detection" not in detection: return None + if shape not in (Shape.RECTANGLE, Shape.CIRCLE): + logger.warning(f"Unsupported shape '{shape}'. Defaulting to rectangle.") + shape = Shape.RECTANGLE + detection = detection["detection"] # Scale font size and box thickness based on image size and number of detections @@ -163,12 +164,19 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image x2_text = x1 + text_width + label_hpad * 2 # Draw bounding box - draw.rectangle([x1, y1, x2, y2], outline=box_color, width=box_thickness) + if shape == Shape.CIRCLE: + center_x = int((x1 + x2) / 2) + center_y = int((y1 + y2) / 2) + radius = 10 + bounding_box = (center_x - radius, center_y - radius, center_x + radius, center_y + radius) + draw.ellipse(bounding_box, outline=box_color, width=2) + else: + draw.rectangle((x1, y1, x2, y2), outline=box_color, width=box_thickness) # Draw label background (dark gray, semi-transparent) on overlay label_bg_color = (0, 0, 0, 128) overlay = Image.new("RGBA", image_box.size, (0, 0, 0, 0)) overlay_draw = ImageDraw.Draw(overlay) - overlay_draw.rectangle([x1, y1_text, x2_text, y2_text], fill=label_bg_color, outline=None) + overlay_draw.rectangle((x1, y1_text, x2_text, y2_text), fill=label_bg_color, outline=None) image_box = image_box.convert("RGBA") image_box = Image.alpha_composite(image_box, overlay) draw = ImageDraw.Draw(image_box) diff --git a/tests/arduino/app_bricks/objectdetection/test_objectdetection.py b/tests/arduino/app_bricks/objectdetection/test_objectdetection.py index f98f779f..da041ab6 100644 --- a/tests/arduino/app_bricks/objectdetection/test_objectdetection.py +++ b/tests/arduino/app_bricks/objectdetection/test_objectdetection.py @@ -9,6 +9,11 @@ from arduino.app_bricks.object_detection import ObjectDetection +class ModelInfo: + def __init__(self, model_type: str): + self.model_type = model_type + + @pytest.fixture(autouse=True) def mock_dependencies(monkeypatch: pytest.MonkeyPatch): """Mock external dependencies in __init__. @@ -19,6 +24,7 @@ def mock_dependencies(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr("arduino.app_internal.core.load_brick_compose_file", lambda cls: fake_compose) monkeypatch.setattr("arduino.app_internal.core.resolve_address", lambda host: "127.0.0.1") monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "8100")]) + monkeypatch.setattr("arduino.app_bricks.object_detection.ObjectDetection.get_model_info", lambda self: ModelInfo("object-detection")) @pytest.fixture