diff --git a/cmd/cli/commands/list.go b/cmd/cli/commands/list.go index 0c7f66750..b8d134d87 100644 --- a/cmd/cli/commands/list.go +++ b/cmd/cli/commands/list.go @@ -14,6 +14,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/formatter" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/pkg/standalone" + "github.com/docker/model-runner/pkg/distribution/types" dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" @@ -258,10 +259,10 @@ func appendRow(table *tablewriter.Table, tag string, model dmrm.Model) { // Strip default "ai/" prefix and ":latest" tag for display displayTag := stripDefaultsFromModelName(tag) contextSize := "" - if model.Config.ContextSize != nil { - contextSize = fmt.Sprintf("%d", *model.Config.ContextSize) - } else if model.Config.GGUF != nil { - if v, ok := model.Config.GGUF["llama.context_length"]; ok { + if model.Config.GetContextSize() != nil { + contextSize = fmt.Sprintf("%d", *model.Config.GetContextSize()) + } else if dockerConfig, ok := model.Config.(*types.Config); ok && dockerConfig.GGUF != nil { + if v, ok := dockerConfig.GGUF["llama.context_length"]; ok { if parsed, err := strconv.ParseUint(v, 10, 64); err == nil { contextSize = fmt.Sprintf("%d", parsed) } else { @@ -272,13 +273,13 @@ func appendRow(table *tablewriter.Table, tag string, model dmrm.Model) { table.Append([]string{ displayTag, - model.Config.Parameters, - model.Config.Quantization, - model.Config.Architecture, + model.Config.GetParameters(), + model.Config.GetQuantization(), + model.Config.GetArchitecture(), model.ID[7:19], units.HumanDuration(time.Since(time.Unix(model.Created, 0))) + " ago", contextSize, - model.Config.Size, + model.Config.GetSize(), }) } diff --git a/cmd/cli/commands/list_test.go b/cmd/cli/commands/list_test.go index 168af5188..95547df59 100644 --- a/cmd/cli/commands/list_test.go +++ b/cmd/cli/commands/list_test.go @@ -15,7 +15,7 @@ func testModel(id string, tags []string, created int64) dmrm.Model { ID: id, Tags: tags, Created: created, - Config: types.Config{ + Config: &types.Config{ Parameters: "7B", Quantization: "Q4_0", Architecture: "llama", @@ -177,7 +177,7 @@ func TestListModelsSingleModel(t *testing.T) { ID: "sha256:123456789012345678901234567890123456789012345678901234567890abcd", Tags: []string{"single:latest"}, Created: 1000, - Config: types.Config{ + Config: &types.Config{ Parameters: "7B", Quantization: "Q4_0", Architecture: "llama", @@ -234,7 +234,7 @@ func TestPrettyPrintModelsWithSortedInput(t *testing.T) { ID: "sha256:123456789012345678901234567890123456789012345678901234567890abcd", Tags: []string{"ai/apple:latest"}, Created: 1000, - Config: types.Config{ + Config: &types.Config{ Parameters: "7B", Quantization: "Q4_0", Architecture: "llama", @@ -245,7 +245,7 @@ func TestPrettyPrintModelsWithSortedInput(t *testing.T) { ID: "sha256:223456789012345678901234567890123456789012345678901234567890abcd", Tags: []string{"ai/banana:v1"}, Created: 2000, - Config: types.Config{ + Config: &types.Config{ Parameters: "13B", Quantization: "Q4_K_M", Architecture: "llama", @@ -282,7 +282,7 @@ func TestPrettyPrintModelsWithMultipleTags(t *testing.T) { ID: "sha256:123456789012345678901234567890123456789012345678901234567890abcd", Tags: []string{"qwen3:8B-Q4_K_M", "qwen3:latest", "qwen3:0.6B-F16"}, Created: 1000, - Config: types.Config{ + Config: &types.Config{ Parameters: "8B", Quantization: "Q4_K_M", Architecture: "qwen3", diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 3c0c976d2..787b9c49c 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -519,10 +519,10 @@ func cmdGet(client *distribution.Client, args []string) int { fmt.Fprintf(os.Stderr, "Error reading model config: %v\n", err) return 1 } - fmt.Printf("Format: %s\n", cfg.Format) - fmt.Printf("Architecture: %s\n", cfg.Architecture) - fmt.Printf("Parameters: %s\n", cfg.Parameters) - fmt.Printf("Quantization: %s\n", cfg.Quantization) + fmt.Printf("Format: %s\n", cfg.GetFormat()) + fmt.Printf("Architecture: %s\n", cfg.GetArchitecture()) + fmt.Printf("Parameters: %s\n", cfg.GetParameters()) + fmt.Printf("Quantization: %s\n", cfg.GetQuantization()) return 0 } diff --git a/pkg/distribution/builder/builder_test.go b/pkg/distribution/builder/builder_test.go index bc379c6c5..b6b5cca01 100644 --- a/pkg/distribution/builder/builder_test.go +++ b/pkg/distribution/builder/builder_test.go @@ -137,8 +137,8 @@ func TestWithMultimodalProjectorChaining(t *testing.T) { t.Fatalf("Failed to get config: %v", err) } - if config.ContextSize == nil || *config.ContextSize != 4096 { - t.Errorf("Expected context size 4096, got %v", config.ContextSize) + if config.GetContextSize() == nil || *config.GetContextSize() != 4096 { + t.Errorf("Expected context size 4096, got %v", config.GetContextSize()) } // Note: We can't directly test GGUFPath() and MMPROJPath() on ModelArtifact interface @@ -172,8 +172,8 @@ func TestFromModel(t *testing.T) { if err != nil { t.Fatalf("Failed to get initial config: %v", err) } - if initialConfig.ContextSize == nil || *initialConfig.ContextSize != 2048 { - t.Fatalf("Expected initial context size 2048, got %v", initialConfig.ContextSize) + if initialConfig.GetContextSize() == nil || *initialConfig.GetContextSize() != 2048 { + t.Fatalf("Expected initial context size 2048, got %v", initialConfig.GetContextSize()) } // Step 2: Use FromModel() to create a new builder from the existing model @@ -197,8 +197,8 @@ func TestFromModel(t *testing.T) { t.Fatalf("Failed to get repackaged config: %v", err) } - if repackagedConfig.ContextSize == nil || *repackagedConfig.ContextSize != 4096 { - t.Errorf("Expected repackaged context size 4096, got %v", repackagedConfig.ContextSize) + if repackagedConfig.GetContextSize() == nil || *repackagedConfig.GetContextSize() != 4096 { + t.Errorf("Expected repackaged context size 4096, got %v", repackagedConfig.GetContextSize()) } // Step 6: Verify the original layers are preserved diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 9382e87ac..ea7455d35 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -350,7 +350,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter return fmt.Errorf("getting cached model config: %w", err) } - err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.Size)) + err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %s", cfg.GetSize())) if err != nil { c.log.Warnf("Writing progress: %v", err) } @@ -623,10 +623,10 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, return fmt.Errorf("reading model config: %w", err) } - if config.Format == "" { + if config.GetFormat() == "" { log.Warnf("Model format field is empty for %s, unable to verify format compatibility", utils.SanitizeForLog(reference)) - } else if !slices.Contains(GetSupportedFormats(), config.Format) { + } else if !slices.Contains(GetSupportedFormats(), config.GetFormat()) { // Write warning but continue with pull log.Warnln(warnUnsupportedFormat) if err := progress.WriteWarning(progressWriter, warnUnsupportedFormat); err != nil { diff --git a/pkg/distribution/internal/bundle/bundle.go b/pkg/distribution/internal/bundle/bundle.go index 4cae36e49..b17b241d0 100644 --- a/pkg/distribution/internal/bundle/bundle.go +++ b/pkg/distribution/internal/bundle/bundle.go @@ -17,7 +17,7 @@ type Bundle struct { mmprojPath string ggufFile string // path to GGUF file (first shard when model is split among files) safetensorsFile string // path to safetensors file (first shard when model is split among files) - runtimeConfig types.Config + runtimeConfig types.ModelConfig chatTemplatePath string } @@ -60,6 +60,7 @@ func (b *Bundle) SafetensorsPath() string { } // RuntimeConfig returns config that should be respected by the backend at runtime. -func (b *Bundle) RuntimeConfig() types.Config { +// Can return either Docker format (*types.Config) or ModelPack format (*modelpack.Model). +func (b *Bundle) RuntimeConfig() types.ModelConfig { return b.runtimeConfig } diff --git a/pkg/distribution/internal/bundle/parse.go b/pkg/distribution/internal/bundle/parse.go index 543c7a00e..4e2cf49ad 100644 --- a/pkg/distribution/internal/bundle/parse.go +++ b/pkg/distribution/internal/bundle/parse.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -60,17 +61,29 @@ func Parse(rootDir string) (*Bundle, error) { }, nil } -func parseRuntimeConfig(rootDir string) (types.Config, error) { - f, err := os.Open(filepath.Join(rootDir, "config.json")) +// parseRuntimeConfig parses the runtime config from the bundle. +// Natively supports both Docker format and ModelPack format without conversion. +func parseRuntimeConfig(rootDir string) (types.ModelConfig, error) { + raw, err := os.ReadFile(filepath.Join(rootDir, "config.json")) if err != nil { - return types.Config{}, fmt.Errorf("open runtime config: %w", err) + return nil, fmt.Errorf("read runtime config: %w", err) } - defer f.Close() + + // Detect and parse based on format + if modelpack.IsModelPackConfig(raw) { + var mp modelpack.Model + if err := json.Unmarshal(raw, &mp); err != nil { + return nil, fmt.Errorf("decode ModelPack runtime config: %w", err) + } + return &mp, nil + } + + // Docker format var cfg types.Config - if err := json.NewDecoder(f).Decode(&cfg); err != nil { - return types.Config{}, fmt.Errorf("decode runtime config: %w", err) + if err := json.Unmarshal(raw, &cfg); err != nil { + return nil, fmt.Errorf("decode Docker runtime config: %w", err) } - return cfg, nil + return &cfg, nil } func findGGUFFile(modelDir string) (string, error) { diff --git a/pkg/distribution/internal/gguf/model_test.go b/pkg/distribution/internal/gguf/model_test.go index 071dddc50..3e15c5f61 100644 --- a/pkg/distribution/internal/gguf/model_test.go +++ b/pkg/distribution/internal/gguf/model_test.go @@ -18,27 +18,31 @@ func TestGGUF(t *testing.T) { } t.Run("TestConfig", func(t *testing.T) { - cfg, err := mdl.Config() + cfgInterface, err := mdl.Config() if err != nil { t.Fatalf("Failed to get config: %v", err) } - if cfg.Format != types.FormatGGUF { - t.Fatalf("Unexpected format: got %s expected %s", cfg.Format, types.FormatGGUF) + if cfgInterface.GetFormat() != types.FormatGGUF { + t.Fatalf("Unexpected format: got %s expected %s", cfgInterface.GetFormat(), types.FormatGGUF) } - if cfg.Parameters != "183" { - t.Fatalf("Unexpected parameters: got %s expected %s", cfg.Parameters, "183") + if cfgInterface.GetParameters() != "183" { + t.Fatalf("Unexpected parameters: got %s expected %s", cfgInterface.GetParameters(), "183") } - if cfg.Architecture != "llama" { - t.Fatalf("Unexpected architecture: got %s expected %s", cfg.Parameters, "llama") + if cfgInterface.GetArchitecture() != "llama" { + t.Fatalf("Unexpected architecture: got %s expected %s", cfgInterface.GetArchitecture(), "llama") } - if cfg.Quantization != "Unknown" { // todo: testdata with a real value - t.Fatalf("Unexpected quantization: got %s expected %s", cfg.Quantization, "Unknown") + if cfgInterface.GetQuantization() != "Unknown" { // todo: testdata with a real value + t.Fatalf("Unexpected quantization: got %s expected %s", cfgInterface.GetQuantization(), "Unknown") } - if cfg.Size != "864B" { - t.Fatalf("Unexpected size: got %s expected %s", cfg.Size, "864B") + if cfgInterface.GetSize() != "864B" { + t.Fatalf("Unexpected size: got %s expected %s", cfgInterface.GetSize(), "864B") } - // Test GGUF metadata + // Test GGUF metadata (Docker format specific) + cfg, ok := cfgInterface.(*types.Config) + if !ok { + t.Fatal("Expected *types.Config for GGUF model") + } if cfg.GGUF == nil { t.Fatal("Expected GGUF metadata to be present") } @@ -169,27 +173,31 @@ func TestGGUFShards(t *testing.T) { } t.Run("TestConfig", func(t *testing.T) { - cfg, err := mdl.Config() + cfgInterface, err := mdl.Config() if err != nil { t.Fatalf("Failed to get config: %v", err) } - if cfg.Format != types.FormatGGUF { - t.Fatalf("Unexpected format: got %s expected %s", cfg.Format, types.FormatGGUF) + if cfgInterface.GetFormat() != types.FormatGGUF { + t.Fatalf("Unexpected format: got %s expected %s", cfgInterface.GetFormat(), types.FormatGGUF) } - if cfg.Parameters != "183" { - t.Fatalf("Unexpected parameters: got %s expected %s", cfg.Parameters, "183") + if cfgInterface.GetParameters() != "183" { + t.Fatalf("Unexpected parameters: got %s expected %s", cfgInterface.GetParameters(), "183") } - if cfg.Architecture != "llama" { - t.Fatalf("Unexpected architecture: got %s expected %s", cfg.Parameters, "llama") + if cfgInterface.GetArchitecture() != "llama" { + t.Fatalf("Unexpected architecture: got %s expected %s", cfgInterface.GetArchitecture(), "llama") } - if cfg.Quantization != "Unknown" { // todo: testdata with a real value - t.Fatalf("Unexpected quantization: got %s expected %s", cfg.Quantization, "Unknown") + if cfgInterface.GetQuantization() != "Unknown" { // todo: testdata with a real value + t.Fatalf("Unexpected quantization: got %s expected %s", cfgInterface.GetQuantization(), "Unknown") } - if cfg.Size != "864B" { - t.Fatalf("Unexpected size: got %s expected %s", cfg.Size, "864B") + if cfgInterface.GetSize() != "864B" { + t.Fatalf("Unexpected size: got %s expected %s", cfgInterface.GetSize(), "864B") } - // Test GGUF metadata + // Test GGUF metadata (Docker format specific) + cfg, ok := cfgInterface.(*types.Config) + if !ok { + t.Fatal("Expected *types.Config for GGUF model") + } if cfg.GGUF == nil { t.Fatal("Expected GGUF metadata to be present") } diff --git a/pkg/distribution/internal/mutate/model.go b/pkg/distribution/internal/mutate/model.go index 9c6207ae3..106f8b8a0 100644 --- a/pkg/distribution/internal/mutate/model.go +++ b/pkg/distribution/internal/mutate/model.go @@ -26,7 +26,7 @@ func (m *model) ID() (string, error) { return partial.ID(m) } -func (m *model) Config() (types.Config, error) { +func (m *model) Config() (types.ModelConfig, error) { return partial.Config(m) } diff --git a/pkg/distribution/internal/mutate/mutate_test.go b/pkg/distribution/internal/mutate/mutate_test.go index d4089fda6..5a46de1a4 100644 --- a/pkg/distribution/internal/mutate/mutate_test.go +++ b/pkg/distribution/internal/mutate/mutate_test.go @@ -89,8 +89,8 @@ func TestContextSize(t *testing.T) { if err != nil { t.Fatalf("Failed to get config file: %v", err) } - if cfg.ContextSize != nil { - t.Fatalf("Epected nil context size got %d", cfg.ContextSize) + if cfg.GetContextSize() != nil { + t.Fatalf("Epected nil context size got %d", *cfg.GetContextSize()) } // set the context size @@ -101,10 +101,10 @@ func TestContextSize(t *testing.T) { if err != nil { t.Fatalf("Failed to get config file: %v", err) } - if cfg2.ContextSize == nil { + if cfg2.GetContextSize() == nil { t.Fatal("Expected non-nil context") } - if *cfg2.ContextSize != 2096 { - t.Fatalf("Expected context size of 2096 got %d", *cfg2.ContextSize) + if *cfg2.GetContextSize() != 2096 { + t.Fatalf("Expected context size of 2096 got %d", *cfg2.GetContextSize()) } } diff --git a/pkg/distribution/internal/partial/model.go b/pkg/distribution/internal/partial/model.go index dc2945060..38610a04a 100644 --- a/pkg/distribution/internal/partial/model.go +++ b/pkg/distribution/internal/partial/model.go @@ -89,7 +89,7 @@ func (m *BaseModel) ID() (string, error) { return ID(m) } -func (m *BaseModel) Config() (types.Config, error) { +func (m *BaseModel) Config() (types.ModelConfig, error) { return Config(m) } diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index e69e7977b..9ba9c232f 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" + "github.com/docker/model-runner/pkg/distribution/modelpack" "github.com/docker/model-runner/pkg/distribution/types" v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/partial" @@ -15,25 +16,43 @@ type WithRawConfigFile interface { RawConfigFile() ([]byte, error) } -func ConfigFile(i WithRawConfigFile) (*types.ConfigFile, error) { +// Config returns the model configuration. Returns *types.Config for Docker format +// or *modelpack.Model for ModelPack format, without any conversion. +func Config(i WithRawConfigFile) (types.ModelConfig, error) { raw, err := i.RawConfigFile() if err != nil { return nil, fmt.Errorf("get raw config file: %w", err) } + + // ModelPack format: parse directly into modelpack.Model without conversion + if modelpack.IsModelPackConfig(raw) { + var mp modelpack.Model + if err := json.Unmarshal(raw, &mp); err != nil { + return nil, fmt.Errorf("unmarshal modelpack config: %w", err) + } + return &mp, nil + } + + // Docker format: parse into types.Config var cf types.ConfigFile if err := json.Unmarshal(raw, &cf); err != nil { - return nil, fmt.Errorf("unmarshal : %w", err) + return nil, fmt.Errorf("unmarshal config: %w", err) } - return &cf, nil + return &cf.Config, nil } -// Config returns the types.Config for the model. -func Config(i WithRawConfigFile) (types.Config, error) { - cf, err := ConfigFile(i) +// ConfigFile returns the full Docker format config file (only for Docker format models). +func ConfigFile(i WithRawConfigFile) (*types.ConfigFile, error) { + raw, err := i.RawConfigFile() if err != nil { - return types.Config{}, fmt.Errorf("config file: %w", err) + return nil, fmt.Errorf("get raw config file: %w", err) } - return cf.Config, nil + + var cf types.ConfigFile + if err := json.Unmarshal(raw, &cf); err != nil { + return nil, fmt.Errorf("unmarshal config: %w", err) + } + return &cf, nil } // Descriptor returns the types.Descriptor for the model. @@ -117,7 +136,8 @@ func ConfigArchivePath(i WithLayers) (string, error) { return paths[0], err } -// layerPathsByMediaType is a generic helper function that finds a layer by media type and returns its path +// layerPathsByMediaType is a generic helper function that finds a layer by media type and returns its path. +// Natively supports both Docker and ModelPack media types without any conversion. func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, error) { layers, err := i.Layers() if err != nil { @@ -126,7 +146,10 @@ func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, er var paths []string for _, l := range layers { mt, err := l.MediaType() - if err != nil || mt != mediaType { + if err != nil { + continue + } + if !matchesMediaType(mt, mediaType) { continue } layer, ok := l.(*Layer) @@ -138,6 +161,29 @@ func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, er return paths, nil } +// matchesMediaType checks if a layer media type matches the target type. +// Natively supports both Docker and ModelPack formats without any conversion. +func matchesMediaType(layerMT, targetMT ggcr.MediaType) bool { + // Exact match + if layerMT == targetMT { + return true + } + + // Native ModelPack support: check equivalent ModelPack types + //nolint:exhaustive // Only GGUF and Safetensors need cross-format matching + switch targetMT { + case types.MediaTypeGGUF: + // ModelPack GGUF layers also match Docker GGUF target + return layerMT == ggcr.MediaType(modelpack.MediaTypeWeightGGUF) + case types.MediaTypeSafetensors: + // ModelPack safetensors layers also match Docker safetensors target + return layerMT == ggcr.MediaType(modelpack.MediaTypeWeightSafetensors) + default: + // Other media types have no cross-format equivalents + return false + } +} + func ManifestForLayers(i WithLayers) (*v1.Manifest, error) { cfgLayer, err := partial.ConfigLayer(i) if err != nil { diff --git a/pkg/distribution/internal/partial/partial_test.go b/pkg/distribution/internal/partial/partial_test.go index 40337e68d..8a56384c0 100644 --- a/pkg/distribution/internal/partial/partial_test.go +++ b/pkg/distribution/internal/partial/partial_test.go @@ -8,8 +8,108 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/mutate" "github.com/docker/model-runner/pkg/distribution/internal/partial" "github.com/docker/model-runner/pkg/distribution/types" + ggcr "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/types" ) +// mockConfig is used to test ConfigFile and Config functions +type mockConfig struct { + raw []byte + err error +} + +func (m *mockConfig) RawConfigFile() ([]byte, error) { + return m.raw, m.err +} + +// TestConfig_NativeFormatSupport tests that Config() returns native format without conversion +func TestConfig_NativeFormatSupport(t *testing.T) { + t.Run("Docker format returns *types.Config", func(t *testing.T) { + // Docker format config + dockerJSON := `{ + "config": {"format": "gguf", "parameters": "8B"}, + "descriptor": {}, + "rootfs": {"type": "layers", "diff_ids": []} + }` + + mock := &mockConfig{raw: []byte(dockerJSON)} + cfg, err := partial.Config(mock) + if err != nil { + t.Fatalf("Config() error = %v", err) + } + + if cfg.GetFormat() != types.FormatGGUF { + t.Errorf("GetFormat() = %v, want %v", cfg.GetFormat(), types.FormatGGUF) + } + if cfg.GetParameters() != "8B" { + t.Errorf("GetParameters() = %q, want %q", cfg.GetParameters(), "8B") + } + }) + + t.Run("ModelPack format returns *modelpack.Model without conversion", func(t *testing.T) { + // ModelPack format config (uses paramSize not parameters) + modelPackJSON := `{ + "descriptor": {"createdAt": "2025-01-15T10:30:00Z"}, + "config": {"format": "gguf", "paramSize": "8B"}, + "modelfs": {"type": "layers", "diffIds": []} + }` + + mock := &mockConfig{raw: []byte(modelPackJSON)} + cfg, err := partial.Config(mock) + if err != nil { + t.Fatalf("Config() error = %v", err) + } + + // Should return native format with interface methods working + if cfg.GetFormat() != types.FormatGGUF { + t.Errorf("GetFormat() = %v, want %v", cfg.GetFormat(), types.FormatGGUF) + } + // GetParameters() returns ParamSize for ModelPack + if cfg.GetParameters() != "8B" { + t.Errorf("GetParameters() = %q, want %q", cfg.GetParameters(), "8B") + } + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + mock := &mockConfig{raw: []byte("not valid json")} + _, err := partial.Config(mock) + if err == nil { + t.Error("expected error for invalid JSON") + } + }) +} + +// TestConfigFile tests ConfigFile() which is for Docker format only +func TestConfigFile(t *testing.T) { + t.Run("Docker format parses correctly", func(t *testing.T) { + dockerJSON := `{ + "config": {"format": "gguf", "parameters": "8B"}, + "descriptor": {}, + "rootfs": {"type": "layers", "diff_ids": []} + }` + + mock := &mockConfig{raw: []byte(dockerJSON)} + cf, err := partial.ConfigFile(mock) + if err != nil { + t.Fatalf("ConfigFile() error = %v", err) + } + + if cf.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", cf.Config.Format, types.FormatGGUF) + } + if cf.Config.Parameters != "8B" { + t.Errorf("Parameters = %q, want %q", cf.Config.Parameters, "8B") + } + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + mock := &mockConfig{raw: []byte("not valid json")} + _, err := partial.ConfigFile(mock) + if err == nil { + t.Error("expected error for invalid JSON") + } + }) +} + func TestMMPROJPath(t *testing.T) { // Create a model from GGUF file mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) @@ -122,3 +222,33 @@ func TestLayerPathByMediaType(t *testing.T) { } } + +// TestGGUFPaths_ModelPackMediaType tests that GGUFPaths can find ModelPack format layers +func TestGGUFPaths_ModelPackMediaType(t *testing.T) { + // Create a layer with ModelPack GGUF media type + modelPackGGUFType := ggcr.MediaType("application/vnd.cncf.model.weight.v1.gguf") + + layer, err := partial.NewLayer(filepath.Join("..", "..", "assets", "dummy.gguf"), modelPackGGUFType) + if err != nil { + t.Fatalf("Failed to create ModelPack layer: %v", err) + } + + // Create a model with mutate and add the layer + mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf")) + if err != nil { + t.Fatalf("Failed to create GGUF model: %v", err) + } + + mdlWithModelPackLayer := mutate.AppendLayers(mdl, layer) + + // GGUFPaths should be able to find ModelPack format GGUF layers + paths, err := partial.GGUFPaths(mdlWithModelPackLayer) + if err != nil { + t.Fatalf("GGUFPaths() error = %v", err) + } + + // Should find two: original Docker format + newly added ModelPack format + if len(paths) != 2 { + t.Errorf("Expected 2 GGUF paths, got %d", len(paths)) + } +} diff --git a/pkg/distribution/internal/safetensors/model_test.go b/pkg/distribution/internal/safetensors/model_test.go index d3024a6b5..11178144a 100644 --- a/pkg/distribution/internal/safetensors/model_test.go +++ b/pkg/distribution/internal/safetensors/model_test.go @@ -82,42 +82,48 @@ func TestNewModel_WithMetadata(t *testing.T) { } // Verify format - if config.Format != types.FormatSafetensors { - t.Errorf("Config.Format = %v, want %v", config.Format, types.FormatSafetensors) + if config.GetFormat() != types.FormatSafetensors { + t.Errorf("Config.Format = %v, want %v", config.GetFormat(), types.FormatSafetensors) } // Verify architecture - if config.Architecture != "LlamaForCausalLM" { - t.Errorf("Config.Architecture = %v, want %v", config.Architecture, "LlamaForCausalLM") + if config.GetArchitecture() != "LlamaForCausalLM" { + t.Errorf("Config.Architecture = %v, want %v", config.GetArchitecture(), "LlamaForCausalLM") } // Verify parameters (4096*4096 + 4096 = 16781312) expectedParams := "16.78M" - if config.Parameters != expectedParams { - t.Errorf("Config.Parameters = %v, want %v", config.Parameters, expectedParams) + if config.GetParameters() != expectedParams { + t.Errorf("Config.Parameters = %v, want %v", config.GetParameters(), expectedParams) } // Verify quantization - if config.Quantization != "F16" { - t.Errorf("Config.Quantization = %v, want %v", config.Quantization, "F16") + if config.GetQuantization() != "F16" { + t.Errorf("Config.Quantization = %v, want %v", config.GetQuantization(), "F16") } // Verify size is calculated - if config.Size == "" { + if config.GetSize() == "" { t.Error("Config.Size is empty") } + // Type assert to access Docker format specific fields + dockerConfig, ok := config.(*types.Config) + if !ok { + t.Fatal("Expected *types.Config for safetensors model") + } + // Verify safetensors metadata map - if config.Safetensors == nil { + if dockerConfig.Safetensors == nil { t.Fatal("Config.Safetensors is nil") } - if config.Safetensors["architecture"] != "LlamaForCausalLM" { - t.Errorf("Config.Safetensors[architecture] = %v, want %v", config.Safetensors["architecture"], "LlamaForCausalLM") + if dockerConfig.Safetensors["architecture"] != "LlamaForCausalLM" { + t.Errorf("Config.Safetensors[architecture] = %v, want %v", dockerConfig.Safetensors["architecture"], "LlamaForCausalLM") } - if config.Safetensors["tensor_count"] != "2" { - t.Errorf("Config.Safetensors[tensor_count] = %v, want %v", config.Safetensors["tensor_count"], "2") + if dockerConfig.Safetensors["tensor_count"] != "2" { + t.Errorf("Config.Safetensors[tensor_count] = %v, want %v", dockerConfig.Safetensors["tensor_count"], "2") } // Test annotations @@ -260,33 +266,39 @@ func TestNewModel_NoMetadata(t *testing.T) { } // Verify format - if config.Format != types.FormatSafetensors { - t.Errorf("Config.Format = %v, want %v", config.Format, types.FormatSafetensors) + if config.GetFormat() != types.FormatSafetensors { + t.Errorf("Config.Format = %v, want %v", config.GetFormat(), types.FormatSafetensors) } // Verify parameters (100*200 = 20000) expectedParams := "20.00K" - if config.Parameters != expectedParams { - t.Errorf("Config.Parameters = %v, want %v", config.Parameters, expectedParams) + if config.GetParameters() != expectedParams { + t.Errorf("Config.Parameters = %v, want %v", config.GetParameters(), expectedParams) } // Verify quantization - if config.Quantization != "F32" { - t.Errorf("Config.Quantization = %v, want %v", config.Quantization, "F32") + if config.GetQuantization() != "F32" { + t.Errorf("Config.Quantization = %v, want %v", config.GetQuantization(), "F32") } // Architecture should be empty when no metadata - if config.Architecture != "" { - t.Errorf("Config.Architecture = %v, want empty string", config.Architecture) + if config.GetArchitecture() != "" { + t.Errorf("Config.Architecture = %v, want empty string", config.GetArchitecture()) + } + + // Type assert to access Docker format specific fields + dockerConfig, ok := config.(*types.Config) + if !ok { + t.Fatal("Expected *types.Config for safetensors model") } // Verify safetensors metadata map exists with tensor count - if config.Safetensors == nil { + if dockerConfig.Safetensors == nil { t.Fatal("Config.Safetensors is nil") } - if config.Safetensors["tensor_count"] != "1" { - t.Errorf("Config.Safetensors[tensor_count] = %v, want %v", config.Safetensors["tensor_count"], "1") + if dockerConfig.Safetensors["tensor_count"] != "1" { + t.Errorf("Config.Safetensors[tensor_count] = %v, want %v", dockerConfig.Safetensors["tensor_count"], "1") } // Test annotations diff --git a/pkg/distribution/internal/store/model.go b/pkg/distribution/internal/store/model.go index ce3b45448..6778994a8 100644 --- a/pkg/distribution/internal/store/model.go +++ b/pkg/distribution/internal/store/model.go @@ -145,7 +145,7 @@ func (m *Model) ID() (string, error) { return mdpartial.ID(m) } -func (m *Model) Config() (mdtypes.Config, error) { +func (m *Model) Config() (mdtypes.ModelConfig, error) { return mdpartial.Config(m) } diff --git a/pkg/distribution/modelpack/convert.go b/pkg/distribution/modelpack/convert.go new file mode 100644 index 000000000..caf5b8624 --- /dev/null +++ b/pkg/distribution/modelpack/convert.go @@ -0,0 +1,155 @@ +package modelpack + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/docker/model-runner/pkg/distribution/types" + v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" + "github.com/opencontainers/go-digest" +) + +// IsModelPackMediaType checks if the given media type indicates a CNCF ModelPack format. +// It returns true if the media type has the CNCF model prefix. +func IsModelPackMediaType(mediaType string) bool { + return strings.HasPrefix(mediaType, MediaTypePrefix) +} + +// IsModelPackConfig detects if raw config bytes are in ModelPack format. +// It parses the JSON structure for precise detection, avoiding false positives from string matching. +// ModelPack format characteristics: config.paramSize or descriptor.createdAt +// Docker format uses: config.parameters and descriptor.created +func IsModelPackConfig(raw []byte) bool { + if len(raw) == 0 { + return false + } + + // Parse as map to check actual JSON structure + var parsed map[string]json.RawMessage + if err := json.Unmarshal(raw, &parsed); err != nil { + return false + } + + // Check for config.paramSize (ModelPack-specific field) + if configRaw, ok := parsed["config"]; ok { + var config map[string]json.RawMessage + if err := json.Unmarshal(configRaw, &config); err == nil { + if _, hasParamSize := config["paramSize"]; hasParamSize { + return true + } + } + } + + // Check for descriptor.createdAt (ModelPack uses camelCase) + if descRaw, ok := parsed["descriptor"]; ok { + var desc map[string]json.RawMessage + if err := json.Unmarshal(descRaw, &desc); err == nil { + if _, hasCreatedAt := desc["createdAt"]; hasCreatedAt { + return true + } + } + } + + // Check for modelfs (ModelPack-specific field name) + if _, hasModelFS := parsed["modelfs"]; hasModelFS { + return true + } + + return false +} + +// MapLayerMediaType maps ModelPack layer media types to Docker format. +// Returns the original value if not a ModelPack type. +func MapLayerMediaType(mediaType string) string { + // Only process ModelPack weight layers + if !strings.HasPrefix(mediaType, MediaTypePrefix) { + return mediaType + } + + // Determine corresponding Docker type based on media type format + switch { + case strings.Contains(mediaType, "weight") && strings.Contains(mediaType, "gguf"): + return string(types.MediaTypeGGUF) + case strings.Contains(mediaType, "weight") && strings.Contains(mediaType, "safetensors"): + return string(types.MediaTypeSafetensors) + default: + // Keep other layer types (doc, code, etc.) as-is + return mediaType + } +} + +// ConvertToDockerConfig converts a raw ModelPack config JSON to Docker model-spec ConfigFile. +// It maps common fields directly. Note: Extended ModelPack metadata is not preserved +// since types.Config no longer has a ModelPack field. +func ConvertToDockerConfig(rawConfig []byte) (*types.ConfigFile, error) { + var mp Model + if err := json.Unmarshal(rawConfig, &mp); err != nil { + return nil, fmt.Errorf("unmarshal modelpack config: %w", err) + } + + // Build the Docker format config + dockerConfig := &types.ConfigFile{ + Config: types.Config{ + Format: convertFormat(mp.Config.Format), + Architecture: mp.Config.Architecture, + Quantization: mp.Config.Quantization, + Parameters: mp.Config.ParamSize, + Size: "0", // ModelPack doesn't have an equivalent field + }, + Descriptor: types.Descriptor{ + Created: mp.Descriptor.CreatedAt, + }, + RootFS: v1.RootFS{ + Type: normalizeRootFSType(mp.ModelFS.Type), + DiffIDs: convertDiffIDs(mp.ModelFS.DiffIDs), + }, + } + + return dockerConfig, nil +} + +// convertFormat maps ModelPack format strings to Docker Format type. +// Format strings are normalized to lowercase for consistent matching. +func convertFormat(mpFormat string) types.Format { + switch strings.ToLower(mpFormat) { + case "gguf": + return types.FormatGGUF + case "safetensors": + return types.FormatSafetensors + default: + // Pass through unknown formats as-is + return types.Format(strings.ToLower(mpFormat)) + } +} + +// normalizeRootFSType ensures the rootfs type is set correctly. +// ModelPack uses "layers" as the type, which maps to Docker's "layers". +func normalizeRootFSType(mpType string) string { + if mpType == "" { + return "layers" + } + return mpType +} + +// convertDiffIDs converts opencontainers digest.Digest slice to go-containerregistry v1.Hash slice. +// Note: Invalid digests are silently skipped here because they will be caught +// during layer validation when the model is actually loaded. This avoids +// failing early for formats we might not fully understand yet. +func convertDiffIDs(digests []digest.Digest) []v1.Hash { + if len(digests) == 0 { + return nil + } + + result := make([]v1.Hash, 0, len(digests)) + for _, d := range digests { + // digest.Digest format is "algorithm:hex", same as v1.Hash + hash, err := v1.NewHash(d.String()) + if err != nil { + // Skip invalid digests; they will be caught during layer validation + continue + } + result = append(result, hash) + } + return result +} diff --git a/pkg/distribution/modelpack/convert_test.go b/pkg/distribution/modelpack/convert_test.go new file mode 100644 index 000000000..26b859ae6 --- /dev/null +++ b/pkg/distribution/modelpack/convert_test.go @@ -0,0 +1,439 @@ +package modelpack + +import ( + "encoding/json" + "testing" + "time" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/opencontainers/go-digest" +) + +func TestIsModelPackMediaType(t *testing.T) { + tests := []struct { + name string + mediaType string + expected bool + }{ + { + name: "CNCF v1 config", + mediaType: "application/vnd.cncf.model.config.v1+json", + expected: true, + }, + { + name: "CNCF future version", + mediaType: "application/vnd.cncf.model.config.v2+json", + expected: true, + }, + { + name: "CNCF weight media type", + mediaType: "application/vnd.cncf.model.weight.v1.raw", + expected: true, + }, + { + name: "Docker format", + mediaType: "application/vnd.docker.ai.model.config.v0.1+json", + expected: false, + }, + { + name: "Generic JSON", + mediaType: "application/json", + expected: false, + }, + { + name: "Empty string", + mediaType: "", + expected: false, + }, + { + name: "OCI image config", + mediaType: "application/vnd.oci.image.config.v1+json", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsModelPackMediaType(tt.mediaType) + if result != tt.expected { + t.Errorf("IsModelPackMediaType(%q) = %v, want %v", tt.mediaType, result, tt.expected) + } + }) + } +} + +func TestConvertToDockerConfig(t *testing.T) { + t.Run("full config conversion", func(t *testing.T) { + created := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC) + knowledgeCutoff := time.Date(2024, 6, 1, 0, 0, 0, 0, time.UTC) + reasoning := true + toolUsage := true + + mpConfig := Model{ + Descriptor: ModelDescriptor{ + CreatedAt: &created, + Authors: []string{"Author1", "Author2"}, + Family: "llama", + Name: "llama3-8b-instruct", + DocURL: "https://example.com/docs", + SourceURL: "https://example.com/source", + DatasetsURL: []string{"https://example.com/dataset1", "https://example.com/dataset2"}, + Version: "1.0.0", + Revision: "abc123", + Vendor: "TestVendor", + Licenses: []string{"MIT", "Apache-2.0"}, + Title: "Llama 3 8B Instruct", + Description: "A test model for testing", + }, + Config: ModelConfig{ + Architecture: "transformer", + Format: "gguf", + ParamSize: "8B", + Precision: "fp16", + Quantization: "Q4_K_M", + Capabilities: &ModelCapabilities{ + InputTypes: []string{"text"}, + OutputTypes: []string{"text"}, + KnowledgeCutoff: &knowledgeCutoff, + Reasoning: &reasoning, + ToolUsage: &toolUsage, + Languages: []string{"en", "zh"}, + }, + }, + ModelFS: ModelFS{ + Type: "layers", + DiffIDs: []digest.Digest{"sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abc1"}, + }, + } + + rawConfig, err := json.Marshal(mpConfig) + if err != nil { + t.Fatalf("Failed to marshal test config: %v", err) + } + + dockerConfig, err := ConvertToDockerConfig(rawConfig) + if err != nil { + t.Fatalf("ConvertToDockerConfig failed: %v", err) + } + + // Verify direct field mappings + if dockerConfig.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", dockerConfig.Config.Format, types.FormatGGUF) + } + if dockerConfig.Config.Architecture != "transformer" { + t.Errorf("Architecture = %q, want %q", dockerConfig.Config.Architecture, "transformer") + } + if dockerConfig.Config.Quantization != "Q4_K_M" { + t.Errorf("Quantization = %q, want %q", dockerConfig.Config.Quantization, "Q4_K_M") + } + if dockerConfig.Config.Parameters != "8B" { + t.Errorf("Parameters = %q, want %q", dockerConfig.Config.Parameters, "8B") + } + if dockerConfig.Config.Size != "0" { + t.Errorf("Size = %q, want %q", dockerConfig.Config.Size, "0") + } + + // Verify descriptor + if dockerConfig.Descriptor.Created == nil { + t.Error("Descriptor.Created should not be nil") + } else if !dockerConfig.Descriptor.Created.Equal(created) { + t.Errorf("Descriptor.Created = %v, want %v", dockerConfig.Descriptor.Created, created) + } + + // Verify RootFS + if dockerConfig.RootFS.Type != "layers" { + t.Errorf("RootFS.Type = %q, want %q", dockerConfig.RootFS.Type, "layers") + } + if len(dockerConfig.RootFS.DiffIDs) != 1 { + t.Errorf("RootFS.DiffIDs length = %d, want 1", len(dockerConfig.RootFS.DiffIDs)) + } + // Note: Extended metadata (ModelPack field) is no longer preserved since + // types.Config no longer has a ModelPack field. Native format support (Option B) + // handles ModelPack configs directly without conversion. + }) + + t.Run("minimal config", func(t *testing.T) { + mpConfig := Model{ + Config: ModelConfig{ + Format: "gguf", + }, + ModelFS: ModelFS{ + Type: "layers", + DiffIDs: []digest.Digest{"sha256:abc123"}, + }, + } + + rawConfig, _ := json.Marshal(mpConfig) + dockerConfig, err := ConvertToDockerConfig(rawConfig) + if err != nil { + t.Fatalf("ConvertToDockerConfig failed for minimal config: %v", err) + } + + if dockerConfig.Config.Format != types.FormatGGUF { + t.Errorf("Format = %v, want %v", dockerConfig.Config.Format, types.FormatGGUF) + } + }) + + t.Run("empty config", func(t *testing.T) { + mpConfig := Model{} + rawConfig, _ := json.Marshal(mpConfig) + + dockerConfig, err := ConvertToDockerConfig(rawConfig) + if err != nil { + t.Fatalf("ConvertToDockerConfig failed for empty config: %v", err) + } + + if dockerConfig.Config.Format != "" { + t.Errorf("Format should be empty, got %v", dockerConfig.Config.Format) + } + if dockerConfig.RootFS.Type != "layers" { + t.Errorf("RootFS.Type should default to 'layers', got %q", dockerConfig.RootFS.Type) + } + }) + + t.Run("invalid JSON", func(t *testing.T) { + _, err := ConvertToDockerConfig([]byte("invalid json")) + if err == nil { + t.Error("Expected error for invalid JSON, got nil") + } + }) + + t.Run("empty input", func(t *testing.T) { + _, err := ConvertToDockerConfig([]byte("")) + if err == nil { + t.Error("Expected error for empty input, got nil") + } + }) +} + +func TestConvertFormat(t *testing.T) { + tests := []struct { + input string + expected types.Format + }{ + {"gguf", types.FormatGGUF}, + {"GGUF", types.FormatGGUF}, + {"GgUf", types.FormatGGUF}, + {"safetensors", types.FormatSafetensors}, + {"SafeTensors", types.FormatSafetensors}, + {"SAFETENSORS", types.FormatSafetensors}, + {"onnx", types.Format("onnx")}, + {"pytorch", types.Format("pytorch")}, + {"", types.Format("")}, + {"unknown", types.Format("unknown")}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := convertFormat(tt.input) + if result != tt.expected { + t.Errorf("convertFormat(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestConvertDiffIDs(t *testing.T) { + t.Run("valid digests", func(t *testing.T) { + digests := []digest.Digest{ + "sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abc1", + "sha256:123456789012345678901234567890123456789012345678901234567890abcd", + } + + result := convertDiffIDs(digests) + if len(result) != 2 { + t.Errorf("Expected 2 hashes, got %d", len(result)) + } + }) + + t.Run("empty slice", func(t *testing.T) { + result := convertDiffIDs([]digest.Digest{}) + if result != nil { + t.Errorf("Expected nil for empty slice, got %v", result) + } + }) + + t.Run("nil slice", func(t *testing.T) { + result := convertDiffIDs(nil) + if result != nil { + t.Errorf("Expected nil for nil slice, got %v", result) + } + }) + + t.Run("invalid digest skipped", func(t *testing.T) { + digests := []digest.Digest{ + "sha256:abc123def456abc123def456abc123def456abc123def456abc123def456abc1", + "invalid-digest-format", // This should be skipped + "sha256:123456789012345678901234567890123456789012345678901234567890abcd", + } + + result := convertDiffIDs(digests) + // Should only have 2 valid hashes, invalid one skipped + if len(result) != 2 { + t.Errorf("Expected 2 valid hashes (invalid skipped), got %d", len(result)) + } + }) +} + +// Note: TestExtractExtendedMetadata was removed because the extractExtendedMetadata +// function was removed. With Option B (native format support), ModelPack configs +// are handled directly without conversion to Docker format. + +func TestNormalizeRootFSType(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"layers", "layers"}, + {"", "layers"}, + {"rootfs", "rootfs"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeRootFSType(tt.input) + if result != tt.expected { + t.Errorf("normalizeRootFSType(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestMapLayerMediaType tests layer media type mapping +func TestMapLayerMediaType(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + // ModelPack GGUF related media types + { + name: "ModelPack weight gguf v1", + input: "application/vnd.cncf.model.weight.v1.gguf", + expected: "application/vnd.docker.ai.gguf.v3", + }, + { + name: "ModelPack weight gguf no version", + input: "application/vnd.cncf.model.weight.gguf", + expected: "application/vnd.docker.ai.gguf.v3", + }, + // ModelPack safetensors related + { + name: "ModelPack weight safetensors", + input: "application/vnd.cncf.model.weight.v1.safetensors", + expected: "application/vnd.docker.ai.safetensors", + }, + // Docker format passthrough + { + name: "Docker GGUF passthrough", + input: "application/vnd.docker.ai.gguf.v3", + expected: "application/vnd.docker.ai.gguf.v3", + }, + { + name: "Docker safetensors passthrough", + input: "application/vnd.docker.ai.safetensors", + expected: "application/vnd.docker.ai.safetensors", + }, + // Other types unchanged + { + name: "generic octet-stream", + input: "application/octet-stream", + expected: "application/octet-stream", + }, + { + name: "ModelPack doc layer unchanged", + input: "application/vnd.cncf.model.doc.v1", + expected: "application/vnd.cncf.model.doc.v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MapLayerMediaType(tt.input) + if got != tt.expected { + t.Errorf("MapLayerMediaType(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +// TestIsModelPackConfig tests detecting ModelPack format from raw config bytes +func TestIsModelPackConfig(t *testing.T) { + // Prepare test ModelPack format config (has paramSize field) + modelPackConfig := `{ + "descriptor": {"createdAt": "2025-01-15T10:30:00Z"}, + "config": {"paramSize": "8B", "format": "gguf"} + }` + + // Docker format config (uses parameters instead of paramSize) + dockerConfig := `{ + "config": {"parameters": "8B", "format": "gguf"}, + "descriptor": {"created": "2025-01-15T10:30:00Z"} + }` + + tests := []struct { + name string + input []byte + expected bool + }{ + { + name: "ModelPack config with paramSize", + input: []byte(modelPackConfig), + expected: true, + }, + { + name: "Docker config with parameters", + input: []byte(dockerConfig), + expected: false, + }, + { + name: "empty JSON object", + input: []byte("{}"), + expected: false, + }, + { + name: "invalid JSON", + input: []byte("not json"), + expected: false, + }, + { + name: "nil input", + input: nil, + expected: false, + }, + { + name: "empty input", + input: []byte(""), + expected: false, + }, + { + name: "config with createdAt field", + input: []byte(`{"descriptor": {"createdAt": "2025-01-01T00:00:00Z"}}`), + expected: true, + }, + { + name: "config with modelfs field", + input: []byte(`{"modelfs": {"type": "layers", "diffIds": []}}`), + expected: true, + }, + { + name: "false positive prevention - paramSize as value", + input: []byte(`{"config": {"description": "paramSize is 8B"}}`), + expected: false, + }, + { + name: "false positive prevention - createdAt as value", + input: []byte(`{"descriptor": {"note": "createdAt was yesterday"}}`), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsModelPackConfig(tt.input) + if got != tt.expected { + t.Errorf("IsModelPackConfig() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/pkg/distribution/modelpack/types.go b/pkg/distribution/modelpack/types.go new file mode 100644 index 000000000..d6d1b5e8d --- /dev/null +++ b/pkg/distribution/modelpack/types.go @@ -0,0 +1,191 @@ +// Package modelpack provides native support for CNCF ModelPack format models. +// It enables docker/model-runner to pull, store, and run models in ModelPack format +// without conversion. Both Docker and ModelPack formats are supported natively through +// the types.ModelConfig interface. +// +// Note: JSON tags in this package use camelCase (e.g., "createdAt", "paramSize") to match +// the CNCF ModelPack spec, which differs from Docker model-spec's snake_case convention +// (e.g., "context_size"). +// +// See: https://github.com/modelpack/model-spec +package modelpack + +import ( + "strings" + "time" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/opencontainers/go-digest" +) + +const ( + // MediaTypePrefix is the prefix for all CNCF model config media types. + MediaTypePrefix = "application/vnd.cncf.model." + + // MediaTypeModelConfigV1 is the CNCF model config v1 media type. + MediaTypeModelConfigV1 = "application/vnd.cncf.model.config.v1+json" + + // MediaTypeWeightGGUF is the CNCF ModelPack media type for GGUF weight layers. + MediaTypeWeightGGUF = "application/vnd.cncf.model.weight.v1.gguf" + + // MediaTypeWeightSafetensors is the CNCF ModelPack media type for safetensors weight layers. + MediaTypeWeightSafetensors = "application/vnd.cncf.model.weight.v1.safetensors" +) + +// Model represents the CNCF ModelPack config structure. +// It provides the `application/vnd.cncf.model.config.v1+json` mediatype when marshalled to JSON. +type Model struct { + // Descriptor provides metadata about the model provenance and identity. + Descriptor ModelDescriptor `json:"descriptor"` + + // ModelFS describes the layer content addresses. + ModelFS ModelFS `json:"modelfs"` + + // Config defines the execution parameters for the model. + Config ModelConfig `json:"config,omitempty"` +} + +// ModelDescriptor defines the general information of a model. +type ModelDescriptor struct { + // CreatedAt is the date and time on which the model was built. + CreatedAt *time.Time `json:"createdAt,omitempty"` + + // Authors contains the contact details of the people or organization responsible for the model. + Authors []string `json:"authors,omitempty"` + + // Family is the model family, such as llama3, gpt2, qwen2, etc. + Family string `json:"family,omitempty"` + + // Name is the model name, such as llama3-8b-instruct, gpt2-xl, etc. + Name string `json:"name,omitempty"` + + // DocURL is the URL to get documentation on the model. + DocURL string `json:"docURL,omitempty"` + + // SourceURL is the URL to get source code for building the model. + SourceURL string `json:"sourceURL,omitempty"` + + // DatasetsURL contains URLs referencing datasets that the model was trained upon. + DatasetsURL []string `json:"datasetsURL,omitempty"` + + // Version is the version of the packaged software. + Version string `json:"version,omitempty"` + + // Revision is the source control revision identifier for the packaged software. + Revision string `json:"revision,omitempty"` + + // Vendor is the name of the distributing entity, organization or individual. + Vendor string `json:"vendor,omitempty"` + + // Licenses contains the license(s) under which contained software is distributed + // as an SPDX License Expression. + Licenses []string `json:"licenses,omitempty"` + + // Title is the human-readable title of the model. + Title string `json:"title,omitempty"` + + // Description is the human-readable description of the software packaged in the model. + Description string `json:"description,omitempty"` +} + +// ModelConfig defines the execution parameters which should be used as a base +// when running a model using an inference engine. +type ModelConfig struct { + // Architecture is the model architecture, such as transformer, cnn, rnn, etc. + Architecture string `json:"architecture,omitempty"` + + // Format is the model format, such as gguf, safetensors, onnx, etc. + Format string `json:"format,omitempty"` + + // ParamSize is the size of the model parameters, such as "8b", "16b", "32b", etc. + ParamSize string `json:"paramSize,omitempty"` + + // Precision is the model precision, such as bf16, fp16, int8, mixed etc. + Precision string `json:"precision,omitempty"` + + // Quantization is the model quantization method, such as awq, gptq, etc. + Quantization string `json:"quantization,omitempty"` + + // Capabilities defines special capabilities that the model supports. + Capabilities *ModelCapabilities `json:"capabilities,omitempty"` +} + +// ModelCapabilities defines the special capabilities that the model supports. +type ModelCapabilities struct { + // InputTypes specifies what input modalities the model can process. + // Values can be: "text", "image", "audio", "video", "embedding", "other". + InputTypes []string `json:"inputTypes,omitempty"` + + // OutputTypes specifies what output modalities the model can produce. + // Values can be: "text", "image", "audio", "video", "embedding", "other". + OutputTypes []string `json:"outputTypes,omitempty"` + + // KnowledgeCutoff is the date of the datasets that the model was trained on. + KnowledgeCutoff *time.Time `json:"knowledgeCutoff,omitempty"` + + // Reasoning indicates whether the model can perform reasoning tasks. + Reasoning *bool `json:"reasoning,omitempty"` + + // ToolUsage indicates whether the model can use external tools. + ToolUsage *bool `json:"toolUsage,omitempty"` + + // Reward indicates whether the model is a reward model. + Reward *bool `json:"reward,omitempty"` + + // Languages indicates the languages that the model can speak. + // Encoded as ISO 639 two letter codes. For example, ["en", "fr", "zh"]. + Languages []string `json:"languages,omitempty"` +} + +// ModelFS describes the layer content addresses. +type ModelFS struct { + // Type is the type of the rootfs. MUST be set to "layers". + Type string `json:"type"` + + // DiffIDs is an array of layer content hashes (DiffIDs), + // in order from bottom-most to top-most. + DiffIDs []digest.Digest `json:"diffIds"` +} + +// Ensure Model implements types.ModelConfig +var _ types.ModelConfig = (*Model)(nil) + +// GetFormat returns the model format, converted to types.Format. +func (m *Model) GetFormat() types.Format { + f := strings.ToLower(m.Config.Format) + switch f { + case "gguf": + return types.FormatGGUF + case "safetensors": + return types.FormatSafetensors + default: + return types.Format(f) + } +} + +// GetContextSize returns the context size. ModelPack spec does not define this field, +// so it always returns nil. +func (m *Model) GetContextSize() *int32 { + return nil +} + +// GetSize returns the parameter size (e.g., "8b"). +func (m *Model) GetSize() string { + return m.Config.ParamSize +} + +// GetArchitecture returns the model architecture. +func (m *Model) GetArchitecture() string { + return m.Config.Architecture +} + +// GetParameters returns the parameters description. +// ModelPack uses ParamSize instead of Parameters, so return ParamSize. +func (m *Model) GetParameters() string { + return m.Config.ParamSize +} + +// GetQuantization returns the quantization method. +func (m *Model) GetQuantization() string { + return m.Config.Quantization +} diff --git a/pkg/distribution/registry/artifact.go b/pkg/distribution/registry/artifact.go index 3bd8dfe85..d20001e0c 100644 --- a/pkg/distribution/registry/artifact.go +++ b/pkg/distribution/registry/artifact.go @@ -16,7 +16,7 @@ func (a *artifact) ID() (string, error) { return partial.ID(a) } -func (a *artifact) Config() (types.Config, error) { +func (a *artifact) Config() (types.ModelConfig, error) { return partial.Config(a) } diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index 62e45ebad..ccd78847a 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -52,6 +52,19 @@ const ( type Format string +// ModelConfig provides a unified interface for accessing model configuration. +// Both Docker format (*Config) and ModelPack format (*modelpack.Model) implement +// this interface, allowing schedulers and backends to access config without +// knowing the underlying format. +type ModelConfig interface { + GetFormat() Format + GetContextSize() *int32 + GetSize() string + GetArchitecture() string + GetParameters() string + GetQuantization() string +} + type ConfigFile struct { Config Config `json:"config"` Descriptor Descriptor `json:"descriptor"` @@ -75,6 +88,39 @@ type Descriptor struct { Created *time.Time `json:"created,omitempty"` } +// Ensure Config implements ModelConfig +var _ ModelConfig = (*Config)(nil) + +// GetFormat returns the model format. +func (c *Config) GetFormat() Format { + return c.Format +} + +// GetContextSize returns the context size configuration. +func (c *Config) GetContextSize() *int32 { + return c.ContextSize +} + +// GetSize returns the parameter size (e.g., "8B"). +func (c *Config) GetSize() string { + return c.Size +} + +// GetArchitecture returns the model architecture. +func (c *Config) GetArchitecture() string { + return c.Architecture +} + +// GetParameters returns the parameters description. +func (c *Config) GetParameters() string { + return c.Parameters +} + +// GetQuantization returns the quantization method. +func (c *Config) GetQuantization() string { + return c.Quantization +} + // FileMetadata represents the metadata of file, which is the value definition of AnnotationFileMetadata. // This follows the OCI image specification for model artifacts. type FileMetadata struct { diff --git a/pkg/distribution/types/model.go b/pkg/distribution/types/model.go index edb9a4a4a..5c1f81e95 100644 --- a/pkg/distribution/types/model.go +++ b/pkg/distribution/types/model.go @@ -10,7 +10,7 @@ type Model interface { SafetensorsPaths() ([]string, error) ConfigArchivePath() (string, error) MMPROJPath() (string, error) - Config() (Config, error) + Config() (ModelConfig, error) Tags() []string Descriptor() (Descriptor, error) ChatTemplatePath() (string, error) @@ -18,7 +18,7 @@ type Model interface { type ModelArtifact interface { ID() (string, error) - Config() (Config, error) + Config() (ModelConfig, error) Descriptor() (Descriptor, error) v1.Image } @@ -29,5 +29,5 @@ type ModelBundle interface { SafetensorsPath() string ChatTemplatePath() string MMPROJPath() string - RuntimeConfig() Config + RuntimeConfig() ModelConfig } diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 935428eeb..9e1d6bd1e 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -16,12 +16,14 @@ import ( "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/distribution/types" + v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends" "github.com/docker/model-runner/pkg/inference/config" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/sandbox" + parser "github.com/gpustack/gguf-parser-go" ) const ( @@ -194,6 +196,143 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) { return size, nil } +func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (inference.RequiredMemory, error) { + mdlGguf, mdlConfig, err := l.parseModel(ctx, model) + if err != nil { + return inference.RequiredMemory{}, &inference.ErrGGUFParse{Err: err} + } + + configuredContextSize := GetContextSize(mdlConfig, config) + contextSize := int32(4096) // default context size + if configuredContextSize != nil { + contextSize = *configuredContextSize + } + + var ngl uint64 + if l.gpuSupported { + ngl = 999 + if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && mdlConfig.GetQuantization() != "Q4_0" { + ngl = 0 // only Q4_0 models can be accelerated on Adreno + } + } + + memory := l.estimateMemoryFromGGUF(mdlGguf, contextSize, ngl) + + if config != nil && config.Speculative != nil && config.Speculative.DraftModel != "" { + draftGguf, _, err := l.parseModel(ctx, config.Speculative.DraftModel) + if err != nil { + return inference.RequiredMemory{}, fmt.Errorf("estimating draft model memory: %w", &inference.ErrGGUFParse{Err: err}) + } + draftMemory := l.estimateMemoryFromGGUF(draftGguf, contextSize, ngl) + memory.RAM += draftMemory.RAM + memory.VRAM += draftMemory.VRAM + } + + if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" { + memory.VRAM = 1 + } + + return memory, nil +} + +// parseModel parses a model (local or remote) and returns the GGUF file and config. +func (l *llamaCpp) parseModel(ctx context.Context, model string) (*parser.GGUFFile, types.ModelConfig, error) { + inStore, err := l.modelManager.InStore(model) + if err != nil { + return nil, nil, fmt.Errorf("checking if model is in local store: %w", err) + } + if inStore { + return l.parseLocalModel(model) + } + return l.parseRemoteModel(ctx, model) +} + +// estimateMemoryFromGGUF estimates memory requirements from a parsed GGUF file. +func (l *llamaCpp) estimateMemoryFromGGUF(ggufFile *parser.GGUFFile, contextSize int32, ngl uint64) inference.RequiredMemory { + estimate := ggufFile.EstimateLLaMACppRun( + parser.WithLLaMACppContextSize(contextSize), + parser.WithLLaMACppLogicalBatchSize(2048), + parser.WithLLaMACppOffloadLayers(ngl), + ) + ram := uint64(estimate.Devices[0].Weight.Sum() + estimate.Devices[0].KVCache.Sum() + estimate.Devices[0].Computation.Sum()) + var vram uint64 + if len(estimate.Devices) > 1 { + vram = uint64(estimate.Devices[1].Weight.Sum() + estimate.Devices[1].KVCache.Sum() + estimate.Devices[1].Computation.Sum()) + } + + return inference.RequiredMemory{ + RAM: ram, + VRAM: vram, + } +} + +func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.ModelConfig, error) { + bundle, err := l.modelManager.GetBundle(model) + if err != nil { + return nil, nil, fmt.Errorf("getting model(%s): %w", model, err) + } + modelGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath()) + if err != nil { + return nil, nil, fmt.Errorf("parsing gguf(%s): %w", bundle.GGUFPath(), err) + } + return modelGGUF, bundle.RuntimeConfig(), nil +} + +func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.ModelConfig, error) { + mdl, err := l.modelManager.GetRemote(ctx, model) + if err != nil { + return nil, nil, fmt.Errorf("getting remote model(%s): %w", model, err) + } + layers, err := mdl.Layers() + if err != nil { + return nil, nil, fmt.Errorf("getting layers of model(%s): %w", model, err) + } + ggufLayers := getGGUFLayers(layers) + if len(ggufLayers) != 1 { + return nil, nil, fmt.Errorf( + "remote memory estimation only supported for models with single GGUF layer, found %d layers", len(ggufLayers), + ) + } + ggufDigest, err := ggufLayers[0].Digest() + if err != nil { + return nil, nil, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err) + } + if ggufDigest.String() == "" { + return nil, nil, fmt.Errorf("model(%s) has no GGUF layer", model) + } + blobURL, err := l.modelManager.GetRemoteBlobURL(model, ggufDigest) + if err != nil { + return nil, nil, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err) + } + tok, err := l.modelManager.BearerTokenForModel(ctx, model) + if err != nil { + return nil, nil, fmt.Errorf("getting bearer token for model(%s): %w", model, err) + } + mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok)) + if err != nil { + return nil, nil, fmt.Errorf("parsing GGUF for model(%s): %w", model, err) + } + config, err := mdl.Config() + if err != nil { + return nil, nil, fmt.Errorf("getting config for model(%s): %w", model, err) + } + return mdlGguf, config, nil +} + +func getGGUFLayers(layers []v1.Layer) []v1.Layer { + var filtered []v1.Layer + for _, layer := range layers { + mt, err := layer.MediaType() + if err != nil { + continue + } + if mt == types.MediaTypeGGUF { + filtered = append(filtered, layer) + } + } + return filtered +} + func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool { binPath := l.vendoredServerStoragePath if l.updatedLlamaCpp { diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index c0fad1240..f0ed4106f 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -94,10 +94,12 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 { +func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { // Model config takes precedence - if modelCfg.ContextSize != nil && (*modelCfg.ContextSize == UnlimitedContextSize || *modelCfg.ContextSize > 0) { - return modelCfg.ContextSize + if modelCfg != nil { + if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) { + return ctxSize + } } // Fallback to backend config if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) { diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index fc0c976e4..1a53a1c85 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -195,7 +195,7 @@ func TestGetArgs(t *testing.T) { mode: inference.BackendModeEmbedding, bundle: &fakeBundle{ ggufPath: modelPath, - config: types.Config{ + config: &types.Config{ ContextSize: int32ptr(2096), }, }, @@ -423,7 +423,7 @@ var _ types.ModelBundle = &fakeBundle{} type fakeBundle struct { ggufPath string - config types.Config + config *types.Config templatePath string mmprojPath string } @@ -448,7 +448,10 @@ func (f *fakeBundle) SafetensorsPath() string { return "" } -func (f *fakeBundle) RuntimeConfig() types.Config { +func (f *fakeBundle) RuntimeConfig() types.ModelConfig { + if f.config == nil { + return nil + } return f.config } diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index 025f5b89c..bc4f605c6 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -64,6 +64,6 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // GetMaxTokens returns the max tokens (context size) from model config or backend config. // Model config takes precedence over backend config. // Returns nil if neither is specified (MLX will use model defaults). -func GetMaxTokens(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *uint64 { +func GetMaxTokens(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *uint64 { return nil } diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 8b500a906..4d220d96c 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -65,10 +65,10 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // GetContextLength returns the context length (context size) from model config or backend config. // Model config takes precedence over backend config. // Returns nil if neither is specified (SGLang will auto-derive from model). -func GetContextLength(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 { +func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { // Model config takes precedence - if modelCfg.ContextSize != nil && *modelCfg.ContextSize > 0 { - return modelCfg.ContextSize + if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 { + return cs } // Fallback to backend config if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { diff --git a/pkg/inference/backends/sglang/sglang_config_test.go b/pkg/inference/backends/sglang/sglang_config_test.go index 28886527a..e4aed9255 100644 --- a/pkg/inference/backends/sglang/sglang_config_test.go +++ b/pkg/inference/backends/sglang/sglang_config_test.go @@ -9,7 +9,7 @@ import ( type mockModelBundle struct { safetensorsPath string - runtimeConfig types.Config + runtimeConfig *types.Config } func (m *mockModelBundle) GGUFPath() string { @@ -28,7 +28,10 @@ func (m *mockModelBundle) MMPROJPath() string { return "" } -func (m *mockModelBundle) RuntimeConfig() types.Config { +func (m *mockModelBundle) RuntimeConfig() types.ModelConfig { + if m.runtimeConfig == nil { + return &types.Config{} + } return m.runtimeConfig } @@ -99,7 +102,7 @@ func TestGetArgs(t *testing.T) { name: "with model context size (takes precedence)", bundle: &mockModelBundle{ safetensorsPath: "/path/to/model/model.safetensors", - runtimeConfig: types.Config{ + runtimeConfig: &types.Config{ ContextSize: int32ptr(16384), }, }, @@ -191,19 +194,19 @@ func TestGetArgs(t *testing.T) { func TestGetContextLength(t *testing.T) { tests := []struct { name string - modelCfg types.Config + modelCfg types.ModelConfig backendCfg *inference.BackendConfiguration expectedValue *int32 }{ { name: "no config", - modelCfg: types.Config{}, + modelCfg: &types.Config{}, backendCfg: nil, expectedValue: nil, }, { name: "backend config only", - modelCfg: types.Config{}, + modelCfg: &types.Config{}, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(4096), }, @@ -211,7 +214,7 @@ func TestGetContextLength(t *testing.T) { }, { name: "model config only", - modelCfg: types.Config{ + modelCfg: &types.Config{ ContextSize: int32ptr(8192), }, backendCfg: nil, @@ -219,7 +222,7 @@ func TestGetContextLength(t *testing.T) { }, { name: "model config takes precedence", - modelCfg: types.Config{ + modelCfg: &types.Config{ ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ @@ -229,7 +232,7 @@ func TestGetContextLength(t *testing.T) { }, { name: "zero context size in backend config returns nil", - modelCfg: types.Config{}, + modelCfg: &types.Config{}, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(0), }, diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index b07c2f2e5..b172637f2 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -89,10 +89,12 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference // GetMaxModelLen returns the max model length (context size) from model config or backend config. // Model config takes precedence over backend config. // Returns nil if neither is specified (vLLM will auto-derive from model). -func GetMaxModelLen(modelCfg types.Config, backendCfg *inference.BackendConfiguration) *int32 { +func GetMaxModelLen(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { // Model config takes precedence - if modelCfg.ContextSize != nil { - return modelCfg.ContextSize + if modelCfg != nil { + if ctxSize := modelCfg.GetContextSize(); ctxSize != nil { + return ctxSize + } } // Fallback to backend config if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index ee9304f98..c52d65e19 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -9,7 +9,7 @@ import ( type mockModelBundle struct { safetensorsPath string - runtimeConfig types.Config + runtimeConfig *types.Config } func (m *mockModelBundle) GGUFPath() string { @@ -28,7 +28,10 @@ func (m *mockModelBundle) MMPROJPath() string { return "" } -func (m *mockModelBundle) RuntimeConfig() types.Config { +func (m *mockModelBundle) RuntimeConfig() types.ModelConfig { + if m.runtimeConfig == nil { + return nil + } return m.runtimeConfig } @@ -104,7 +107,7 @@ func TestGetArgs(t *testing.T) { name: "with model context size (takes precedence)", bundle: &mockModelBundle{ safetensorsPath: "/path/to/model", - runtimeConfig: types.Config{ + runtimeConfig: &types.Config{ ContextSize: int32ptr(16384), }, }, @@ -383,19 +386,19 @@ func TestGetArgs(t *testing.T) { func TestGetMaxModelLen(t *testing.T) { tests := []struct { name string - modelCfg types.Config + modelCfg types.ModelConfig backendCfg *inference.BackendConfiguration expectedValue *int32 }{ { name: "no config", - modelCfg: types.Config{}, + modelCfg: &types.Config{}, backendCfg: nil, expectedValue: nil, }, { name: "backend config only", - modelCfg: types.Config{}, + modelCfg: &types.Config{}, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(4096), }, @@ -403,7 +406,7 @@ func TestGetMaxModelLen(t *testing.T) { }, { name: "model config only", - modelCfg: types.Config{ + modelCfg: &types.Config{ ContextSize: int32ptr(8192), }, backendCfg: nil, @@ -411,7 +414,7 @@ func TestGetMaxModelLen(t *testing.T) { }, { name: "model config takes precedence", - modelCfg: types.Config{ + modelCfg: &types.Config{ ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index 2b8d93fc1..727612b71 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -21,11 +21,11 @@ type ModelCreateRequest struct { // SimpleModel is a wrapper that allows creating a model with modified configuration type SimpleModel struct { types.Model - ConfigValue types.Config + ConfigValue types.ModelConfig DescriptorValue types.Descriptor } -func (s *SimpleModel) Config() (types.Config, error) { +func (s *SimpleModel) Config() (types.ModelConfig, error) { return s.ConfigValue, nil } @@ -108,6 +108,7 @@ type Model struct { Tags []string `json:"tags,omitempty"` // Created is the Unix epoch timestamp corresponding to the model creation. Created int64 `json:"created"` - // Config describes the model. - Config types.Config `json:"config"` + // Config describes the model. Can be either Docker format (*types.Config) + // or ModelPack format (*modelpack.Model). + Config types.ModelConfig `json:"config"` } diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index 5e53663d2..5c4284cda 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -262,7 +262,13 @@ func TestHandleGetModel(t *testing.T) { } } else { // For successful responses, verify we got a valid JSON response - var response Model + // Use a test struct with json.RawMessage for Config since ModelConfig is an interface + var response struct { + ID string `json:"id"` + Tags []string `json:"tags,omitempty"` + Created int64 `json:"created"` + Config json.RawMessage `json:"config"` + } if err := json.NewDecoder(w.Body).Decode(&response); err != nil { t.Errorf("Failed to decode response body: %v", err) } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 7f4247a10..631b8218c 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -100,7 +100,7 @@ func (s *Scheduler) selectBackendForModel(model types.Model, backend inference.B return backend } - if config.Format == types.FormatSafetensors { + if config.GetFormat() == types.FormatSafetensors { // Prefer vLLM for safetensors models if vllmBackend, ok := s.backends[vllm.Name]; ok && vllmBackend != nil { return vllmBackend diff --git a/pkg/ollama/http_handler.go b/pkg/ollama/http_handler.go index f061bd32e..1ce017c90 100644 --- a/pkg/ollama/http_handler.go +++ b/pkg/ollama/http_handler.go @@ -209,10 +209,10 @@ func (h *HTTPHandler) handleListModels(w http.ResponseWriter, r *http.Request) { // Extract details from the model details := ModelDetails{ Format: "gguf", // Default to gguf for now - Family: model.Config.Architecture, - Families: []string{model.Config.Architecture}, - ParameterSize: model.Config.Parameters, - QuantizationLevel: model.Config.Quantization, + Family: model.Config.GetArchitecture(), + Families: []string{model.Config.GetArchitecture()}, + ParameterSize: model.Config.GetParameters(), + QuantizationLevel: model.Config.GetQuantization(), } // Parse size from config string to int64 @@ -336,10 +336,10 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { response := ShowResponse{ Details: ModelDetails{ Format: "gguf", - Family: config.Architecture, - Families: []string{config.Architecture}, - ParameterSize: config.Parameters, - QuantizationLevel: config.Quantization, + Family: config.GetArchitecture(), + Families: []string{config.GetArchitecture()}, + ParameterSize: config.GetParameters(), + QuantizationLevel: config.GetQuantization(), }, }