-
Notifications
You must be signed in to change notification settings - Fork 33
feat(tools): add HuggingFace config converter #197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -589,3 +589,51 @@ func TestConfig(t *testing.T) { | |
| } | ||
| } | ||
| } | ||
|
|
||
| func TestArchitectureConfigValid(t *testing.T) { | ||
| validJSON := `{ | ||
| "descriptor": {"name": "test-model"}, | ||
| "config": { | ||
| "paramSize": "8b", | ||
| "architecture_config": { | ||
| "type": "transformer", | ||
| "numLayers": 32, | ||
| "hiddenSize": 4096, | ||
| "numAttentionHeads": 32 | ||
| } | ||
|
Comment on lines
+596
to
+603
|
||
| }, | ||
| "modelfs": { | ||
| "type": "layers", | ||
| "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"] | ||
| } | ||
| }` | ||
|
|
||
| err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(validJSON)) | ||
| if err != nil { | ||
| t.Fatalf("expected valid architecture_config to pass, got error: %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestArchitectureConfigMissingRequiredField(t *testing.T) { | ||
| // Missing numLayers field | ||
| invalidJSON := `{ | ||
| "descriptor": {"name": "test-model"}, | ||
| "config": { | ||
| "paramSize": "8b", | ||
| "architecture_config": { | ||
| "type": "transformer", | ||
| "hiddenSize": 4096, | ||
| "numAttentionHeads": 32 | ||
| } | ||
| }, | ||
| "modelfs": { | ||
| "type": "layers", | ||
| "diffIds": ["sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"] | ||
| } | ||
| }` | ||
|
|
||
| err := schema.ValidatorMediaTypeModelConfig.Validate(strings.NewReader(invalidJSON)) | ||
| if err == nil { | ||
| t.Fatalf("expected architecture_config with missing numLayers to fail validation") | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -42,6 +42,24 @@ type ModelConfig struct { | |||||
|
|
||||||
| // Special capabilities that the model supports | ||||||
| Capabilities *ModelCapabilities `json:"capabilities,omitempty"` | ||||||
|
|
||||||
| // Architecture-specific configuration parameters | ||||||
| ArchitectureConfig *ArchitectureConfig `json:"architecture_config,omitempty"` | ||||||
|
||||||
| ArchitectureConfig *ArchitectureConfig `json:"architecture_config,omitempty"` | |
| ArchitectureConfig *ArchitectureConfig `json:"architectureConfig,omitempty"` |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,59 @@ | ||||||||||||||||||||
| #!/usr/bin/env python3 | ||||||||||||||||||||
| """Convert HuggingFace config.json to architecture_config format.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import json | ||||||||||||||||||||
| import sys | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| REQUIRED_MAPPINGS = { | ||||||||||||||||||||
| "numLayers": "num_hidden_layers", | ||||||||||||||||||||
| "hiddenSize": "hidden_size", | ||||||||||||||||||||
| "numAttentionHeads": "num_attention_heads", | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def convert_hf_config(hf_config: dict) -> dict: | ||||||||||||||||||||
| """Convert HuggingFace config to architecture_config format.""" | ||||||||||||||||||||
| arch_config = {"type": "transformer"} | ||||||||||||||||||||
|
|
||||||||||||||||||||
| for arch_key, hf_key in REQUIRED_MAPPINGS.items(): | ||||||||||||||||||||
| if hf_key not in hf_config: | ||||||||||||||||||||
| raise ValueError(f"missing required field: {hf_key}") | ||||||||||||||||||||
| value = hf_config[hf_key] | ||||||||||||||||||||
| if not isinstance(value, int) or isinstance(value, bool): | ||||||||||||||||||||
|
Comment on lines
+15
to
+23
|
||||||||||||||||||||
| raise ValueError(f"field {hf_key} must be an integer, got {type(value).__name__}") | ||||||||||||||||||||
| if value < 1: | ||||||||||||||||||||
| raise ValueError(f"field {hf_key} must be >= 1, got {value}") | ||||||||||||||||||||
| arch_config[arch_key] = value | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return arch_config | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def main(): | ||||||||||||||||||||
| if len(sys.argv) != 2: | ||||||||||||||||||||
| print(f"usage: {sys.argv[0]} <config.json>", file=sys.stderr) | ||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| config_path = sys.argv[1] | ||||||||||||||||||||
|
Comment on lines
+33
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For more robust and maintainable command-line argument parsing, consider using the You will need to add
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| try: | ||||||||||||||||||||
| with open(config_path, "r") as f: | ||||||||||||||||||||
| hf_config = json.load(f) | ||||||||||||||||||||
| except FileNotFoundError: | ||||||||||||||||||||
| print(f"error: file not found: {config_path}", file=sys.stderr) | ||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||
| except json.JSONDecodeError as e: | ||||||||||||||||||||
| print(f"error: invalid JSON: {e}", file=sys.stderr) | ||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| try: | ||||||||||||||||||||
| arch_config = convert_hf_config(hf_config) | ||||||||||||||||||||
| except ValueError as e: | ||||||||||||||||||||
| print(f"error: {e}", file=sys.stderr) | ||||||||||||||||||||
| sys.exit(1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| print(json.dumps(arch_config, indent=2, sort_keys=True)) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||
| main() | ||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,146 @@ | ||||||
| #!/usr/bin/env python3 | ||||||
| """Tests for hf_to_arch.py""" | ||||||
|
|
||||||
| import json | ||||||
| import subprocess | ||||||
| import sys | ||||||
| import tempfile | ||||||
| import os | ||||||
|
|
||||||
| SCRIPT_PATH = os.path.join(os.path.dirname(__file__), "hf_to_arch.py") | ||||||
|
|
||||||
|
|
||||||
| def run_script(config_content: str) -> tuple: | ||||||
| """Run hf_to_arch.py with given config content, return (exitcode, stdout, stderr).""" | ||||||
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: | ||||||
| f.write(config_content) | ||||||
| f.flush() | ||||||
| temp_path = f.name | ||||||
|
|
||||||
| try: | ||||||
| result = subprocess.run( | ||||||
| [sys.executable, SCRIPT_PATH, temp_path], | ||||||
| capture_output=True, | ||||||
| text=True, | ||||||
| ) | ||||||
| return result.returncode, result.stdout, result.stderr | ||||||
| finally: | ||||||
| os.unlink(temp_path) | ||||||
|
|
||||||
|
|
||||||
| def test_valid_config(): | ||||||
| """Valid HuggingFace config produces correct output.""" | ||||||
| config = json.dumps({ | ||||||
| "num_hidden_layers": 32, | ||||||
| "hidden_size": 4096, | ||||||
| "num_attention_heads": 32, | ||||||
| "vocab_size": 32000, | ||||||
| }) | ||||||
|
|
||||||
| exitcode, stdout, stderr = run_script(config) | ||||||
|
|
||||||
| assert exitcode == 0, f"expected exit 0, got {exitcode}: {stderr}" | ||||||
| output = json.loads(stdout) | ||||||
| assert output == { | ||||||
| "type": "transformer", | ||||||
| "numLayers": 32, | ||||||
| "hiddenSize": 4096, | ||||||
| "numAttentionHeads": 32, | ||||||
| }, f"unexpected output: {output}" | ||||||
| print("PASS: test_valid_config") | ||||||
|
|
||||||
|
|
||||||
| def test_missing_field(): | ||||||
| """Missing required field produces error.""" | ||||||
| config = json.dumps({ | ||||||
| "num_hidden_layers": 32, | ||||||
| "hidden_size": 4096, | ||||||
| }) | ||||||
|
|
||||||
| exitcode, stdout, stderr = run_script(config) | ||||||
|
|
||||||
| assert exitcode != 0, "expected non-zero exit for missing field" | ||||||
| assert "num_attention_heads" in stderr, f"error should mention missing field: {stderr}" | ||||||
| print("PASS: test_missing_field") | ||||||
|
|
||||||
|
|
||||||
| def test_invalid_json(): | ||||||
| """Invalid JSON produces error.""" | ||||||
| exitcode, stdout, stderr = run_script("not valid json {") | ||||||
|
|
||||||
| assert exitcode != 0, "expected non-zero exit for invalid JSON" | ||||||
| assert "invalid JSON" in stderr.lower() or "json" in stderr.lower(), f"error should mention JSON: {stderr}" | ||||||
|
||||||
| assert "invalid JSON" in stderr.lower() or "json" in stderr.lower(), f"error should mention JSON: {stderr}" | |
| assert "invalid json" in stderr.lower() or "json" in stderr.lower(), f"error should mention JSON: {stderr}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this test script is functional, structuring it with Python's standard unittest framework would make it more robust and align with common testing practices. unittest provides a test runner, test discovery, and a richer set of assertion methods (e.g., self.assertEqual(), self.assertIn()). This would also remove the need for a custom main function to run tests and manual print statements for test status.
For example, you could structure your tests in a class:
import unittest
# ... other imports
class HFToArchTests(unittest.TestCase):
def test_valid_config(self):
# ... test logic ...
self.assertEqual(exitcode, 0, f"...")
# ...
if __name__ == '__main__':
unittest.main()This approach is more maintainable and scalable as more tests are added.
Copilot
AI
Mar 26, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These Python tests are not picked up by the repository's default make test (which currently runs go test ./... only). If these are meant to run in CI, consider integrating them into the test target/CI workflow (or converting to a Go test that shells out to the tool), otherwise the PR's "Tests included" claim may be misleading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new property name
architecture_configintroduces snake_case into the schema, while other ModelConfig fields use lowerCamelCase (e.g.,paramSize,docURL,diffIds). Consider renaming this toarchitectureConfig(and updating the Go struct tags/tests/tools accordingly) to keep JSON field naming consistent.