From 516f9edda0903d12c039c4a457ecd860a62cf9dd Mon Sep 17 00:00:00 2001 From: sathiraumesh Date: Sat, 18 Apr 2026 10:42:24 +0200 Subject: [PATCH] inital draft for extract context size --- pkg/distribution/format/gguf.go | 12 ++- pkg/distribution/format/safetensors.go | 53 +++++++++- pkg/distribution/format/safetensors_test.go | 107 ++++++++++++++++++++ 3 files changed, 168 insertions(+), 4 deletions(-) diff --git a/pkg/distribution/format/gguf.go b/pkg/distribution/format/gguf.go index 2d8260bb2..68c2cbf64 100644 --- a/pkg/distribution/format/gguf.go +++ b/pkg/distribution/format/gguf.go @@ -2,6 +2,7 @@ package format import ( "fmt" + "math" "regexp" "strings" @@ -54,14 +55,21 @@ func (g *GGUFFormat) ExtractConfig(paths []string) (types.Config, error) { return types.Config{Format: types.FormatGGUF}, nil } - return types.Config{ + cfg := types.Config{ Format: types.FormatGGUF, Parameters: normalizeUnitString(gguf.Metadata().Parameters.String()), Architecture: strings.TrimSpace(gguf.Metadata().Architecture), Quantization: strings.TrimSpace(gguf.Metadata().FileType.String()), Size: normalizeUnitString(gguf.Metadata().Size.String()), GGUF: extractGGUFMetadata(&gguf.Header), - }, nil + } + + if ctx := gguf.Architecture().MaximumContextLength; ctx > 0 && ctx <= math.MaxInt32 { + ctxSize := int32(ctx) + cfg.ContextSize = &ctxSize + } + + return cfg, nil } var ( diff --git a/pkg/distribution/format/safetensors.go b/pkg/distribution/format/safetensors.go index fa6563799..d6163fece 100644 --- a/pkg/distribution/format/safetensors.go +++ b/pkg/distribution/format/safetensors.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "math" "os" "path/filepath" "regexp" @@ -115,14 +116,62 @@ func (s *SafetensorsFormat) ExtractConfig(paths []string) (types.Config, error) architecture = fmt.Sprintf("%v", arch) } - return types.Config{ + cfg := types.Config{ Format: types.FormatSafetensors, Parameters: formatParameters(params), Quantization: header.getQuantization(), Size: formatSize(totalSize), Architecture: architecture, Safetensors: header.extractMetadata(), - }, nil + } + + if ctx := readContextSizeFromConfigJSON(filepath.Dir(paths[0])); ctx != nil { + cfg.ContextSize = ctx + } + + return cfg, nil +} + +// contextSizeConfigKeys lists the HuggingFace config.json keys that may hold +// the model's maximum context length, in priority order. This mirrors +// llama.cpp's convert_hf_to_gguf.py (TextModel.set_gguf_parameters), which is +// the canonical HuggingFace-to-GGUF converter. +var contextSizeConfigKeys = []string{ + "max_position_embeddings", + "n_ctx", + "n_positions", + "max_length", + "max_sequence_length", + "model_max_length", +} + +func readContextSizeFromConfigJSON(dir string) *int32 { + data, err := os.ReadFile(filepath.Join(dir, "config.json")) + if err != nil { + return nil + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return nil + } + + for _, key := range contextSizeConfigKeys { + v, ok := raw[key] + if !ok { + continue + } + var n int64 + if err := json.Unmarshal(v, &n); err != nil { + continue + } + if n <= 0 || n > math.MaxInt32 { + continue + } + ctx := int32(n) + return &ctx + } + return nil } const ( diff --git a/pkg/distribution/format/safetensors_test.go b/pkg/distribution/format/safetensors_test.go index fee404384..ade9e70ad 100644 --- a/pkg/distribution/format/safetensors_test.go +++ b/pkg/distribution/format/safetensors_test.go @@ -34,3 +34,110 @@ func TestParseSafetensorsHeader_TruncatedFile(t *testing.T) { t.Fatal("expected error for truncated safetensors header, got nil") } } + +func TestReadContextSizeFromConfigJSON(t *testing.T) { + tests := []struct { + name string + contents string + expected *int32 + }{ + { + name: "max_position_embeddings", + contents: `{"max_position_embeddings": 4096}`, + expected: int32Ptr(4096), + }, + { + name: "n_ctx fallback", + contents: `{"n_ctx": 8192}`, + expected: int32Ptr(8192), + }, + { + name: "n_positions fallback", + contents: `{"n_positions": 2048}`, + expected: int32Ptr(2048), + }, + { + name: "max_length fallback", + contents: `{"max_length": 1024}`, + expected: int32Ptr(1024), + }, + { + name: "max_sequence_length fallback", + contents: `{"max_sequence_length": 512}`, + expected: int32Ptr(512), + }, + { + name: "model_max_length fallback", + contents: `{"model_max_length": 256}`, + expected: int32Ptr(256), + }, + { + name: "max_position_embeddings preferred over fallbacks", + contents: `{"max_position_embeddings": 4096, "n_positions": 2048, "n_ctx": 1024}`, + expected: int32Ptr(4096), + }, + { + name: "n_ctx preferred over n_positions", + contents: `{"n_ctx": 8192, "n_positions": 2048}`, + expected: int32Ptr(8192), + }, + { + name: "no recognized key", + contents: `{"hidden_size": 768}`, + expected: nil, + }, + { + name: "zero value ignored", + contents: `{"max_position_embeddings": 0}`, + expected: nil, + }, + { + name: "negative value ignored", + contents: `{"max_position_embeddings": -1}`, + expected: nil, + }, + { + name: "value exceeding int32 ignored", + contents: `{"max_position_embeddings": 9999999999}`, + expected: nil, + }, + { + name: "non-numeric value falls through", + contents: `{"max_position_embeddings": "not-a-number", "n_positions": 512}`, + expected: int32Ptr(512), + }, + { + name: "malformed json", + contents: `{not json}`, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + if err := os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte(tt.contents), 0o644); err != nil { + t.Fatalf("failed to write config.json: %v", err) + } + + got := readContextSizeFromConfigJSON(tmpDir) + if (got == nil) != (tt.expected == nil) { + t.Fatalf("expected nil=%v, got nil=%v (got value: %v)", tt.expected == nil, got == nil, got) + } + if got != nil && *got != *tt.expected { + t.Errorf("expected %d, got %d", *tt.expected, *got) + } + }) + } +} + +func TestReadContextSizeFromConfigJSON_MissingFile(t *testing.T) { + tmpDir := t.TempDir() + if got := readContextSizeFromConfigJSON(tmpDir); got != nil { + t.Errorf("expected nil for missing config.json, got %d", *got) + } +} + +func int32Ptr(v int32) *int32 { + return &v +}