Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions cmd/cli/commands/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/docker/model-runner/cmd/cli/commands/formatter"
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/types"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -258,10 +259,10 @@ func appendRow(table *tablewriter.Table, tag string, model dmrm.Model) {
// Strip default "ai/" prefix and ":latest" tag for display
displayTag := stripDefaultsFromModelName(tag)
contextSize := ""
if model.Config.ContextSize != nil {
contextSize = fmt.Sprintf("%d", *model.Config.ContextSize)
} else if model.Config.GGUF != nil {
if v, ok := model.Config.GGUF["llama.context_length"]; ok {
if model.Config.GetContextSize() != nil {
contextSize = fmt.Sprintf("%d", *model.Config.GetContextSize())
} else if dockerConfig, ok := model.Config.(*types.Config); ok && dockerConfig.GGUF != nil {
if v, ok := dockerConfig.GGUF["llama.context_length"]; ok {
if parsed, err := strconv.ParseUint(v, 10, 64); err == nil {
contextSize = fmt.Sprintf("%d", parsed)
} else {
Expand All @@ -272,13 +273,13 @@ func appendRow(table *tablewriter.Table, tag string, model dmrm.Model) {

table.Append([]string{
displayTag,
model.Config.Parameters,
model.Config.Quantization,
model.Config.Architecture,
model.Config.GetParameters(),
model.Config.GetQuantization(),
model.Config.GetArchitecture(),
model.ID[7:19],
units.HumanDuration(time.Since(time.Unix(model.Created, 0))) + " ago",
contextSize,
model.Config.Size,
model.Config.GetSize(),
})
}

Expand Down
10 changes: 5 additions & 5 deletions cmd/cli/commands/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func testModel(id string, tags []string, created int64) dmrm.Model {
ID: id,
Tags: tags,
Created: created,
Config: types.Config{
Config: &types.Config{
Parameters: "7B",
Quantization: "Q4_0",
Architecture: "llama",
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestListModelsSingleModel(t *testing.T) {
ID: "sha256:123456789012345678901234567890123456789012345678901234567890abcd",
Tags: []string{"single:latest"},
Created: 1000,
Config: types.Config{
Config: &types.Config{
Parameters: "7B",
Quantization: "Q4_0",
Architecture: "llama",
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestPrettyPrintModelsWithSortedInput(t *testing.T) {
ID: "sha256:123456789012345678901234567890123456789012345678901234567890abcd",
Tags: []string{"ai/apple:latest"},
Created: 1000,
Config: types.Config{
Config: &types.Config{
Parameters: "7B",
Quantization: "Q4_0",
Architecture: "llama",
Expand All @@ -245,7 +245,7 @@ func TestPrettyPrintModelsWithSortedInput(t *testing.T) {
ID: "sha256:223456789012345678901234567890123456789012345678901234567890abcd",
Tags: []string{"ai/banana:v1"},
Created: 2000,
Config: types.Config{
Config: &types.Config{
Parameters: "13B",
Quantization: "Q4_K_M",
Architecture: "llama",
Expand Down Expand Up @@ -282,7 +282,7 @@ func TestPrettyPrintModelsWithMultipleTags(t *testing.T) {
ID: "sha256:123456789012345678901234567890123456789012345678901234567890abcd",
Tags: []string{"qwen3:8B-Q4_K_M", "qwen3:latest", "qwen3:0.6B-F16"},
Created: 1000,
Config: types.Config{
Config: &types.Config{
Parameters: "8B",
Quantization: "Q4_K_M",
Architecture: "qwen3",
Expand Down
8 changes: 4 additions & 4 deletions cmd/mdltool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,10 @@ func cmdGet(client *distribution.Client, args []string) int {
fmt.Fprintf(os.Stderr, "Error reading model config: %v\n", err)
return 1
}
fmt.Printf("Format: %s\n", cfg.Format)
fmt.Printf("Architecture: %s\n", cfg.Architecture)
fmt.Printf("Parameters: %s\n", cfg.Parameters)
fmt.Printf("Quantization: %s\n", cfg.Quantization)
fmt.Printf("Format: %s\n", cfg.GetFormat())
fmt.Printf("Architecture: %s\n", cfg.GetArchitecture())
fmt.Printf("Parameters: %s\n", cfg.GetParameters())
fmt.Printf("Quantization: %s\n", cfg.GetQuantization())
return 0
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/distribution/builder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ func TestWithMultimodalProjectorChaining(t *testing.T) {
t.Fatalf("Failed to get config: %v", err)
}

if config.ContextSize == nil || *config.ContextSize != 4096 {
t.Errorf("Expected context size 4096, got %v", config.ContextSize)
if config.GetContextSize() == nil || *config.GetContextSize() != 4096 {
t.Errorf("Expected context size 4096, got %v", config.GetContextSize())
}

// Note: We can't directly test GGUFPath() and MMPROJPath() on ModelArtifact interface
Expand Down Expand Up @@ -172,8 +172,8 @@ func TestFromModel(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get initial config: %v", err)
}
if initialConfig.ContextSize == nil || *initialConfig.ContextSize != 2048 {
t.Fatalf("Expected initial context size 2048, got %v", initialConfig.ContextSize)
if initialConfig.GetContextSize() == nil || *initialConfig.GetContextSize() != 2048 {
t.Fatalf("Expected initial context size 2048, got %v", initialConfig.GetContextSize())
}

// Step 2: Use FromModel() to create a new builder from the existing model
Expand All @@ -197,8 +197,8 @@ func TestFromModel(t *testing.T) {
t.Fatalf("Failed to get repackaged config: %v", err)
}

if repackagedConfig.ContextSize == nil || *repackagedConfig.ContextSize != 4096 {
t.Errorf("Expected repackaged context size 4096, got %v", repackagedConfig.ContextSize)
if repackagedConfig.GetContextSize() == nil || *repackagedConfig.GetContextSize() != 4096 {
t.Errorf("Expected repackaged context size 4096, got %v", repackagedConfig.GetContextSize())
}

// Step 6: Verify the original layers are preserved
Expand Down
6 changes: 3 additions & 3 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
return fmt.Errorf("getting cached model config: %w", err)
}

err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.Size))
err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize()))
if err != nil {
c.log.Warnf("Writing progress: %v", err)
}
Expand Down Expand Up @@ -623,10 +623,10 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string,
return fmt.Errorf("reading model config: %w", err)
}

if config.Format == "" {
if config.GetFormat() == "" {
log.Warnf("Model format field is empty for %s, unable to verify format compatibility",
utils.SanitizeForLog(reference))
} else if !slices.Contains(GetSupportedFormats(), config.Format) {
} else if !slices.Contains(GetSupportedFormats(), config.GetFormat()) {
// Write warning but continue with pull
log.Warnln(warnUnsupportedFormat)
if err := progress.WriteWarning(progressWriter, warnUnsupportedFormat); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions pkg/distribution/internal/bundle/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Bundle struct {
mmprojPath string
ggufFile string // path to GGUF file (first shard when model is split among files)
safetensorsFile string // path to safetensors file (first shard when model is split among files)
runtimeConfig types.Config
runtimeConfig types.ModelConfig
chatTemplatePath string
}

Expand Down Expand Up @@ -60,6 +60,7 @@ func (b *Bundle) SafetensorsPath() string {
}

// RuntimeConfig returns config that should be respected by the backend at runtime.
func (b *Bundle) RuntimeConfig() types.Config {
// Can return either Docker format (*types.Config) or ModelPack format (*modelpack.Model).
func (b *Bundle) RuntimeConfig() types.ModelConfig {
return b.runtimeConfig
}
27 changes: 20 additions & 7 deletions pkg/distribution/internal/bundle/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path/filepath"

"github.com/docker/model-runner/pkg/distribution/modelpack"
"github.com/docker/model-runner/pkg/distribution/types"
)

Expand Down Expand Up @@ -60,17 +61,29 @@ func Parse(rootDir string) (*Bundle, error) {
}, nil
}

func parseRuntimeConfig(rootDir string) (types.Config, error) {
f, err := os.Open(filepath.Join(rootDir, "config.json"))
// parseRuntimeConfig parses the runtime config from the bundle.
// Natively supports both Docker format and ModelPack format without conversion.
func parseRuntimeConfig(rootDir string) (types.ModelConfig, error) {
raw, err := os.ReadFile(filepath.Join(rootDir, "config.json"))
if err != nil {
return types.Config{}, fmt.Errorf("open runtime config: %w", err)
return nil, fmt.Errorf("read runtime config: %w", err)
}
defer f.Close()

// Detect and parse based on format
if modelpack.IsModelPackConfig(raw) {
var mp modelpack.Model
if err := json.Unmarshal(raw, &mp); err != nil {
return nil, fmt.Errorf("decode ModelPack runtime config: %w", err)
}
return &mp, nil
}

// Docker format
var cfg types.Config
if err := json.NewDecoder(f).Decode(&cfg); err != nil {
return types.Config{}, fmt.Errorf("decode runtime config: %w", err)
if err := json.Unmarshal(raw, &cfg); err != nil {
return nil, fmt.Errorf("decode Docker runtime config: %w", err)
}
return cfg, nil
return &cfg, nil
}

func findGGUFFile(modelDir string) (string, error) {
Expand Down
56 changes: 32 additions & 24 deletions pkg/distribution/internal/gguf/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,31 @@ func TestGGUF(t *testing.T) {
}

t.Run("TestConfig", func(t *testing.T) {
cfg, err := mdl.Config()
cfgInterface, err := mdl.Config()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Format != types.FormatGGUF {
t.Fatalf("Unexpected format: got %s expected %s", cfg.Format, types.FormatGGUF)
if cfgInterface.GetFormat() != types.FormatGGUF {
t.Fatalf("Unexpected format: got %s expected %s", cfgInterface.GetFormat(), types.FormatGGUF)
}
if cfg.Parameters != "183" {
t.Fatalf("Unexpected parameters: got %s expected %s", cfg.Parameters, "183")
if cfgInterface.GetParameters() != "183" {
t.Fatalf("Unexpected parameters: got %s expected %s", cfgInterface.GetParameters(), "183")
}
if cfg.Architecture != "llama" {
t.Fatalf("Unexpected architecture: got %s expected %s", cfg.Parameters, "llama")
if cfgInterface.GetArchitecture() != "llama" {
t.Fatalf("Unexpected architecture: got %s expected %s", cfgInterface.GetArchitecture(), "llama")
}
if cfg.Quantization != "Unknown" { // todo: testdata with a real value
t.Fatalf("Unexpected quantization: got %s expected %s", cfg.Quantization, "Unknown")
if cfgInterface.GetQuantization() != "Unknown" { // todo: testdata with a real value
t.Fatalf("Unexpected quantization: got %s expected %s", cfgInterface.GetQuantization(), "Unknown")
}
if cfg.Size != "864B" {
t.Fatalf("Unexpected size: got %s expected %s", cfg.Size, "864B")
if cfgInterface.GetSize() != "864B" {
t.Fatalf("Unexpected size: got %s expected %s", cfgInterface.GetSize(), "864B")
}

// Test GGUF metadata
// Test GGUF metadata (Docker format specific)
cfg, ok := cfgInterface.(*types.Config)
if !ok {
t.Fatal("Expected *types.Config for GGUF model")
}
if cfg.GGUF == nil {
t.Fatal("Expected GGUF metadata to be present")
}
Expand Down Expand Up @@ -169,27 +173,31 @@ func TestGGUFShards(t *testing.T) {
}

t.Run("TestConfig", func(t *testing.T) {
cfg, err := mdl.Config()
cfgInterface, err := mdl.Config()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Format != types.FormatGGUF {
t.Fatalf("Unexpected format: got %s expected %s", cfg.Format, types.FormatGGUF)
if cfgInterface.GetFormat() != types.FormatGGUF {
t.Fatalf("Unexpected format: got %s expected %s", cfgInterface.GetFormat(), types.FormatGGUF)
}
if cfg.Parameters != "183" {
t.Fatalf("Unexpected parameters: got %s expected %s", cfg.Parameters, "183")
if cfgInterface.GetParameters() != "183" {
t.Fatalf("Unexpected parameters: got %s expected %s", cfgInterface.GetParameters(), "183")
}
if cfg.Architecture != "llama" {
t.Fatalf("Unexpected architecture: got %s expected %s", cfg.Parameters, "llama")
if cfgInterface.GetArchitecture() != "llama" {
t.Fatalf("Unexpected architecture: got %s expected %s", cfgInterface.GetArchitecture(), "llama")
}
if cfg.Quantization != "Unknown" { // todo: testdata with a real value
t.Fatalf("Unexpected quantization: got %s expected %s", cfg.Quantization, "Unknown")
if cfgInterface.GetQuantization() != "Unknown" { // todo: testdata with a real value
t.Fatalf("Unexpected quantization: got %s expected %s", cfgInterface.GetQuantization(), "Unknown")
}
if cfg.Size != "864B" {
t.Fatalf("Unexpected size: got %s expected %s", cfg.Size, "864B")
if cfgInterface.GetSize() != "864B" {
t.Fatalf("Unexpected size: got %s expected %s", cfgInterface.GetSize(), "864B")
}

// Test GGUF metadata
// Test GGUF metadata (Docker format specific)
cfg, ok := cfgInterface.(*types.Config)
if !ok {
t.Fatal("Expected *types.Config for GGUF model")
}
if cfg.GGUF == nil {
t.Fatal("Expected GGUF metadata to be present")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/mutate/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (m *model) ID() (string, error) {
return partial.ID(m)
}

func (m *model) Config() (types.Config, error) {
func (m *model) Config() (types.ModelConfig, error) {
return partial.Config(m)
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/distribution/internal/mutate/mutate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ func TestContextSize(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get config file: %v", err)
}
if cfg.ContextSize != nil {
t.Fatalf("Epected nil context size got %d", cfg.ContextSize)
if cfg.GetContextSize() != nil {
t.Fatalf("Epected nil context size got %d", *cfg.GetContextSize())
}

// set the context size
Expand All @@ -101,10 +101,10 @@ func TestContextSize(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get config file: %v", err)
}
if cfg2.ContextSize == nil {
if cfg2.GetContextSize() == nil {
t.Fatal("Expected non-nil context")
}
if *cfg2.ContextSize != 2096 {
t.Fatalf("Expected context size of 2096 got %d", *cfg2.ContextSize)
if *cfg2.GetContextSize() != 2096 {
t.Fatalf("Expected context size of 2096 got %d", *cfg2.GetContextSize())
}
}
2 changes: 1 addition & 1 deletion pkg/distribution/internal/partial/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (m *BaseModel) ID() (string, error) {
return ID(m)
}

func (m *BaseModel) Config() (types.Config, error) {
func (m *BaseModel) Config() (types.ModelConfig, error) {
return Config(m)
}

Expand Down
Loading
Loading