diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index cd0db6a20..0e96028a5 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -29,6 +29,9 @@ type TaskInfoResponse struct { TerminationReason string `json:"termination_reason"` TerminationMessage string `json:"termination_message"` Ports []shim.PortMapping `json:"ports"` + + ImagePullProgress *shim.ImagePullProgress `json:"image_pull_progress"` + // The following fields are for debugging only, server doesn't need them ContainerName string `json:"container_name"` ContainerID string `json:"container_id"` diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 6acfb27a5..d2df8fc86 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -16,6 +16,7 @@ import ( rt "runtime" "strconv" "strings" + "sync" "time" "github.com/docker/docker/api/types/container" @@ -50,6 +51,86 @@ const ( LabelValueTrue = "true" ) +// dockerd reports pulling progress as a stream of JSON Lines. The format of records is not documented in the API documentation, +// although it's occasionally mentioned, e.g., https://docs.docker.com/reference/api/engine/version-history/#v148-api-changes +// https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/pkg/jsonmessage/jsonmessage.go#L144 +// All fields are optional. +type PullMessage struct { + Id string `json:"id"` // layer id + Status string `json:"status"` + ProgressDetail struct { + Current uint64 `json:"current"` // bytes + Total uint64 `json:"total"` // bytes + } `json:"progressDetail"` + ErrorDetail struct { + Message string `json:"message"` + } `json:"errorDetail"` +} + +type layerProgress struct { + Status string + DownloadedBytes uint64 + ExtractedBytes uint64 + TotalBytes uint64 +} + +type PullTracker struct { + mu sync.RWMutex + layers map[string]layerProgress +} + +func newPullTracker() *PullTracker { + return &PullTracker{layers: make(map[string]layerProgress)} +} + +func (t *PullTracker) Update(msg PullMessage) { + if msg.Id == "" { + return + } + t.mu.Lock() + defer t.mu.Unlock() + layer := t.layers[msg.Id] + switch msg.Status { + case "Pulling fs layer", "Waiting", "Verifying Checksum", "Already exists": + // no bytes to update, just track status + case "Downloading": + layer.DownloadedBytes = msg.ProgressDetail.Current + layer.TotalBytes = msg.ProgressDetail.Total + case "Download complete": + layer.DownloadedBytes = layer.TotalBytes + case "Extracting": + layer.ExtractedBytes = msg.ProgressDetail.Current + layer.DownloadedBytes = msg.ProgressDetail.Total + layer.TotalBytes = msg.ProgressDetail.Total + case "Pull complete": + layer.ExtractedBytes = layer.TotalBytes + layer.DownloadedBytes = layer.TotalBytes + default: + // Non-layer events, such as {"status":"Pulling from library/python","id":"3.11"} + return + } + layer.Status = msg.Status + t.layers[msg.Id] = layer +} + +func (t *PullTracker) Progress() *ImagePullProgress { + t.mu.RLock() + defer t.mu.RUnlock() + if len(t.layers) == 0 { + return nil + } + p := ImagePullProgress{IsTotalBytesFinal: true} + for _, l := range t.layers { + if l.TotalBytes == 0 && l.Status != "Already exists" && l.Status != "Pull complete" { + p.IsTotalBytesFinal = false + } + p.DownloadedBytes += l.DownloadedBytes + p.ExtractedBytes += l.ExtractedBytes + p.TotalBytes += l.TotalBytes + } + return &p +} + type DockerRunner struct { client *docker.Client dockerParams DockerParameters @@ -239,6 +320,7 @@ func (d *DockerRunner) TaskInfo(taskID string) TaskInfo { ContainerName: task.containerName, ContainerID: task.containerID, GpuIDs: task.gpuIDs, + ImagePullProgress: task.pullTracker.Progress(), } } @@ -350,7 +432,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error { // Although it's called "runner dir", we also use it for shim task-related data. // Maybe we should rename it to "task dir" (including the `/root/.dstack/runners` dir on the host). pullLogPath := filepath.Join(runnerDir, "pull.log") - if err = pullImage(pullCtx, d.client, cfg, pullLogPath); err != nil { + if err = pullImage(pullCtx, d.client, cfg, pullLogPath, task.pullTracker); err != nil { errMessage := fmt.Sprintf("pullImage error: %s", err.Error()) log.Error(ctx, errMessage) task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage) @@ -670,7 +752,7 @@ func mountDisk(ctx context.Context, deviceName, mountPoint string, fsRootPerms o return nil } -func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig, logPath string) error { +func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig, logPath string, tracker *PullTracker) error { if !strings.Contains(taskConfig.ImageName, ":") { taskConfig.ImageName += ":latest" } @@ -710,26 +792,6 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf teeReader := io.TeeReader(reader, logFile) - current := make(map[string]uint) - total := make(map[string]uint) - - // dockerd reports pulling progress as a stream of JSON Lines. The format of records is not documented in the API documentation, - // although it's occasionally mentioned, e.g., https://docs.docker.com/reference/api/engine/version-history/#v148-api-changes - - // https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/pkg/jsonmessage/jsonmessage.go#L144 - // All fields are optional - type PullMessage struct { - Id string `json:"id"` // layer id - Status string `json:"status"` - ProgressDetail struct { - Current uint `json:"current"` // bytes - Total uint `json:"total"` // bytes - } `json:"progressDetail"` - ErrorDetail struct { - Message string `json:"message"` - } `json:"errorDetail"` - } - var pullCompleted bool pullErrors := make([]string, 0) @@ -740,13 +802,7 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf if err := json.Unmarshal(line, &pullMessage); err != nil { continue } - if pullMessage.Status == "Downloading" { - current[pullMessage.Id] = pullMessage.ProgressDetail.Current - total[pullMessage.Id] = pullMessage.ProgressDetail.Total - } - if pullMessage.Status == "Download complete" { - current[pullMessage.Id] = total[pullMessage.Id] - } + tracker.Update(pullMessage) if pullMessage.ErrorDetail.Message != "" { log.Error(ctx, "error pulling image", "name", taskConfig.ImageName, "err", pullMessage.ErrorDetail.Message) pullErrors = append(pullErrors, pullMessage.ErrorDetail.Message) @@ -764,13 +820,10 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf } duration := time.Since(startTime) - var currentBytes uint - var totalBytes uint - for _, v := range current { - currentBytes += v - } - for _, v := range total { - totalBytes += v + p := tracker.Progress() + var currentBytes, totalBytes uint64 + if p != nil { + currentBytes, totalBytes = p.DownloadedBytes, p.TotalBytes } speed := bytesize.New(float64(currentBytes) / duration.Seconds()) diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 18f8c31fc..68e45ed1c 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -155,3 +155,112 @@ func createTaskConfig(t *testing.T) TaskConfig { ImageName: "ubuntu", } } + +func pullMsg(id, status string, current, total uint64) PullMessage { + m := PullMessage{Id: id, Status: status} + m.ProgressDetail.Current = current + m.ProgressDetail.Total = total + return m +} + +func TestPullTracker_Empty(t *testing.T) { + tracker := newPullTracker() + assert.Nil(t, tracker.Progress()) +} + +func TestPullTracker_AlreadyExists(t *testing.T) { + tracker := newPullTracker() + tracker.Update(PullMessage{Id: "3.11", Status: "Pulling from library/python"}) + for _, id := range []string{"aaa", "bbb", "ccc"} { + tracker.Update(PullMessage{Id: id, Status: "Already exists"}) + } + tracker.Update(PullMessage{Status: "Digest: sha256:***"}) + tracker.Update(PullMessage{Status: "Status: Image is up to date for python:3.11"}) + p := tracker.Progress() + require.NotNil(t, p) + assert.Equal(t, uint64(0), p.DownloadedBytes) + assert.Equal(t, uint64(0), p.ExtractedBytes) + assert.Equal(t, uint64(0), p.TotalBytes) + assert.True(t, p.IsTotalBytesFinal) +} + +func TestPullTracker_FullPull(t *testing.T) { + const sizeA, sizeB uint64 = 111, 222 + + tracker := newPullTracker() + tracker.Update(PullMessage{Id: "3.11", Status: "Pulling from library/python"}) + tracker.Update(PullMessage{Id: "aaa", Status: "Pulling fs layer"}) + tracker.Update(PullMessage{Id: "bbb", Status: "Pulling fs layer"}) + tracker.Update(PullMessage{Id: "aaa", Status: "Waiting"}) + tracker.Update(PullMessage{Id: "bbb", Status: "Waiting"}) + + // Layers announced but sizes unknown yet + p := tracker.Progress() + require.NotNil(t, p) + assert.Equal(t, uint64(0), p.DownloadedBytes) + assert.Equal(t, uint64(0), p.ExtractedBytes) + assert.Equal(t, uint64(0), p.TotalBytes) + assert.False(t, p.IsTotalBytesFinal) + + // Both layers start downloading - sizes now known + tracker.Update(pullMsg("aaa", "Downloading", 100, sizeA)) + tracker.Update(pullMsg("bbb", "Downloading", 200, sizeB)) + + p = tracker.Progress() + assert.Equal(t, uint64(300), p.DownloadedBytes) + assert.Equal(t, uint64(0), p.ExtractedBytes) + assert.Equal(t, sizeA+sizeB, p.TotalBytes) + assert.True(t, p.IsTotalBytesFinal) + + // Downloads complete + tracker.Update(pullMsg("aaa", "Downloading", sizeA, sizeA)) + tracker.Update(PullMessage{Id: "aaa", Status: "Download complete"}) + tracker.Update(pullMsg("bbb", "Downloading", sizeB, sizeB)) + tracker.Update(PullMessage{Id: "bbb", Status: "Download complete"}) + + p = tracker.Progress() + assert.Equal(t, sizeA+sizeB, p.DownloadedBytes) + assert.Equal(t, uint64(0), p.ExtractedBytes) + assert.Equal(t, sizeA+sizeB, p.TotalBytes) + assert.True(t, p.IsTotalBytesFinal) + + // Both layers start extracting + tracker.Update(pullMsg("aaa", "Extracting", 100, sizeA)) + tracker.Update(pullMsg("bbb", "Extracting", 200, sizeB)) + + p = tracker.Progress() + assert.Equal(t, sizeA+sizeB, p.DownloadedBytes) + assert.Equal(t, uint64(300), p.ExtractedBytes) + assert.Equal(t, sizeA+sizeB, p.TotalBytes) + assert.True(t, p.IsTotalBytesFinal) + + // Extractions complete + tracker.Update(pullMsg("aaa", "Extracting", sizeA, sizeA)) + tracker.Update(PullMessage{Id: "aaa", Status: "Pull complete"}) + tracker.Update(pullMsg("bbb", "Extracting", sizeB, sizeB)) + tracker.Update(PullMessage{Id: "bbb", Status: "Pull complete"}) + tracker.Update(PullMessage{Status: "Digest: sha256:***"}) + tracker.Update(PullMessage{Status: "Status: Downloaded newer image for python:3.11"}) + + p = tracker.Progress() + assert.Equal(t, sizeA+sizeB, p.DownloadedBytes) + assert.Equal(t, sizeA+sizeB, p.ExtractedBytes) + assert.Equal(t, sizeA+sizeB, p.TotalBytes) + assert.True(t, p.IsTotalBytesFinal) +} + +func TestPullTracker_MixedLayerStatuses(t *testing.T) { + tracker := newPullTracker() + + tracker.Update(PullMessage{Id: "layer-exists", Status: "Already exists"}) + tracker.Update(pullMsg("layer-downloading", "Downloading", 50, 100)) + tracker.Update(pullMsg("layer-extracting", "Extracting", 100, 200)) + tracker.Update(PullMessage{Id: "layer-waiting", Status: "Waiting"}) + + p := tracker.Progress() + require.NotNil(t, p) + assert.Equal(t, uint64(50+200), p.DownloadedBytes) + assert.Equal(t, uint64(100), p.ExtractedBytes) + assert.Equal(t, uint64(100+200), p.TotalBytes) + assert.False(t, p.IsTotalBytesFinal) // layer-waiting size unknown +} diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index d50fe6e29..b0b7852d8 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -108,12 +108,20 @@ type TaskListItem struct { Status TaskStatus `json:"status"` } +type ImagePullProgress struct { + DownloadedBytes uint64 `json:"downloaded_bytes"` + ExtractedBytes uint64 `json:"extracted_bytes"` + TotalBytes uint64 `json:"total_bytes"` + IsTotalBytesFinal bool `json:"is_total_bytes_final"` +} + type TaskInfo struct { ID string Status TaskStatus TerminationReason string TerminationMessage string Ports []PortMapping + ImagePullProgress *ImagePullProgress ContainerName string ContainerID string GpuIDs []string diff --git a/runner/internal/shim/task.go b/runner/internal/shim/task.go index d2fef7e02..ea3ad7c96 100644 --- a/runner/internal/shim/task.go +++ b/runner/internal/shim/task.go @@ -42,6 +42,8 @@ type Task struct { ports []PortMapping runnerDir string // path on host mapped to consts.RunnerDir in container + pullTracker *PullTracker + mu *sync.Mutex } @@ -128,6 +130,7 @@ func NewTask(id string, status TaskStatus, containerName string, containerID str runnerDir: runnerDir, gpuIDs: gpuIDs, ports: ports, + pullTracker: newPullTracker(), mu: &sync.Mutex{}, } } @@ -138,6 +141,7 @@ func NewTaskFromConfig(cfg TaskConfig) Task { Status: TaskStatusPending, config: cfg, containerName: generateUniqueName(cfg.Name, cfg.ID), + pullTracker: newPullTracker(), mu: &sync.Mutex{}, } } diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 351bea9c0..2e53f93ce 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -27,6 +27,7 @@ TerminationPolicy, ) from dstack._internal.core.models.runs import ( + ImagePullProgress, Job, JobStatus, JobSubmission, @@ -316,12 +317,32 @@ def _format_job_submission_status(job_submission: JobSubmission, verbose: bool) color = color_map.get(job_status, "white") status_style = f"bold {color}" if not job_status.is_finished() else color formatted_status_message = f"[{status_style}]{status_message}[/]" + if job_status == JobStatus.PULLING and job_submission.image_pull_progress is not None: + formatted_status_message += ( + f" [secondary]{_format_pull_progress(job_submission.image_pull_progress)}[/]" + ) if verbose and job_submission.inactivity_secs: inactive_for = format_duration_multiunit(job_submission.inactivity_secs) formatted_status_message += f" (inactive for {inactive_for})" return formatted_status_message +def _format_pull_progress(progress: ImagePullProgress) -> str: + if progress.total_bytes >= 2**30: # 1GB + unit = "GB" + + def f(x: int) -> str: + return f"{x / 2**30:.2f}" + else: + unit = "MB" + + def f(x: int) -> str: + return f"{x / 2**20:.0f}" + + total_sign = "≥" if not progress.is_total_bytes_final else "" + return f"{f(progress.extracted_bytes)}/{f(progress.downloaded_bytes)}/{total_sign}{f(progress.total_bytes)}{unit}" + + def _get_show_deployment_replica_job(run: CoreRun, verbose: bool) -> tuple[bool, bool, bool]: show_deployment_num = ( verbose and run.run_spec.configuration.type == "service" @@ -434,7 +455,7 @@ def get_runs_table( else: table.add_column("GPU", ratio=2) table.add_column("PRICE", style="grey58", ratio=1) - table.add_column("STATUS", no_wrap=True, ratio=1) + table.add_column("STATUS", ratio=1) if verbose or any( run._run.is_deployment_in_progress() and any(job.job_submissions[-1].probes for job in run._run.jobs) diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 68566f3a5..b9b958f03 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -147,6 +147,9 @@ def get_job_submission_excludes(job_submissions: list[JobSubmission]) -> Include jrd_excludes["working_dir"] = True submission_excludes["job_runtime_data"] = jrd_excludes + if all(s.image_pull_progress is None for s in job_submissions): + submission_excludes["image_pull_progress"] = True + return submission_excludes diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 4d542fc0e..baecfac44 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -399,6 +399,19 @@ class Probe(CoreModel): success_streak: int +class ImagePullProgress(CoreModel): + downloaded_bytes: int + extracted_bytes: int + total_bytes: int + """An estimate of the number of bytes to be downloaded and extracted during this pull. + Does not include cached layers that existed on the instance before the pull. + """ + is_total_bytes_final: bool + """Whether `total_bytes` is believed to be the correct final value. + If `False`, then `total_bytes` is a lower estimate. + """ + + class JobSubmission(CoreModel): id: UUID4 submission_num: int @@ -421,6 +434,7 @@ class JobSubmission(CoreModel): job_runtime_data: Optional[JobRuntimeData] = None error: Optional[str] = None probes: list[Probe] = [] + image_pull_progress: Optional[ImagePullProgress] = None @property def age(self) -> timedelta: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 0bd995527..7c7483a34 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -25,6 +25,7 @@ from dstack._internal.core.models.repos import RemoteRepoCreds from dstack._internal.core.models.runs import ( ClusterInfo, + ImagePullProgress, Job, JobProvisioningData, JobRuntimeData, @@ -337,6 +338,7 @@ class _JobUpdateMap(ItemUpdateMap, total=False): inactivity_secs: Optional[int] exit_status: Optional[int] registered: bool + image_pull_progress: Optional[str] @dataclass @@ -730,6 +732,9 @@ async def _process_pulling_status( if shim_state.job_runtime_data is not None: _set_job_runtime_data(result, shim_state.job_runtime_data) + if shim_state.image_pull_progress is not None: + result.job_update_map["image_pull_progress"] = shim_state.image_pull_progress.json() + if shim_state.state == _ShimPullingState.WAITING: _reset_disconnected_at(context.job_model, result) return @@ -1288,6 +1293,7 @@ class _SyncShimPullingStateResult: termination_reason: Optional[JobTerminationReason] = None termination_reason_message: Optional[str] = None job_runtime_data: Optional[JobRuntimeData] = None + image_pull_progress: Optional[ImagePullProgress] = None @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) @@ -1305,8 +1311,12 @@ def _sync_shim_pulling_state( jrd: Optional[JobRuntimeData] = None, ) -> Union[_SyncShimPullingStateResult, Literal[False]]: shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + image_pull_progress: Optional[ImagePullProgress] = None if shim_client.is_api_v2_supported(): task = shim_client.get_task(job_model.id) + if task.image_pull_progress is not None: + image_pull_progress = task.image_pull_progress + if task.status == TaskStatus.TERMINATED: logger.warning( "shim failed to execute job %s: %s (%s)", @@ -1319,14 +1329,21 @@ def _sync_shim_pulling_state( state=_ShimPullingState.FAILED, termination_reason=JobTerminationReason(task.termination_reason.lower()), termination_reason_message=task.termination_message, + image_pull_progress=image_pull_progress, ) if task.status != TaskStatus.RUNNING: - return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + return _SyncShimPullingStateResult( + state=_ShimPullingState.WAITING, + image_pull_progress=image_pull_progress, + ) if jrd is not None: if task.ports is None: - return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + return _SyncShimPullingStateResult( + state=_ShimPullingState.WAITING, + image_pull_progress=image_pull_progress, + ) jrd = jrd.copy(update={"ports": {pm.container: pm.host for pm in task.ports}}) else: shim_status = shim_client.pull() @@ -1346,14 +1363,19 @@ def _sync_shim_pulling_state( state=_ShimPullingState.FAILED, termination_reason=JobTerminationReason(shim_status.result.reason.lower()), termination_reason_message=shim_status.result.reason_message, + image_pull_progress=image_pull_progress, ) if shim_status.state in ("pulling", "creating"): - return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + return _SyncShimPullingStateResult( + state=_ShimPullingState.WAITING, + image_pull_progress=image_pull_progress, + ) return _SyncShimPullingStateResult( state=_ShimPullingState.READY, job_runtime_data=jrd, + image_pull_progress=image_pull_progress, ) diff --git a/src/dstack/_internal/server/migrations/versions/2026/04_18_1822_f48b23790053_add_jobmodel_image_pull_progress.py b/src/dstack/_internal/server/migrations/versions/2026/04_18_1822_f48b23790053_add_jobmodel_image_pull_progress.py new file mode 100644 index 000000000..33b46a73f --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/04_18_1822_f48b23790053_add_jobmodel_image_pull_progress.py @@ -0,0 +1,32 @@ +"""Add JobModel.image_pull_progress + +Revision ID: f48b23790053 +Revises: 94fcd7e38b7e +Create Date: 2026-04-18 18:22:47.121819+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f48b23790053" +down_revision = "94fcd7e38b7e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column(sa.Column("image_pull_progress", sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("image_pull_progress") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 64990e96a..0b7595629 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -578,6 +578,7 @@ class JobModel(PipelineModelMixin, BaseModel): for example to provision instances for all jobs when processing master. If not set, all jobs should be processed only one-by-one. """ + image_pull_progress: Mapped[Optional[str]] = mapped_column(Text) __table_args__ = ( Index( diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 549ff7914..c1ad0407d 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -9,6 +9,7 @@ from dstack._internal.core.models.repos.remote import RemoteRepoCreds from dstack._internal.core.models.runs import ( ClusterInfo, + ImagePullProgress, JobSpec, JobStatus, JobSubmission, @@ -223,6 +224,7 @@ class TaskInfoResponse(CoreModel): """`ports` uses a default value for backward compatibility with 0.18.34. It can be removed after a few releases. """ + image_pull_progress: Optional[ImagePullProgress] = None class TaskSubmitRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 32828cf3d..da39e661e 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -18,6 +18,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RunConfigurationType from dstack._internal.core.models.runs import ( + ImagePullProgress, Job, JobConnectionInfo, JobProvisioningData, @@ -261,6 +262,7 @@ def job_model_to_job_submission( job_runtime_data=get_job_runtime_data(job_model), error=error, probes=probes, + image_pull_progress=_get_image_pull_progress(job_model), ) @@ -276,6 +278,12 @@ def get_job_runtime_data(job_model: JobModel) -> Optional[JobRuntimeData]: return JobRuntimeData.__response__.parse_raw(job_model.job_runtime_data) +def _get_image_pull_progress(job_model: JobModel) -> Optional[ImagePullProgress]: + if job_model.image_pull_progress is None: + return None + return ImagePullProgress.__response__.parse_raw(job_model.image_pull_progress) + + def get_job_spec(job_model: JobModel) -> JobSpec: return JobSpec.__response__.parse_raw(job_model.job_spec_data) diff --git a/src/tests/_internal/cli/utils/test_run.py b/src/tests/_internal/cli/utils/test_run.py index 3ed665d93..4ff911439 100644 --- a/src/tests/_internal/cli/utils/test_run.py +++ b/src/tests/_internal/cli/utils/test_run.py @@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from dstack._internal.cli.utils.run import get_runs_table +from dstack._internal.cli.utils.run import _format_pull_progress, get_runs_table from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ( AnyRunConfiguration, @@ -21,6 +21,7 @@ from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import ( + ImagePullProgress, JobProvisioningData, JobStatus, JobTerminationReason, @@ -200,13 +201,9 @@ async def create_run_with_job( ) -pytestmark = [ - pytest.mark.asyncio, - pytest.mark.usefixtures("test_db", "image_config_mock"), - pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True), -] - - +@pytest.mark.asyncio +@pytest.mark.usefixtures("test_db", "image_config_mock") +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGetRunsTable: async def test_simple_run(self, session: AsyncSession): api_run = await create_run_with_job(session=session) @@ -519,3 +516,62 @@ async def test_service_with_multiple_replicas_and_jobs(self, session: AsyncSessi assert f"replica={i - 1}" in job_row["NAME"] assert "job=" not in job_row["NAME"] assert job_row["STATUS"] == "running" + + +@pytest.mark.parametrize( + "progress,expected", + [ + pytest.param( + ImagePullProgress( + downloaded_bytes=300 * 2**20, + extracted_bytes=200 * 2**20, + total_bytes=500 * 2**20, + is_total_bytes_final=True, + ), + "200/300/500MB", + id="mb_final", + ), + pytest.param( + ImagePullProgress( + downloaded_bytes=300 * 2**20, + extracted_bytes=200 * 2**20, + total_bytes=500 * 2**20, + is_total_bytes_final=False, + ), + "200/300/≥500MB", + id="mb_non_final", + ), + pytest.param( + ImagePullProgress( + downloaded_bytes=int(1.5 * 2**30), + extracted_bytes=1 * 2**30, + total_bytes=2 * 2**30, + is_total_bytes_final=True, + ), + "1.00/1.50/2.00GB", + id="gb_final", + ), + pytest.param( + ImagePullProgress( + downloaded_bytes=int(1.5 * 2**30), + extracted_bytes=1 * 2**30, + total_bytes=2 * 2**30, + is_total_bytes_final=False, + ), + "1.00/1.50/≥2.00GB", + id="gb_non_final", + ), + pytest.param( + ImagePullProgress( + downloaded_bytes=0, + extracted_bytes=0, + total_bytes=2**30, + is_total_bytes_final=True, + ), + "0.00/0.00/1.00GB", + id="gb_boundary", + ), + ], +) +def test_format_pull_progress(progress: ImagePullProgress, expected: str) -> None: + assert _format_pull_progress(progress) == expected diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index a3423f8cf..ea485f3ec 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -26,6 +26,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy from dstack._internal.core.models.runs import ( + ImagePullProgress, JobRuntimeData, JobStatus, JobTerminationReason, @@ -114,6 +115,7 @@ def ssh_tunnel_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: mock = Mock(spec_set=ShimClient) mock.healthcheck.return_value = HealthcheckResponse(service="dstack-shim", version="latest") + mock.get_task.return_value.image_pull_progress = None monkeypatch.setattr( "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) ) @@ -1088,6 +1090,42 @@ async def test_pulling_shim_failed( assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE assert job.remove_at is None + async def test_pulling_shim_stores_pull_progress( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + progress = ImagePullProgress( + downloaded_bytes=512, extracted_bytes=0, total_bytes=1024, is_total_bytes_final=True + ) + shim_client_mock.get_task.return_value.status = TaskStatus.PULLING + shim_client_mock.get_task.return_value.image_pull_progress = progress + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.PULLING + assert job.image_pull_progress == progress.json() + async def test_provisioning_shim_force_stop_if_already_running_api_v1( self, monkeypatch: pytest.MonkeyPatch, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index ee8220bc6..7dbc567ca 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -562,6 +562,7 @@ def get_dev_env_run_dict( "job_provisioning_data": None, "job_runtime_data": None, "probes": [], + "image_pull_progress": None, } ], "job_connection_info": None, @@ -584,6 +585,7 @@ def get_dev_env_run_dict( "job_provisioning_data": None, "job_runtime_data": None, "probes": [], + "image_pull_progress": None, }, "cost": 0.0, "service": None, @@ -728,6 +730,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli "job_provisioning_data": None, "job_runtime_data": None, "probes": [], + "image_pull_progress": None, } ], "job_connection_info": None, @@ -750,6 +753,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli "job_provisioning_data": None, "job_runtime_data": None, "probes": [], + "image_pull_progress": None, }, "cost": 0, "service": None, @@ -919,6 +923,7 @@ async def test_limits_job_submissions( "job_provisioning_data": None, "job_runtime_data": None, "probes": [], + "image_pull_progress": None, } ], "job_connection_info": None, @@ -941,6 +946,7 @@ async def test_limits_job_submissions( "job_provisioning_data": None, "job_runtime_data": None, "probes": [], + "image_pull_progress": None, }, "cost": 0, "service": None,