Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions runner/internal/shim/api/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type TaskInfoResponse struct {
TerminationReason string `json:"termination_reason"`
TerminationMessage string `json:"termination_message"`
Ports []shim.PortMapping `json:"ports"`

ImagePullProgress *shim.ImagePullProgress `json:"image_pull_progress"`

// The following fields are for debugging only, server doesn't need them
ContainerName string `json:"container_name"`
ContainerID string `json:"container_id"`
Expand Down
125 changes: 89 additions & 36 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
rt "runtime"
"strconv"
"strings"
"sync"
"time"

"github.com/docker/docker/api/types/container"
Expand Down Expand Up @@ -50,6 +51,86 @@ const (
LabelValueTrue = "true"
)

// dockerd reports pulling progress as a stream of JSON Lines. The format of records is not documented in the API documentation,
// although it's occasionally mentioned, e.g., https://docs.docker.com/reference/api/engine/version-history/#v148-api-changes
// https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/pkg/jsonmessage/jsonmessage.go#L144
// All fields are optional.
type PullMessage struct {
Id string `json:"id"` // layer id
Status string `json:"status"`
ProgressDetail struct {
Current uint64 `json:"current"` // bytes
Total uint64 `json:"total"` // bytes
} `json:"progressDetail"`
ErrorDetail struct {
Message string `json:"message"`
} `json:"errorDetail"`
}

type layerProgress struct {
Status string
DownloadedBytes uint64
ExtractedBytes uint64
TotalBytes uint64
}

type PullTracker struct {
mu sync.RWMutex
layers map[string]layerProgress
}

func newPullTracker() *PullTracker {
return &PullTracker{layers: make(map[string]layerProgress)}
}

func (t *PullTracker) Update(msg PullMessage) {
if msg.Id == "" {
return
}
t.mu.Lock()
defer t.mu.Unlock()
layer := t.layers[msg.Id]
switch msg.Status {
case "Pulling fs layer", "Waiting", "Verifying Checksum", "Already exists":
// no bytes to update, just track status
case "Downloading":
layer.DownloadedBytes = msg.ProgressDetail.Current
layer.TotalBytes = msg.ProgressDetail.Total
case "Download complete":
layer.DownloadedBytes = layer.TotalBytes
case "Extracting":
layer.ExtractedBytes = msg.ProgressDetail.Current
layer.DownloadedBytes = msg.ProgressDetail.Total
layer.TotalBytes = msg.ProgressDetail.Total
case "Pull complete":
layer.ExtractedBytes = layer.TotalBytes
layer.DownloadedBytes = layer.TotalBytes
default:
// Non-layer events, such as {"status":"Pulling from library/python","id":"3.11"}
return
}
layer.Status = msg.Status
t.layers[msg.Id] = layer
}

func (t *PullTracker) Progress() *ImagePullProgress {
t.mu.RLock()
defer t.mu.RUnlock()
if len(t.layers) == 0 {
return nil
}
p := ImagePullProgress{IsTotalBytesFinal: true}
for _, l := range t.layers {
if l.TotalBytes == 0 && l.Status != "Already exists" && l.Status != "Pull complete" {
p.IsTotalBytesFinal = false
}
p.DownloadedBytes += l.DownloadedBytes
p.ExtractedBytes += l.ExtractedBytes
p.TotalBytes += l.TotalBytes
}
return &p
}

type DockerRunner struct {
client *docker.Client
dockerParams DockerParameters
Expand Down Expand Up @@ -239,6 +320,7 @@ func (d *DockerRunner) TaskInfo(taskID string) TaskInfo {
ContainerName: task.containerName,
ContainerID: task.containerID,
GpuIDs: task.gpuIDs,
ImagePullProgress: task.pullTracker.Progress(),
}
}

Expand Down Expand Up @@ -350,7 +432,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
// Although it's called "runner dir", we also use it for shim task-related data.
// Maybe we should rename it to "task dir" (including the `/root/.dstack/runners` dir on the host).
pullLogPath := filepath.Join(runnerDir, "pull.log")
if err = pullImage(pullCtx, d.client, cfg, pullLogPath); err != nil {
if err = pullImage(pullCtx, d.client, cfg, pullLogPath, task.pullTracker); err != nil {
errMessage := fmt.Sprintf("pullImage error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage)
Expand Down Expand Up @@ -670,7 +752,7 @@ func mountDisk(ctx context.Context, deviceName, mountPoint string, fsRootPerms o
return nil
}

func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig, logPath string) error {
func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig, logPath string, tracker *PullTracker) error {
if !strings.Contains(taskConfig.ImageName, ":") {
taskConfig.ImageName += ":latest"
}
Expand Down Expand Up @@ -710,26 +792,6 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf

teeReader := io.TeeReader(reader, logFile)

current := make(map[string]uint)
total := make(map[string]uint)

// dockerd reports pulling progress as a stream of JSON Lines. The format of records is not documented in the API documentation,
// although it's occasionally mentioned, e.g., https://docs.docker.com/reference/api/engine/version-history/#v148-api-changes

// https://github.com/moby/moby/blob/e77ff99ede5ee5952b3a9227863552ae6e5b6fb1/pkg/jsonmessage/jsonmessage.go#L144
// All fields are optional
type PullMessage struct {
Id string `json:"id"` // layer id
Status string `json:"status"`
ProgressDetail struct {
Current uint `json:"current"` // bytes
Total uint `json:"total"` // bytes
} `json:"progressDetail"`
ErrorDetail struct {
Message string `json:"message"`
} `json:"errorDetail"`
}

var pullCompleted bool
pullErrors := make([]string, 0)

Expand All @@ -740,13 +802,7 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf
if err := json.Unmarshal(line, &pullMessage); err != nil {
continue
}
if pullMessage.Status == "Downloading" {
current[pullMessage.Id] = pullMessage.ProgressDetail.Current
total[pullMessage.Id] = pullMessage.ProgressDetail.Total
}
if pullMessage.Status == "Download complete" {
current[pullMessage.Id] = total[pullMessage.Id]
}
tracker.Update(pullMessage)
if pullMessage.ErrorDetail.Message != "" {
log.Error(ctx, "error pulling image", "name", taskConfig.ImageName, "err", pullMessage.ErrorDetail.Message)
pullErrors = append(pullErrors, pullMessage.ErrorDetail.Message)
Expand All @@ -764,13 +820,10 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf
}

duration := time.Since(startTime)
var currentBytes uint
var totalBytes uint
for _, v := range current {
currentBytes += v
}
for _, v := range total {
totalBytes += v
p := tracker.Progress()
var currentBytes, totalBytes uint64
if p != nil {
currentBytes, totalBytes = p.DownloadedBytes, p.TotalBytes
}
speed := bytesize.New(float64(currentBytes) / duration.Seconds())

Expand Down
109 changes: 109 additions & 0 deletions runner/internal/shim/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,112 @@ func createTaskConfig(t *testing.T) TaskConfig {
ImageName: "ubuntu",
}
}

func pullMsg(id, status string, current, total uint64) PullMessage {
m := PullMessage{Id: id, Status: status}
m.ProgressDetail.Current = current
m.ProgressDetail.Total = total
return m
}

func TestPullTracker_Empty(t *testing.T) {
tracker := newPullTracker()
assert.Nil(t, tracker.Progress())
}

func TestPullTracker_AlreadyExists(t *testing.T) {
tracker := newPullTracker()
tracker.Update(PullMessage{Id: "3.11", Status: "Pulling from library/python"})
for _, id := range []string{"aaa", "bbb", "ccc"} {
tracker.Update(PullMessage{Id: id, Status: "Already exists"})
}
tracker.Update(PullMessage{Status: "Digest: sha256:***"})
tracker.Update(PullMessage{Status: "Status: Image is up to date for python:3.11"})
p := tracker.Progress()
require.NotNil(t, p)
assert.Equal(t, uint64(0), p.DownloadedBytes)
assert.Equal(t, uint64(0), p.ExtractedBytes)
assert.Equal(t, uint64(0), p.TotalBytes)
assert.True(t, p.IsTotalBytesFinal)
}

func TestPullTracker_FullPull(t *testing.T) {
const sizeA, sizeB uint64 = 111, 222

tracker := newPullTracker()
tracker.Update(PullMessage{Id: "3.11", Status: "Pulling from library/python"})
tracker.Update(PullMessage{Id: "aaa", Status: "Pulling fs layer"})
tracker.Update(PullMessage{Id: "bbb", Status: "Pulling fs layer"})
tracker.Update(PullMessage{Id: "aaa", Status: "Waiting"})
tracker.Update(PullMessage{Id: "bbb", Status: "Waiting"})

// Layers announced but sizes unknown yet
p := tracker.Progress()
require.NotNil(t, p)
assert.Equal(t, uint64(0), p.DownloadedBytes)
assert.Equal(t, uint64(0), p.ExtractedBytes)
assert.Equal(t, uint64(0), p.TotalBytes)
assert.False(t, p.IsTotalBytesFinal)

// Both layers start downloading - sizes now known
tracker.Update(pullMsg("aaa", "Downloading", 100, sizeA))
tracker.Update(pullMsg("bbb", "Downloading", 200, sizeB))

p = tracker.Progress()
assert.Equal(t, uint64(300), p.DownloadedBytes)
assert.Equal(t, uint64(0), p.ExtractedBytes)
assert.Equal(t, sizeA+sizeB, p.TotalBytes)
assert.True(t, p.IsTotalBytesFinal)

// Downloads complete
tracker.Update(pullMsg("aaa", "Downloading", sizeA, sizeA))
tracker.Update(PullMessage{Id: "aaa", Status: "Download complete"})
tracker.Update(pullMsg("bbb", "Downloading", sizeB, sizeB))
tracker.Update(PullMessage{Id: "bbb", Status: "Download complete"})

p = tracker.Progress()
assert.Equal(t, sizeA+sizeB, p.DownloadedBytes)
assert.Equal(t, uint64(0), p.ExtractedBytes)
assert.Equal(t, sizeA+sizeB, p.TotalBytes)
assert.True(t, p.IsTotalBytesFinal)

// Both layers start extracting
tracker.Update(pullMsg("aaa", "Extracting", 100, sizeA))
tracker.Update(pullMsg("bbb", "Extracting", 200, sizeB))

p = tracker.Progress()
assert.Equal(t, sizeA+sizeB, p.DownloadedBytes)
assert.Equal(t, uint64(300), p.ExtractedBytes)
assert.Equal(t, sizeA+sizeB, p.TotalBytes)
assert.True(t, p.IsTotalBytesFinal)

// Extractions complete
tracker.Update(pullMsg("aaa", "Extracting", sizeA, sizeA))
tracker.Update(PullMessage{Id: "aaa", Status: "Pull complete"})
tracker.Update(pullMsg("bbb", "Extracting", sizeB, sizeB))
tracker.Update(PullMessage{Id: "bbb", Status: "Pull complete"})
tracker.Update(PullMessage{Status: "Digest: sha256:***"})
tracker.Update(PullMessage{Status: "Status: Downloaded newer image for python:3.11"})

p = tracker.Progress()
assert.Equal(t, sizeA+sizeB, p.DownloadedBytes)
assert.Equal(t, sizeA+sizeB, p.ExtractedBytes)
assert.Equal(t, sizeA+sizeB, p.TotalBytes)
assert.True(t, p.IsTotalBytesFinal)
}

func TestPullTracker_MixedLayerStatuses(t *testing.T) {
tracker := newPullTracker()

tracker.Update(PullMessage{Id: "layer-exists", Status: "Already exists"})
tracker.Update(pullMsg("layer-downloading", "Downloading", 50, 100))
tracker.Update(pullMsg("layer-extracting", "Extracting", 100, 200))
tracker.Update(PullMessage{Id: "layer-waiting", Status: "Waiting"})

p := tracker.Progress()
require.NotNil(t, p)
assert.Equal(t, uint64(50+200), p.DownloadedBytes)
assert.Equal(t, uint64(100), p.ExtractedBytes)
assert.Equal(t, uint64(100+200), p.TotalBytes)
assert.False(t, p.IsTotalBytesFinal) // layer-waiting size unknown
}
8 changes: 8 additions & 0 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,20 @@ type TaskListItem struct {
Status TaskStatus `json:"status"`
}

type ImagePullProgress struct {
DownloadedBytes uint64 `json:"downloaded_bytes"`
ExtractedBytes uint64 `json:"extracted_bytes"`
TotalBytes uint64 `json:"total_bytes"`
IsTotalBytesFinal bool `json:"is_total_bytes_final"`
}

type TaskInfo struct {
ID string
Status TaskStatus
TerminationReason string
TerminationMessage string
Ports []PortMapping
ImagePullProgress *ImagePullProgress
ContainerName string
ContainerID string
GpuIDs []string
Expand Down
4 changes: 4 additions & 0 deletions runner/internal/shim/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type Task struct {
ports []PortMapping
runnerDir string // path on host mapped to consts.RunnerDir in container

pullTracker *PullTracker

mu *sync.Mutex
}

Expand Down Expand Up @@ -128,6 +130,7 @@ func NewTask(id string, status TaskStatus, containerName string, containerID str
runnerDir: runnerDir,
gpuIDs: gpuIDs,
ports: ports,
pullTracker: newPullTracker(),
mu: &sync.Mutex{},
}
}
Expand All @@ -138,6 +141,7 @@ func NewTaskFromConfig(cfg TaskConfig) Task {
Status: TaskStatusPending,
config: cfg,
containerName: generateUniqueName(cfg.Name, cfg.ID),
pullTracker: newPullTracker(),
mu: &sync.Mutex{},
}
}
Expand Down
23 changes: 22 additions & 1 deletion src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TerminationPolicy,
)
from dstack._internal.core.models.runs import (
ImagePullProgress,
Job,
JobStatus,
JobSubmission,
Expand Down Expand Up @@ -316,12 +317,32 @@ def _format_job_submission_status(job_submission: JobSubmission, verbose: bool)
color = color_map.get(job_status, "white")
status_style = f"bold {color}" if not job_status.is_finished() else color
formatted_status_message = f"[{status_style}]{status_message}[/]"
if job_status == JobStatus.PULLING and job_submission.image_pull_progress is not None:
formatted_status_message += (
f" [secondary]{_format_pull_progress(job_submission.image_pull_progress)}[/]"
)
if verbose and job_submission.inactivity_secs:
inactive_for = format_duration_multiunit(job_submission.inactivity_secs)
formatted_status_message += f" (inactive for {inactive_for})"
return formatted_status_message


def _format_pull_progress(progress: ImagePullProgress) -> str:
if progress.total_bytes >= 2**30: # 1GB
unit = "GB"

def f(x: int) -> str:
return f"{x / 2**30:.2f}"
else:
unit = "MB"

def f(x: int) -> str:
return f"{x / 2**20:.0f}"

total_sign = "≥" if not progress.is_total_bytes_final else ""
return f"{f(progress.extracted_bytes)}/{f(progress.downloaded_bytes)}/{total_sign}{f(progress.total_bytes)}{unit}"


def _get_show_deployment_replica_job(run: CoreRun, verbose: bool) -> tuple[bool, bool, bool]:
show_deployment_num = (
verbose and run.run_spec.configuration.type == "service"
Expand Down Expand Up @@ -434,7 +455,7 @@ def get_runs_table(
else:
table.add_column("GPU", ratio=2)
table.add_column("PRICE", style="grey58", ratio=1)
table.add_column("STATUS", no_wrap=True, ratio=1)
table.add_column("STATUS", ratio=1)
if verbose or any(
run._run.is_deployment_in_progress()
and any(job.job_submissions[-1].probes for job in run._run.jobs)
Expand Down
Loading
Loading