Skip to content

Commit fb43b08

Browse files
authored
feat(inspect): Select default model (#3190)
select completed models Signed-off-by: Colorado, Camilo <camilo.colorado@intel.com>
1 parent 95890c7 commit fb43b08

File tree

10 files changed

+86
-74
lines changed

10 files changed

+86
-74
lines changed

application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import { ComponentProps, Suspense, useEffect, useRef } from 'react';
22

3-
import { $api } from '@geti-inspect/api';
43
import { SchemaJob as Job, SchemaJob, SchemaJobStatus } from '@geti-inspect/api/spec';
54
import { useProjectIdentifier } from '@geti-inspect/hooks';
65
import { Content, Flex, Heading, InlineAlert, IntelBrandedLoading, ProgressBar, Text } from '@geti/ui';
76
import { useQueryClient } from '@tanstack/react-query';
87
import { isEqual } from 'lodash-es';
98

9+
import { useProjectTrainingJobs } from '../../../hooks/use-project-trainingJobs.hook';
1010
import { ShowJobLogs } from '../jobs/show-job-logs.component';
1111
import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils';
1212

@@ -88,25 +88,6 @@ const TrainingInProgress = ({ job }: TrainingInProgressProps) => {
8888
);
8989
};
9090

91-
const REFETCH_INTERVAL_WITH_TRAINING = 1_000;
92-
93-
export const useProjectTrainingJobs = () => {
94-
const { projectId } = useProjectIdentifier();
95-
96-
const { data } = $api.useQuery('get', '/api/jobs', undefined, {
97-
refetchInterval: ({ state }) => {
98-
const projectHasTrainingJob = state.data?.jobs.some(
99-
({ project_id, type, status }) =>
100-
projectId === project_id && type === 'training' && (status === 'running' || status === 'pending')
101-
);
102-
103-
return projectHasTrainingJob ? REFETCH_INTERVAL_WITH_TRAINING : undefined;
104-
},
105-
});
106-
107-
return { jobs: data?.jobs.filter((job) => job.project_id === projectId) };
108-
};
109-
11091
export const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => {
11192
const queryClient = useQueryClient();
11293
const { projectId } = useProjectIdentifier();

application/ui/src/features/inspect/footer/footer.component.tsx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ import { SchemaJob as Job } from '@geti-inspect/api/spec';
88
import { usePatchPipeline, usePipeline, useProjectIdentifier } from '@geti-inspect/hooks';
99
import { Flex, Loading, Text, View } from '@geti/ui';
1010
import { WaitingIcon } from '@geti/ui/icons';
11+
import { isEmpty } from 'lodash-es';
1112

12-
import { useTrainedModels } from '../../../hooks/use-model';
13+
import { useCompletedModels } from '../../../hooks/use-completed-models.hook';
1314
import { TrainingStatusItem } from './training-status-item.component';
1415

1516
const useCurrentJob = () => {
@@ -26,13 +27,13 @@ const useCurrentJob = () => {
2627
};
2728

2829
const useDefaultModel = () => {
29-
const models = useTrainedModels();
30+
const models = useCompletedModels();
3031
const { data: pipeline } = usePipeline();
3132
const { projectId } = useProjectIdentifier();
3233
const patchPipeline = usePatchPipeline(projectId);
3334

3435
const hasSelectedModel = pipeline?.model?.id !== undefined;
35-
const hasNonAvailableModels = models.length === 0;
36+
const hasNonAvailableModels = isEmpty(models.filter(({ status }) => status === 'Completed'));
3637

3738
useEffect(() => {
3839
if (hasSelectedModel || hasNonAvailableModels || patchPipeline.isPending) {

application/ui/src/features/inspect/models/export-model-dialog.component.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import { useMutation } from '@tanstack/react-query';
2020
import type { SchemaCompressionType, SchemaExportType } from 'src/api/openapi-spec';
2121
import { Onnx, OpenVino, PyTorch } from 'src/assets/icons';
2222

23+
import type { ModelData } from '../../../hooks/utils';
2324
import { downloadBlob, sanitizeFilename } from '../utils';
24-
import type { ModelData } from './model-types';
2525

2626
import classes from './export-model-dialog.module.scss';
2727

application/ui/src/features/inspect/models/model-actions-menu.component.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import { usePatchPipeline, useProjectIdentifier } from '@geti-inspect/hooks';
55
import { ActionButton, AlertDialog, DialogContainer, Item, Menu, MenuTrigger, toast, type Key } from '@geti/ui';
66
import { MoreMenu } from '@geti/ui/icons';
77

8+
import type { ModelData } from '../../../hooks/utils';
89
import { JobLogsDialog } from '../jobs/show-job-logs.component';
910
import { ExportModelDialog } from './export-model-dialog.component';
10-
import type { ModelData } from './model-types';
1111

1212
interface ModelActionsMenuProps {
1313
model: ModelData;

application/ui/src/features/inspect/models/models-view.component.tsx

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import { useMemo } from 'react';
22

3-
import { $api } from '@geti-inspect/api';
4-
import { usePipeline, useProjectIdentifier } from '@geti-inspect/hooks';
3+
import { usePipeline } from '@geti-inspect/hooks';
54
import {
65
Cell,
76
Column,
@@ -18,63 +17,25 @@ import {
1817
import { sortBy } from 'lodash-es';
1918
import { useDateFormatter } from 'react-aria';
2019

21-
import { useProjectTrainingJobs, useRefreshModelsOnJobUpdates } from '../dataset/dataset-status-panel.component';
20+
import { useCompletedModels } from '../../../hooks/use-completed-models.hook';
21+
import { useProjectTrainingJobs } from '../../../hooks/use-project-trainingJobs.hook';
22+
import type { ModelData } from '../../../hooks/utils';
23+
import { useRefreshModelsOnJobUpdates } from '../dataset/dataset-status-panel.component';
2224
import { formatSize } from '../utils';
2325
import { ModelActionsMenu } from './model-actions-menu.component';
2426
import { ModelStatusBadges } from './model-status-badges.component';
25-
import { ModelData } from './model-types';
2627

2728
import classes from './models-view.module.scss';
2829

29-
const useModels = () => {
30-
const { projectId } = useProjectIdentifier();
31-
const modelsQuery = $api.useSuspenseQuery('get', '/api/projects/{project_id}/models', {
32-
params: { path: { project_id: projectId } },
33-
});
34-
const models = modelsQuery.data.models;
35-
36-
return models;
37-
};
38-
3930
export const ModelsView = () => {
4031
const { data: pipeline } = usePipeline();
4132
const { jobs = [] } = useProjectTrainingJobs();
4233

4334
const dateFormatter = useDateFormatter({ dateStyle: 'medium', timeStyle: 'short' });
4435
const selectedModelId = pipeline.model?.id;
45-
useRefreshModelsOnJobUpdates(jobs);
46-
47-
const models = useModels()
48-
.filter((model) => model.is_ready)
49-
.map((model): ModelData | null => {
50-
const job = jobs.find(({ id }) => id === model.train_job_id);
51-
if (job === undefined) {
52-
return null;
53-
}
54-
55-
let timestamp = '';
56-
let durationInSeconds = 0;
57-
const start = job.start_time ? new Date(job.start_time) : new Date();
58-
if (job) {
59-
const end = job.end_time ? new Date(job.end_time) : new Date();
60-
durationInSeconds = Math.floor((end.getTime() - start.getTime()) / 1000);
61-
timestamp = dateFormatter.format(start);
62-
}
36+
const models = useCompletedModels();
6337

64-
return {
65-
id: model.id!,
66-
name: model.name!,
67-
status: 'Completed',
68-
architecture: model.name!,
69-
startTime: start.getTime(),
70-
timestamp,
71-
durationInSeconds,
72-
progress: 1.0,
73-
job,
74-
sizeBytes: model.size ?? null,
75-
};
76-
})
77-
.filter((model): model is ModelData => model !== null);
38+
useRefreshModelsOnJobUpdates(jobs);
7839

7940
const completedModelsJobsIDs = new Set(models.map((model) => model.job?.id));
8041

application/ui/src/features/inspect/toolbar/models-list/models-list.component.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { clsx } from 'clsx';
77
import { isEmpty } from 'lodash-es';
88
import { NotFound } from 'packages/ui/icons';
99

10-
import { useTrainedModels } from '../../../../hooks/use-model';
10+
import { useTrainedModels } from '../../../../hooks/use-trained-models';
1111

1212
import classes from './model-list.module.scss';
1313

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (C) 2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
import { useDateFormatter } from '@react-aria/i18n';
5+
6+
import { useProjectTrainingJobs } from './use-project-trainingJobs.hook';
7+
import { useTrainedModels } from './use-trained-models';
8+
import { ModelData } from './utils';
9+
10+
export const useCompletedModels = () => {
11+
const { jobs = [] } = useProjectTrainingJobs();
12+
13+
const dateFormatter = useDateFormatter({ dateStyle: 'medium', timeStyle: 'short' });
14+
15+
const models = useTrainedModels()
16+
.filter((model) => model.is_ready)
17+
.map((model): ModelData | null => {
18+
const job = jobs.find(({ id }) => id === model.train_job_id);
19+
if (job === undefined) {
20+
return null;
21+
}
22+
23+
let timestamp = '';
24+
let durationInSeconds = 0;
25+
const start = job.start_time ? new Date(job.start_time) : new Date();
26+
if (job) {
27+
const end = job.end_time ? new Date(job.end_time) : new Date();
28+
durationInSeconds = Math.floor((end.getTime() - start.getTime()) / 1000);
29+
timestamp = dateFormatter.format(start);
30+
}
31+
32+
return {
33+
id: model.id!,
34+
name: model.name!,
35+
status: 'Completed',
36+
architecture: model.name!,
37+
startTime: start.getTime(),
38+
timestamp,
39+
durationInSeconds,
40+
progress: 1.0,
41+
job,
42+
sizeBytes: model.size ?? null,
43+
};
44+
})
45+
.filter((model): model is ModelData => model !== null);
46+
47+
return models;
48+
};
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { $api } from '../api/client';
2+
import { useProjectIdentifier } from './use-project-identifier.hook';
3+
4+
const REFETCH_INTERVAL_WITH_TRAINING = 1_000;
5+
6+
export const useProjectTrainingJobs = () => {
7+
const { projectId } = useProjectIdentifier();
8+
9+
const { data } = $api.useQuery('get', '/api/jobs', undefined, {
10+
refetchInterval: ({ state }) => {
11+
const projectHasTrainingJob = state.data?.jobs.some(
12+
({ project_id, type, status }) =>
13+
projectId === project_id && type === 'training' && (status === 'running' || status === 'pending')
14+
);
15+
16+
return projectHasTrainingJob ? REFETCH_INTERVAL_WITH_TRAINING : undefined;
17+
},
18+
});
19+
20+
return { jobs: data?.jobs.filter((job) => job.project_id === projectId) };
21+
};

application/ui/src/hooks/use-model.tsx renamed to application/ui/src/hooks/use-trained-models.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ import { useProjectIdentifier } from '@geti-inspect/hooks';
66

77
export const useTrainedModels = () => {
88
const { projectId } = useProjectIdentifier();
9-
const { data } = $api.useQuery('get', '/api/projects/{project_id}/models', {
9+
const { data } = $api.useSuspenseQuery('get', '/api/projects/{project_id}/models', {
1010
params: {
1111
path: {
1212
project_id: projectId,
1313
},
1414
},
1515
});
1616

17-
return data?.models.map((model) => ({ id: model.id, name: model.name })) || [];
17+
return data.models;
1818
};

application/ui/src/features/inspect/models/model-types.ts renamed to application/ui/src/hooks/utils.ts

File renamed without changes.

0 commit comments

Comments
 (0)