Skip to content

Commit 3aba9d2

Browse files
app_bricks/object_detection: draw different bounding boxes based on the model type (#22)
1 parent 8c67fce commit 3aba9d2

File tree

3 files changed

+59
-25
lines changed

3 files changed

+59
-25
lines changed

src/arduino/app_bricks/object_detection/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA <http://www.arduino.cc>
22
#
33
# SPDX-License-Identifier: MPL-2.0
4+
from typing import Any
45

56
from PIL import Image
6-
from arduino.app_utils import brick, Logger, draw_bounding_boxes
7+
from arduino.app_utils import brick, Logger, draw_bounding_boxes, Shape
78
from arduino.app_internal.core import EdgeImpulseRunnerFacade
89

910
logger = Logger("ObjectDetection")
@@ -20,8 +21,19 @@ class ObjectDetection(EdgeImpulseRunnerFacade):
2021
"""
2122

2223
def __init__(self, confidence: float = 0.3):
24+
"""Initialize the ObjectDetection module.
25+
26+
Args:
27+
confidence (float): Minimum confidence threshold for detections. Default is 0.3 (30%).
28+
29+
Raises:
30+
ValueError: If model information cannot be retrieved.
31+
"""
2332
self.confidence = confidence
2433
super().__init__()
34+
self._model_info = self.get_model_info()
35+
if not self._model_info:
36+
raise ValueError("Failed to retrieve model information. Ensure the Edge Impulse service is running.")
2537

2638
def detect_from_file(self, image_path: str, confidence: float = None) -> dict | None:
2739
"""Process a local image file to detect and identify objects.
@@ -38,7 +50,7 @@ def detect_from_file(self, image_path: str, confidence: float = None) -> dict |
3850
ret = super().infer_from_file(image_path)
3951
return self._extract_detection(ret, confidence)
4052

41-
def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) -> dict:
53+
def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) -> dict[str, list[Any]] | None:
4254
"""Process an in-memory image to detect and identify objects.
4355
4456
Args:
@@ -63,9 +75,17 @@ def draw_bounding_boxes(self, image: Image.Image | bytes, detections: dict) -> I
6375
6476
Returns:
6577
Image with bounding boxes and key points drawn.
66-
None if no detection or invalid image.
78+
None if input image or detections are invalid.
6779
"""
68-
return draw_bounding_boxes(image, detections)
80+
if not image or not detections:
81+
return None
82+
83+
shape = None
84+
if self._model_info.model_type == "object_detection":
85+
shape = Shape.RECTANGLE
86+
elif self._model_info.model_type == "constrained_object_detection":
87+
shape = Shape.CIRCLE
88+
return draw_bounding_boxes(image, detections, shape=shape)
6989

7090
def _extract_detection(self, item, confidence: float = None):
7191
if not item:

src/arduino/app_utils/image.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88

99
logger = Logger(__name__)
1010

11+
12+
class Shape:
13+
RECTANGLE = "rectangle"
14+
CIRCLE = "circle"
15+
16+
1117
# Define a mapping of confidence ranges to colors for bounding boxes
1218
CONFIDENCE_MAP = {
1319
(0, 20): "#FF0976", # Pink
@@ -40,7 +46,7 @@ def _read(file_path: str) -> bytes:
4046

4147

4248
def get_image_type(image_bytes: bytes | Image.Image) -> str | None:
43-
"""Detect the type of an image from bytes or a PIL Image object.
49+
"""Detect the type of image from bytes or a PIL Image object.
4450
4551
Returns:
4652
str: The image type in lowercase (e.g., 'jpeg', 'png').
@@ -60,7 +66,7 @@ def get_image_type(image_bytes: bytes | Image.Image) -> str | None:
6066
return None
6167

6268

63-
def get_image_bytes(image: str | Image.Image | bytes) -> bytes:
69+
def get_image_bytes(image: str | Image.Image | bytes) -> bytes | None:
6470
"""Convert different type of image objects to bytes."""
6571
if image is None:
6672
return None
@@ -78,23 +84,12 @@ def get_image_bytes(image: str | Image.Image | bytes) -> bytes:
7884
return None
7985

8086

81-
def draw_colored_dot(draw, x, y, color, size):
82-
"""Draws a large colored dot on a PIL Image at the specified coordinate.
83-
84-
Args:
85-
draw: An ImageDraw object from PIL.
86-
x: The x-coordinate of the center of the dot.
87-
y: The y-coordinate of the center of the dot.
88-
color: A color value that PIL understands (e.g., "red", (255, 0, 0), "#FF0000").
89-
size: The radius of the dot (in pixels).
90-
"""
91-
# Calculate the bounding box for the circle
92-
bounding_box = (x - size, y - size, x + size, y + size)
93-
# Draw a filled ellipse (which looks like a circle if the bounding box is a square)
94-
draw.ellipse(bounding_box, fill=color)
95-
96-
97-
def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: ImageDraw.ImageDraw = None) -> Image.Image | None:
87+
def draw_bounding_boxes(
88+
image: Image.Image | bytes,
89+
detection: dict,
90+
draw: ImageDraw.ImageDraw = None,
91+
shape: Shape = Shape.RECTANGLE,
92+
) -> Image.Image | None:
9893
"""Draw bounding boxes on an image using PIL.
9994
10095
The thickness of the box and font size are scaled based on image size.
@@ -104,6 +99,8 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
10499
detection (dict): A dictionary containing detection results with keys 'class_name', 'bounding_box_xyxy', and
105100
'confidence'.
106101
draw (ImageDraw.ImageDraw, optional): An existing ImageDraw object to use. If None, a new one is created.
102+
shape (Shape, optional): Shape of the bounding box. Defaults to rectangle.
103+
itself. Defaults to False.
107104
"""
108105
if isinstance(image, bytes):
109106
image_box = Image.open(io.BytesIO(image))
@@ -116,6 +113,10 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
116113
if not detection or "detection" not in detection:
117114
return None
118115

116+
if shape not in (Shape.RECTANGLE, Shape.CIRCLE):
117+
logger.warning(f"Unsupported shape '{shape}'. Defaulting to rectangle.")
118+
shape = Shape.RECTANGLE
119+
119120
detection = detection["detection"]
120121

121122
# Scale font size and box thickness based on image size and number of detections
@@ -163,12 +164,19 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image
163164
x2_text = x1 + text_width + label_hpad * 2
164165

165166
# Draw bounding box
166-
draw.rectangle([x1, y1, x2, y2], outline=box_color, width=box_thickness)
167+
if shape == Shape.CIRCLE:
168+
center_x = int((x1 + x2) / 2)
169+
center_y = int((y1 + y2) / 2)
170+
radius = 10
171+
bounding_box = (center_x - radius, center_y - radius, center_x + radius, center_y + radius)
172+
draw.ellipse(bounding_box, outline=box_color, width=2)
173+
else:
174+
draw.rectangle((x1, y1, x2, y2), outline=box_color, width=box_thickness)
167175
# Draw label background (dark gray, semi-transparent) on overlay
168176
label_bg_color = (0, 0, 0, 128)
169177
overlay = Image.new("RGBA", image_box.size, (0, 0, 0, 0))
170178
overlay_draw = ImageDraw.Draw(overlay)
171-
overlay_draw.rectangle([x1, y1_text, x2_text, y2_text], fill=label_bg_color, outline=None)
179+
overlay_draw.rectangle((x1, y1_text, x2_text, y2_text), fill=label_bg_color, outline=None)
172180
image_box = image_box.convert("RGBA")
173181
image_box = Image.alpha_composite(image_box, overlay)
174182
draw = ImageDraw.Draw(image_box)

tests/arduino/app_bricks/objectdetection/test_objectdetection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from arduino.app_bricks.object_detection import ObjectDetection
1010

1111

12+
class ModelInfo:
13+
def __init__(self, model_type: str):
14+
self.model_type = model_type
15+
16+
1217
@pytest.fixture(autouse=True)
1318
def mock_dependencies(monkeypatch: pytest.MonkeyPatch):
1419
"""Mock external dependencies in __init__.
@@ -19,6 +24,7 @@ def mock_dependencies(monkeypatch: pytest.MonkeyPatch):
1924
monkeypatch.setattr("arduino.app_internal.core.load_brick_compose_file", lambda cls: fake_compose)
2025
monkeypatch.setattr("arduino.app_internal.core.resolve_address", lambda host: "127.0.0.1")
2126
monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "8100")])
27+
monkeypatch.setattr("arduino.app_bricks.object_detection.ObjectDetection.get_model_info", lambda self: ModelInfo("object-detection"))
2228

2329

2430
@pytest.fixture

0 commit comments

Comments
 (0)