-
Notifications
You must be signed in to change notification settings - Fork 32
feat(schema): add minimal architecture_config for transformer models #196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -589,3 +589,51 @@ func TestConfig(t *testing.T) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| func TestArchitectureConfigValid(t *testing.T) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| validJSON := `{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "descriptor": {"name": "test-model"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "config": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "paramSize": "8b", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "architecture_config": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "transformer", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "numLayers": 32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "hiddenSize": 4096, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "numAttentionHeads": 32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "modelfs": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "layers", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(validJSON)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| t.Fatalf("expected valid architecture_config to pass, got error: %v", err) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| func TestArchitectureConfigMissingRequiredField(t *testing.T) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Missing numLayers field | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| invalidJSON := `{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "descriptor": {"name": "test-model"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "config": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "paramSize": "8b", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "architecture_config": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "transformer", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "hiddenSize": 4096, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "numAttentionHeads": 32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "modelfs": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "layers", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(invalidJSON)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if err == nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| t.Fatalf("expected architecture_config with missing numLayers to fail validation") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+593
to
+639
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new tests for The new test should cover not only missing required fields but also other constraints from the schema, such as func TestArchitectureConfig(t *testing.T) {
testCases := []struct {
name string
config string
wantErr bool
}{
{
name: "valid",
config: `{
"descriptor": {"name": "test-model"},
"config": {
"paramSize": "8b",
"architecture_config": {
"type": "transformer", "numLayers": 32, "hiddenSize": 4096, "numAttentionHeads": 32
}
},
"modelfs": {"type": "layers", "diffIds": ["sha256:abc"]}
}`,
wantErr: false,
},
{
name: "missing numLayers",
config: `{
"descriptor": {"name": "test-model"},
"config": {
"paramSize": "8b",
"architecture_config": {
"type": "transformer", "hiddenSize": 4096, "numAttentionHeads": 32
}
},
"modelfs": {"type": "layers", "diffIds": ["sha256:abc"]}
}`,
wantErr: true,
},
{
name: "numLayers is zero",
config: `{
"descriptor": {"name": "test-model"},
"config": {
"paramSize": "8b",
"architecture_config": {
"type": "transformer", "numLayers": 0, "hiddenSize": 4096, "numAttentionHeads": 32
}
},
"modelfs": {"type": "layers", "diffIds": ["sha256:abc"]}
}`,
wantErr: true,
},
// TODO: Add more cases for other required fields and constraints.
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(tc.config))
if (err != nil) != tc.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tc.wantErr)
}
})
}
}
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | |
| } | |
| func TestArchitectureConfigInvalidType(t *testing.T) { | |
| // Invalid type value that is not part of the allowed enum | |
| invalidJSON := `{ | |
| "descriptor": {"name": "test-model"}, | |
| "config": { | |
| "paramSize": "8b", | |
| "architecture_config": { | |
| "type": "invalid-type", | |
| "numLayers": 32, | |
| "hiddenSize": 4096, | |
| "numAttentionHeads": 32 | |
| } | |
| }, | |
| "modelfs": { | |
| "type": "layers", | |
| "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"] | |
| } | |
| }` | |
| err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(invalidJSON)) | |
| if err == nil { | |
| t.Fatalf("expected architecture_config with invalid type to fail validation") | |
| } | |
| } | |
| func TestArchitectureConfigUnknownExtraField(t *testing.T) { | |
| // Unknown extra field under architecture_config should be rejected when additionalProperties is false | |
| invalidJSON := `{ | |
| "descriptor": {"name": "test-model"}, | |
| "config": { | |
| "paramSize": "8b", | |
| "architecture_config": { | |
| "type": "transformer", | |
| "numLayers": 32, | |
| "hiddenSize": 4096, | |
| "numAttentionHeads": 32, | |
| "unknownField": 123 | |
| } | |
| }, | |
| "modelfs": { | |
| "type": "layers", | |
| "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"] | |
| } | |
| }` | |
| err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(invalidJSON)) | |
| if err == nil { | |
| t.Fatalf("expected architecture_config with unknown extra field to fail validation") | |
| } | |
| } |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -42,6 +42,24 @@ type ModelConfig struct { | |||||
|
|
||||||
| // Special capabilities that the model supports | ||||||
| Capabilities *ModelCapabilities `json:"capabilities,omitempty"` | ||||||
|
|
||||||
| // Architecture-specific configuration parameters | ||||||
| ArchitectureConfig *ArchitectureConfig `json:"architecture_config,omitempty"` | ||||||
|
||||||
| ArchitectureConfig *ArchitectureConfig `json:"architecture_config,omitempty"` | |
| ArchitectureConfig *ArchitectureConfig `json:"architectureConfig,omitempty"` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModelConfigalready has anarchitecturefield, andarchitecture_configintroduces another architecture discriminator (type). As written, the schema allows contradictory values (e.g.,architecture: "cnn"alongsidearchitecture_config.type: "transformer"). Consider removing the nestedtypefield or adding a schema dependency so that whenarchitecture_configis present,architectureis required and constrained to"transformer"to keep configs internally consistent.