diff --git a/schema/config_test.go b/schema/config_test.go index 391727a..7ad9c71 100644 --- a/schema/config_test.go +++ b/schema/config_test.go @@ -589,3 +589,42 @@ func TestConfig(t *testing.T) { } } } + +func TestValidateConfigParsesModelNotModelConfig(t *testing.T) { + // This test verifies that validateConfig correctly parses the full Model structure, + // not just ModelConfig. Previously, validateConfig unmarshaled into ModelConfig, + // which always succeeded because all fields are optional. + + // Test 1: Incomplete model with only config (should fail) + invalidJSON := `{ + "config": {"paramSize": "8b"} + }` + + err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(invalidJSON)) + if err == nil { + t.Fatalf("expected validation to fail for incomplete model") + } + + // Test 2: Config-only JSON (should fail) + configOnlyJSON := `{ + "paramSize": "8b", + "architecture": "transformer" + }` + + err = schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(configOnlyJSON)) + if err == nil { + t.Fatalf("expected failure for config-only JSON without descriptor/modelfs, but got nil") + } + + // Test 3: Valid full Model (should pass) + validJSON := `{ + "descriptor": {"name": "test-model"}, + "config": {"paramSize": "8b"}, + "modelfs": {"type": "layers", "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"]} + }` + + err = schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(validJSON)) + if err != nil { + t.Fatalf("expected valid Model to pass, but got error: %v", err) + } +} diff --git a/schema/validator.go b/schema/validator.go index d526ec8..35c2e17 100644 --- a/schema/validator.go +++ b/schema/validator.go @@ -114,11 +114,19 @@ var validateByMediaType = map[Validator]validateFunc{ } func validateConfig(buf []byte) error { - mc := v1.ModelConfig{} + var model v1.Model - err := json.Unmarshal(buf, &mc) + err := json.Unmarshal(buf, &model) if err != nil { - return fmt.Errorf("config format mismatch: %w", err) + return fmt.Errorf("invalid model structure: %w", err) + } + + // Minimal structural validation for required fields + if model.Descriptor.Name == "" { + return fmt.Errorf("missing descriptor.name") + } + if len(model.ModelFS.DiffIDs) == 0 { + return fmt.Errorf("missing modelfs.diffIds") } return nil