Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pkg/distribution/format/gguf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package format

import (
"fmt"
"math"
"regexp"
"strings"

Expand Down Expand Up @@ -54,14 +55,21 @@ func (g *GGUFFormat) ExtractConfig(paths []string) (types.Config, error) {
return types.Config{Format: types.FormatGGUF}, nil
}

return types.Config{
cfg := types.Config{
Format: types.FormatGGUF,
Parameters: normalizeUnitString(gguf.Metadata().Parameters.String()),
Architecture: strings.TrimSpace(gguf.Metadata().Architecture),
Quantization: strings.TrimSpace(gguf.Metadata().FileType.String()),
Size: normalizeUnitString(gguf.Metadata().Size.String()),
GGUF: extractGGUFMetadata(&gguf.Header),
}, nil
}

if ctx := gguf.Architecture().MaximumContextLength; ctx > 0 && ctx <= math.MaxInt32 {
ctxSize := int32(ctx)
cfg.ContextSize = &ctxSize
}

return cfg, nil
}

var (
Expand Down
53 changes: 51 additions & 2 deletions pkg/distribution/format/safetensors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -115,14 +116,62 @@ func (s *SafetensorsFormat) ExtractConfig(paths []string) (types.Config, error)
architecture = fmt.Sprintf("%v", arch)
}

return types.Config{
cfg := types.Config{
Format: types.FormatSafetensors,
Parameters: formatParameters(params),
Quantization: header.getQuantization(),
Size: formatSize(totalSize),
Architecture: architecture,
Safetensors: header.extractMetadata(),
}, nil
}

if ctx := readContextSizeFromConfigJSON(filepath.Dir(paths[0])); ctx != nil {
cfg.ContextSize = ctx
}

return cfg, nil
}

// contextSizeConfigKeys lists the HuggingFace config.json keys that may hold
// the model's maximum context length, in priority order. This mirrors
// llama.cpp's convert_hf_to_gguf.py (TextModel.set_gguf_parameters), which is
// the canonical HuggingFace-to-GGUF converter.
var contextSizeConfigKeys = []string{
"max_position_embeddings",
"n_ctx",
"n_positions",
"max_length",
"max_sequence_length",
"model_max_length",
}

func readContextSizeFromConfigJSON(dir string) *int32 {
data, err := os.ReadFile(filepath.Join(dir, "config.json"))
if err != nil {
return nil
}
Comment on lines +148 to +152
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

The readContextSizeFromConfigJSON function reads the entire config.json file into memory without a size limit. While model configuration files are typically small, an attacker could provide a path to a very large file named config.json to cause an Out-Of-Memory (OOM) condition. For consistency and security, a size limit should be enforced, similar to the 100MB limit used in parseSafetensorsHeader at line 210.

Suggested change
func readContextSizeFromConfigJSON(dir string) *int32 {
data, err := os.ReadFile(filepath.Join(dir, "config.json"))
if err != nil {
return nil
}
func readContextSizeFromConfigJSON(dir string) *int32 {
f, err := os.Open(filepath.Join(dir, "config.json"))
if err != nil {
return nil
}
defer f.Close()
data, err := io.ReadAll(io.LimitReader(f, 10*1024*1024))
if err != nil {
return nil
}
References
  1. Repository Style Guide: Critical issues include security flaws and must be fixed before merge. (link)
  2. Security rule: Input must be validated at system boundaries to prevent resource exhaustion (DoS).


var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
return nil
}

for _, key := range contextSizeConfigKeys {
v, ok := raw[key]
if !ok {
continue
}
var n int64
if err := json.Unmarshal(v, &n); err != nil {
continue
}
if n <= 0 || n > math.MaxInt32 {
continue
}
ctx := int32(n)
return &ctx
}
return nil
}

const (
Expand Down
107 changes: 107 additions & 0 deletions pkg/distribution/format/safetensors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,110 @@ func TestParseSafetensorsHeader_TruncatedFile(t *testing.T) {
t.Fatal("expected error for truncated safetensors header, got nil")
}
}

func TestReadContextSizeFromConfigJSON(t *testing.T) {
tests := []struct {
name string
contents string
expected *int32
}{
{
name: "max_position_embeddings",
contents: `{"max_position_embeddings": 4096}`,
expected: int32Ptr(4096),
},
{
name: "n_ctx fallback",
contents: `{"n_ctx": 8192}`,
expected: int32Ptr(8192),
},
{
name: "n_positions fallback",
contents: `{"n_positions": 2048}`,
expected: int32Ptr(2048),
},
{
name: "max_length fallback",
contents: `{"max_length": 1024}`,
expected: int32Ptr(1024),
},
{
name: "max_sequence_length fallback",
contents: `{"max_sequence_length": 512}`,
expected: int32Ptr(512),
},
{
name: "model_max_length fallback",
contents: `{"model_max_length": 256}`,
expected: int32Ptr(256),
},
{
name: "max_position_embeddings preferred over fallbacks",
contents: `{"max_position_embeddings": 4096, "n_positions": 2048, "n_ctx": 1024}`,
expected: int32Ptr(4096),
},
{
name: "n_ctx preferred over n_positions",
contents: `{"n_ctx": 8192, "n_positions": 2048}`,
expected: int32Ptr(8192),
},
{
name: "no recognized key",
contents: `{"hidden_size": 768}`,
expected: nil,
},
{
name: "zero value ignored",
contents: `{"max_position_embeddings": 0}`,
expected: nil,
},
{
name: "negative value ignored",
contents: `{"max_position_embeddings": -1}`,
expected: nil,
},
{
name: "value exceeding int32 ignored",
contents: `{"max_position_embeddings": 9999999999}`,
expected: nil,
},
{
name: "non-numeric value falls through",
contents: `{"max_position_embeddings": "not-a-number", "n_positions": 512}`,
expected: int32Ptr(512),
},
{
name: "malformed json",
contents: `{not json}`,
expected: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
if err := os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte(tt.contents), 0o644); err != nil {
t.Fatalf("failed to write config.json: %v", err)
}

got := readContextSizeFromConfigJSON(tmpDir)
if (got == nil) != (tt.expected == nil) {
t.Fatalf("expected nil=%v, got nil=%v (got value: %v)", tt.expected == nil, got == nil, got)
}
if got != nil && *got != *tt.expected {
t.Errorf("expected %d, got %d", *tt.expected, *got)
}
})
}
}

func TestReadContextSizeFromConfigJSON_MissingFile(t *testing.T) {
tmpDir := t.TempDir()
if got := readContextSizeFromConfigJSON(tmpDir); got != nil {
t.Errorf("expected nil for missing config.json, got %d", *got)
}
}

func int32Ptr(v int32) *int32 {
return &v
}