Skip to content

Commit 3d8b8ea

Browse files
authored
[API] Add compatible models in brickinstance details (#99)
* add compatible models in brickinstance details * fix test e2e * fix test e2e * add unit tests * fix tests * remove omitempty from field compatible_modules * update field name for brick details endpoint
1 parent 8d4eb51 commit 3d8b8ea

File tree

7 files changed

+254
-66
lines changed

7 files changed

+254
-66
lines changed

internal/api/docs/openapi.yaml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,15 +1319,15 @@ components:
13191319
$ref: '#/components/schemas/CodeExample'
13201320
nullable: true
13211321
type: array
1322-
description:
1323-
type: string
1324-
id:
1325-
type: string
1326-
models:
1322+
compatible_models:
13271323
items:
13281324
$ref: '#/components/schemas/AIModel'
13291325
nullable: true
13301326
type: array
1327+
description:
1328+
type: string
1329+
id:
1330+
type: string
13311331
name:
13321332
type: string
13331333
readme:
@@ -1350,6 +1350,11 @@ components:
13501350
type: string
13511351
category:
13521352
type: string
1353+
compatible_models:
1354+
items:
1355+
$ref: '#/components/schemas/AIModel'
1356+
nullable: true
1357+
type: array
13531358
config_variables:
13541359
items:
13551360
$ref: '#/components/schemas/BrickConfigVariable'

internal/e2e/client/client.gen.go

Lines changed: 20 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/e2e/daemon/brick_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func TestBricksDetails(t *testing.T) {
144144
require.NotEmpty(t, *response.JSON200.Readme)
145145
require.NotNil(t, response.JSON200.UsedByApps, "UsedByApps should not be nil")
146146
require.Equal(t, expectedUsedByApps, *(response.JSON200.UsedByApps))
147-
require.NotNil(t, response.JSON200.Models, "Models should not be nil")
148-
require.Equal(t, expectedModelLiteInfo, *(response.JSON200.Models))
147+
require.NotNil(t, response.JSON200.CompatibleModels, "Models should not be nil")
148+
require.Equal(t, expectedModelLiteInfo, *(response.JSON200.CompatibleModels))
149149
})
150150
}

internal/e2e/daemon/instance_bricks_test.go renamed to internal/e2e/daemon/bricks_instance_test.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ var (
5151
Value: f.Ptr("/models/ootb/ei/mobilenet-v2-224px.eim"),
5252
},
5353
}
54+
55+
expectedModelInfo = []client.AIModel{
56+
{
57+
Id: f.Ptr("mobilenet-image-classification"),
58+
Name: f.Ptr("General purpose image classification"),
59+
Description: f.Ptr("General purpose image classification model based on MobileNetV2. This model is trained on the ImageNet dataset and can classify images into 1000 categories."),
60+
},
61+
{
62+
Id: f.Ptr("person-classification"),
63+
Name: f.Ptr("Person classification"),
64+
Description: f.Ptr("Person classification model based on WakeVision dataset. This model is trained to classify images into two categories: person and not-person."),
65+
}}
5466
)
5567

5668
func setupTestApp(t *testing.T) (*client.CreateAppResp, *client.ClientWithResponses) {
@@ -78,7 +90,6 @@ func setupTestApp(t *testing.T) (*client.CreateAppResp, *client.ClientWithRespon
7890
)
7991
require.NoError(t, err)
8092
require.Equal(t, http.StatusOK, resp.StatusCode())
81-
8293
return createResp, httpClient
8394
}
8495

@@ -135,6 +146,20 @@ func TestGetAppBrickInstanceById(t *testing.T) {
135146
require.NotEmpty(t, brickInstance.JSON200)
136147
require.Equal(t, ImageClassifactionBrickID, *brickInstance.JSON200.Id)
137148
require.Equal(t, expectedConfigVariables, (*brickInstance.JSON200.ConfigVariables))
149+
require.NotNil(t, brickInstance.JSON200.CompatibleModels)
150+
require.Equal(t, expectedModelInfo, *(brickInstance.JSON200.CompatibleModels))
151+
})
152+
t.Run("GetAppBrickInstanceByBrickIDWithCompatibleModels_Success", func(t *testing.T) {
153+
brickInstance, err := httpClient.GetAppBrickInstanceByBrickIDWithResponse(
154+
t.Context(),
155+
*createResp.JSON201.Id,
156+
ImageClassifactionBrickID,
157+
func(ctx context.Context, req *http.Request) error { return nil })
158+
require.NoError(t, err)
159+
require.NotEmpty(t, brickInstance.JSON200)
160+
require.Equal(t, ImageClassifactionBrickID, *brickInstance.JSON200.Id)
161+
require.NotNil(t, brickInstance.JSON200.CompatibleModels)
162+
require.Equal(t, expectedModelInfo, *(brickInstance.JSON200.CompatibleModels))
138163
})
139164

140165
t.Run("GetAppBrickInstanceByBrickID_InvalidAppID_Fails", func(t *testing.T) {

internal/orchestrator/bricks/bricks.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ func (s *Service) AppBrickInstanceDetails(a *app.ArduinoApp, brickID string) (Br
121121
Variables: variables,
122122
ConfigVariables: configVariables,
123123
ModelID: modelID,
124+
CompatibleModels: f.Map(s.modelsIndex.GetModelsByBrick(brick.ID), func(m modelsindex.AIModel) AIModel {
125+
return AIModel{
126+
ID: m.ID,
127+
Name: m.Name,
128+
Description: m.ModuleDescription,
129+
}
130+
}),
124131
}, nil
125132
}
126133

@@ -202,7 +209,7 @@ func (s *Service) BricksDetails(id string, idProvider *app.IDProvider,
202209
ApiDocsPath: apiDocsPath,
203210
CodeExamples: codeExamples,
204211
UsedByApps: usedByApps,
205-
Models: f.Map(s.modelsIndex.GetModelsByBrick(brick.ID), func(m modelsindex.AIModel) AIModel {
212+
CompatibleModels: f.Map(s.modelsIndex.GetModelsByBrick(brick.ID), func(m modelsindex.AIModel) AIModel {
206213
return AIModel{
207214
ID: m.ID,
208215
Name: m.Name,

internal/orchestrator/bricks/bricks_test.go

Lines changed: 162 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,13 @@ func TestBricksDetails(t *testing.T) {
418418
require.Len(t, res.UsedByApps, 1)
419419
require.Equal(t, "My App", res.UsedByApps[0].Name)
420420
require.NotEmpty(t, res.UsedByApps[0].ID)
421-
require.Len(t, res.Models, 2)
422-
require.Equal(t, "yolox-object-detection", res.Models[0].ID)
423-
require.Equal(t, "General purpose object detection - YoloX", res.Models[0].Name)
424-
require.Equal(t, "General purpose object detection...", res.Models[0].Description)
425-
require.Equal(t, "face-detection", res.Models[1].ID)
426-
require.Equal(t, "Lightweight-Face-Detection", res.Models[1].Name)
427-
require.Equal(t, "", res.Models[1].Description)
421+
require.Len(t, res.CompatibleModels, 2)
422+
require.Equal(t, "yolox-object-detection", res.CompatibleModels[0].ID)
423+
require.Equal(t, "General purpose object detection - YoloX", res.CompatibleModels[0].Name)
424+
require.Equal(t, "General purpose object detection...", res.CompatibleModels[0].Description)
425+
require.Equal(t, "face-detection", res.CompatibleModels[1].ID)
426+
require.Equal(t, "Lightweight-Face-Detection", res.CompatibleModels[1].Name)
427+
require.Equal(t, "", res.CompatibleModels[1].Description)
428428
})
429429

430430
t.Run("Success - Full Details - no models", func(t *testing.T) {
@@ -443,7 +443,7 @@ func TestBricksDetails(t *testing.T) {
443443
require.Len(t, res.UsedByApps, 1)
444444
require.Equal(t, "My App", res.UsedByApps[0].Name)
445445
require.NotEmpty(t, res.UsedByApps[0].ID)
446-
require.Len(t, res.Models, 0)
446+
require.Len(t, res.CompatibleModels, 0)
447447
})
448448

449449
t.Run("Success - Full Details - one model", func(t *testing.T) {
@@ -452,10 +452,10 @@ func TestBricksDetails(t *testing.T) {
452452

453453
require.Equal(t, "arduino:one_model_brick", res.ID)
454454
require.Equal(t, "one model brick", res.Name)
455-
require.Len(t, res.Models, 1)
456-
require.Equal(t, "face-detection", res.Models[0].ID)
457-
require.Equal(t, "Lightweight-Face-Detection", res.Models[0].Name)
458-
require.Equal(t, "", res.Models[0].Description)
455+
require.Len(t, res.CompatibleModels, 1)
456+
require.Equal(t, "face-detection", res.CompatibleModels[0].ID)
457+
require.Equal(t, "Lightweight-Face-Detection", res.CompatibleModels[0].Name)
458+
require.Equal(t, "", res.CompatibleModels[0].Description)
459459
})
460460
}
461461

@@ -489,3 +489,153 @@ bricks:
489489
require.NoError(t, os.MkdirAll(pythonDir, 0755))
490490
require.NoError(t, os.WriteFile(filepath.Join(pythonDir, "main.py"), []byte("print('hello')"), 0600))
491491
}
492+
493+
func TestAppBrickInstanceModelsDetails(t *testing.T) {
494+
495+
bIndex := &bricksindex.BricksIndex{
496+
Bricks: []bricksindex.Brick{
497+
{
498+
ID: "arduino:object_detection",
499+
Name: "Object Detection",
500+
Category: "video",
501+
ModelName: "yolox-object-detection", // Default model
502+
Variables: []bricksindex.BrickVariable{
503+
{Name: "EI_OBJ_DETECTION_MODEL", DefaultValue: "default_path", Description: "path to the model file"},
504+
{Name: "CUSTOM_MODEL_PATH", DefaultValue: "/home/arduino/.arduino-bricks/ei-models", Description: "path to the custom model directory"},
505+
},
506+
},
507+
{
508+
ID: "arduino:weather_forecast",
509+
Name: "Weather Forecast",
510+
Category: "miscellaneous",
511+
ModelName: "",
512+
},
513+
},
514+
}
515+
516+
mIndex := &modelsindex.ModelsIndex{
517+
Models: []modelsindex.AIModel{
518+
519+
{
520+
ID: "yolox-object-detection",
521+
Name: "General purpose object detection - YoloX",
522+
ModuleDescription: "General purpose object detection...",
523+
Bricks: []string{"arduino:object_detection", "arduino:video_object_detection"},
524+
},
525+
{
526+
ID: "face-detection",
527+
Name: "Lightweight-Face-Detection",
528+
Bricks: []string{"arduino:object_detection", "arduino:video_object_detection"},
529+
},
530+
}}
531+
532+
svc := &Service{
533+
bricksIndex: bIndex,
534+
modelsIndex: mIndex,
535+
}
536+
537+
tests := []struct {
538+
name string
539+
app *app.ArduinoApp
540+
brickID string
541+
expectedError string
542+
validate func(*testing.T, BrickInstance)
543+
}{
544+
{
545+
name: "Brick not found in global Index",
546+
brickID: "arduino:non_existent_brick",
547+
app: &app.ArduinoApp{
548+
Descriptor: app.AppDescriptor{Bricks: []app.Brick{}},
549+
},
550+
expectedError: "brick not found",
551+
},
552+
{
553+
name: "Brick found in Index but not added to App",
554+
brickID: "arduino:object_detection",
555+
app: &app.ArduinoApp{
556+
Descriptor: app.AppDescriptor{
557+
Bricks: []app.Brick{
558+
{ID: "arduino:weather_forecast"},
559+
},
560+
},
561+
},
562+
expectedError: "brick arduino:object_detection not added in the app",
563+
},
564+
{
565+
name: "Success - Standard Brick without Model",
566+
brickID: "arduino:weather_forecast",
567+
app: &app.ArduinoApp{
568+
Descriptor: app.AppDescriptor{
569+
Bricks: []app.Brick{
570+
{ID: "arduino:weather_forecast"},
571+
},
572+
},
573+
},
574+
validate: func(t *testing.T, res BrickInstance) {
575+
require.Equal(t, "arduino:weather_forecast", res.ID)
576+
require.Equal(t, "Weather Forecast", res.Name)
577+
require.Equal(t, "installed", res.Status)
578+
require.Empty(t, res.ModelID)
579+
require.Empty(t, res.CompatibleModels)
580+
},
581+
},
582+
{
583+
name: "Success - Brick with Default Model",
584+
brickID: "arduino:object_detection",
585+
app: &app.ArduinoApp{
586+
Descriptor: app.AppDescriptor{
587+
Bricks: []app.Brick{
588+
{
589+
ID: "arduino:object_detection",
590+
},
591+
},
592+
},
593+
},
594+
validate: func(t *testing.T, res BrickInstance) {
595+
require.Equal(t, "arduino:object_detection", res.ID)
596+
require.Equal(t, "yolox-object-detection", res.ModelID)
597+
require.Len(t, res.CompatibleModels, 2)
598+
require.Equal(t, "yolox-object-detection", res.CompatibleModels[0].ID)
599+
require.Equal(t, "face-detection", res.CompatibleModels[1].ID)
600+
},
601+
},
602+
{
603+
name: "Success - Brick with Overridden Model in App",
604+
brickID: "arduino:object_detection",
605+
app: &app.ArduinoApp{
606+
Descriptor: app.AppDescriptor{
607+
Bricks: []app.Brick{
608+
{
609+
ID: "arduino:object_detection",
610+
Model: "face-detection",
611+
},
612+
},
613+
},
614+
},
615+
validate: func(t *testing.T, res BrickInstance) {
616+
require.Equal(t, "arduino:object_detection", res.ID)
617+
require.Equal(t, "face-detection", res.ModelID)
618+
require.Len(t, res.CompatibleModels, 2)
619+
require.Equal(t, "yolox-object-detection", res.CompatibleModels[0].ID)
620+
require.Equal(t, "face-detection", res.CompatibleModels[1].ID)
621+
},
622+
},
623+
}
624+
625+
for _, tt := range tests {
626+
t.Run(tt.name, func(t *testing.T) {
627+
result, err := svc.AppBrickInstanceDetails(tt.app, tt.brickID)
628+
629+
if tt.expectedError != "" {
630+
require.Error(t, err)
631+
require.Equal(t, err.Error(), tt.expectedError)
632+
return
633+
}
634+
635+
require.NoError(t, err)
636+
if tt.validate != nil {
637+
tt.validate(t, result)
638+
}
639+
})
640+
}
641+
}

0 commit comments

Comments
 (0)