diff --git a/pkg/importinpututil/import_input.go b/pkg/importinpututil/import_input.go new file mode 100644 index 00000000000..7f4f0a16dda --- /dev/null +++ b/pkg/importinpututil/import_input.go @@ -0,0 +1,86 @@ +package importinpututil + +import ( + "encoding/json" + "fmt" + "reflect" + "sort" + "strings" +) + +// ResolvePathValue resolves either a top-level input key ("count") or a one-level +// dotted object sub-key ("config.apiKey") from import inputs. +func ResolvePathValue(inputs map[string]any, inputPath string) (any, bool) { + top, sub, hasDot := strings.Cut(inputPath, ".") + if !hasDot { + value, ok := inputs[top] + return value, ok + } + topVal, topOK := inputs[top] + if !topOK { + return nil, false + } + obj, isMap := topVal.(map[string]any) + if !isMap { + return nil, false + } + value, ok := obj[sub] + return value, ok +} + +// FormatResolvedValue formats a resolved import input value for textual +// substitution. []any/map[string]any and typed slices/maps are normalized and +// JSON-marshaled, nil returns ("", false), and scalars use fmt.Sprintf("%v", v). +func FormatResolvedValue(value any) (string, bool) { + switch v := value.(type) { + case []any: + return marshalValue(v) + case map[string]any: + return marshalValue(v) + case nil: + return "", false + default: + return formatReflectiveValue(v) + } +} + +func formatReflectiveValue(value any) (string, bool) { + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Slice: + return marshalValue(normalizeSlice(rv)) + case reflect.Map: + return marshalValue(normalizeMap(rv)) + default: + return fmt.Sprintf("%v", value), true + } +} + +func marshalValue(value any) (string, bool) { + b, err := json.Marshal(value) + if err != nil { + return "", false + } + return string(b), true +} + +func normalizeSlice(rv reflect.Value) []any { + normalized := make([]any, rv.Len()) + for i := range rv.Len() { + normalized[i] = rv.Index(i).Interface() + } + return normalized +} + +func normalizeMap(rv reflect.Value) map[string]any { + keys := make([]string, 0, rv.Len()) + for _, key := range rv.MapKeys() { + keys = append(keys, key.String()) + } + sort.Strings(keys) + normalized := make(map[string]any, rv.Len()) + for _, k := range keys { + normalized[k] = rv.MapIndex(reflect.ValueOf(k)).Interface() + } + return normalized +} diff --git a/pkg/importinpututil/import_input_test.go b/pkg/importinpututil/import_input_test.go new file mode 100644 index 00000000000..3a8c14b7a19 --- /dev/null +++ b/pkg/importinpututil/import_input_test.go @@ -0,0 +1,67 @@ +package importinpututil + +import "testing" + +func TestResolvePathValue(t *testing.T) { + inputs := map[string]any{ + "name": "alice", + "config": map[string]any{ + "token": "abc", + }, + "bad": "not-map", + } + + tests := []struct { + name string + path string + want any + found bool + }{ + {name: "top level", path: "name", want: "alice", found: true}, + {name: "dotted path", path: "config.token", want: "abc", found: true}, + {name: "missing top level", path: "missing", found: false}, + {name: "missing dotted key", path: "config.missing", found: false}, + {name: "dotted non map", path: "bad.token", found: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := ResolvePathValue(inputs, tt.path) + if ok != tt.found { + t.Fatalf("ResolvePathValue(%q) found = %v, want %v", tt.path, ok, tt.found) + } + if ok && got != tt.want { + t.Fatalf("ResolvePathValue(%q) = %#v, want %#v", tt.path, got, tt.want) + } + }) + } +} + +func TestFormatResolvedValue(t *testing.T) { + tests := []struct { + name string + value any + want string + ok bool + }{ + {name: "scalar string", value: "hello", want: "hello", ok: true}, + {name: "scalar int", value: 42, want: "42", ok: true}, + {name: "slice any", value: []any{"a", 1}, want: `["a",1]`, ok: true}, + {name: "typed slice", value: []string{"x", "y"}, want: `["x","y"]`, ok: true}, + {name: "map any", value: map[string]any{"k": "v"}, want: `{"k":"v"}`, ok: true}, + {name: "typed map", value: map[string]string{"k": "v"}, want: `{"k":"v"}`, ok: true}, + {name: "nil value", value: nil, want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := FormatResolvedValue(tt.value) + if ok != tt.ok { + t.Fatalf("FormatResolvedValue(%#v) ok = %v, want %v", tt.value, ok, tt.ok) + } + if got != tt.want { + t.Fatalf("FormatResolvedValue(%#v) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} diff --git a/pkg/parser/import_field_extractor.go b/pkg/parser/import_field_extractor.go index 193364a4619..95e1359dfe6 100644 --- a/pkg/parser/import_field_extractor.go +++ b/pkg/parser/import_field_extractor.go @@ -9,10 +9,10 @@ import ( "fmt" "maps" "path/filepath" - "reflect" "regexp" - "sort" "strings" + + "github.com/github/gh-aw/pkg/importinpututil" ) // importAccumulator centralizes the builder/slice/set variables used during @@ -1238,77 +1238,9 @@ func resolveImportInputPath(inputs map[string]any, inputPath string) (string, bo if !ok { return "", false } - return formatResolvedImportInputValue(value) + return importinpututil.FormatResolvedValue(value) } func resolveImportInputValue(inputs map[string]any, inputPath string) (any, bool) { - top, sub, hasDot := strings.Cut(inputPath, ".") - if !hasDot { - value, ok := inputs[top] - return value, ok - } - topVal, topOK := inputs[top] - if !topOK { - return nil, false - } - obj, isMap := topVal.(map[string]any) - if !isMap { - return nil, false - } - value, ok := obj[sub] - return value, ok -} - -func formatResolvedImportInputValue(value any) (string, bool) { - switch v := value.(type) { - case []any: - return marshalImportInputValue(v) - case map[string]any: - return marshalImportInputValue(v) - case nil: - return "", false - default: - return formatReflectiveImportInputValue(v) - } -} - -func formatReflectiveImportInputValue(value any) (string, bool) { - rv := reflect.ValueOf(value) - switch rv.Kind() { - case reflect.Slice: - return marshalImportInputValue(normalizeSliceForImportInput(rv)) - case reflect.Map: - return marshalImportInputValue(normalizeMapForImportInput(rv)) - default: - return fmt.Sprintf("%v", value), true - } -} - -func marshalImportInputValue(value any) (string, bool) { - b, err := json.Marshal(value) - if err != nil { - return "", false - } - return string(b), true -} - -func normalizeSliceForImportInput(rv reflect.Value) []any { - normalized := make([]any, rv.Len()) - for i := range rv.Len() { - normalized[i] = rv.Index(i).Interface() - } - return normalized -} - -func normalizeMapForImportInput(rv reflect.Value) map[string]any { - keys := make([]string, 0, rv.Len()) - for _, key := range rv.MapKeys() { - keys = append(keys, key.String()) - } - sort.Strings(keys) - normalized := make(map[string]any, rv.Len()) - for _, k := range keys { - normalized[k] = rv.MapIndex(reflect.ValueOf(k)).Interface() - } - return normalized + return importinpututil.ResolvePathValue(inputs, inputPath) } diff --git a/pkg/workflow/expression_extraction.go b/pkg/workflow/expression_extraction.go index afab29dd0d9..3492f4e10b8 100644 --- a/pkg/workflow/expression_extraction.go +++ b/pkg/workflow/expression_extraction.go @@ -3,15 +3,14 @@ package workflow import ( "crypto/sha256" "encoding/hex" - "encoding/json" "fmt" "os" - "reflect" "regexp" "sort" "strings" "github.com/github/gh-aw/pkg/console" + "github.com/github/gh-aw/pkg/importinpututil" "github.com/github/gh-aw/pkg/logger" ) @@ -503,45 +502,12 @@ func SubstituteImportInputs(content string, importInputs map[string]any) string // goccy/go-yaml may produce typed slices (e.g. []string) instead of []any, so // a reflection fallback converts any slice kind to []any before JSON marshaling. func marshalImportInputValue(value any) string { - switch v := value.(type) { - case []any: - if b, err := json.Marshal(v); err == nil { - return string(b) - } - case map[string]any: - if b, err := json.Marshal(v); err == nil { - return string(b) - } - case nil: + if formatted, ok := importinpututil.FormatResolvedValue(value); ok { + return formatted + } + if value == nil { // Null import input — return empty string rather than panicking. return "" - default: - // Handle typed slices (e.g. []string) that goccy/go-yaml may produce - // instead of []any, and typed maps. - rv := reflect.ValueOf(v) - switch rv.Kind() { - case reflect.Slice: - normalized := make([]any, rv.Len()) - for i := range rv.Len() { - normalized[i] = rv.Index(i).Interface() - } - if b, err := json.Marshal(normalized); err == nil { - return string(b) - } - case reflect.Map: - keys := make([]string, 0, rv.Len()) - for _, key := range rv.MapKeys() { - keys = append(keys, key.String()) - } - sort.Strings(keys) - normalized := make(map[string]any, rv.Len()) - for _, k := range keys { - normalized[k] = rv.MapIndex(reflect.ValueOf(k)).Interface() - } - if b, err := json.Marshal(normalized); err == nil { - return string(b) - } - } } return fmt.Sprintf("%v", value) } @@ -552,20 +518,5 @@ func marshalImportInputValue(value any) string { // supporting one level of nesting as defined by import-schema object types. // Returns the resolved value and true on success, or nil and false when the path is not found. func resolveImportInputPath(importInputs map[string]any, path string) (any, bool) { - topKey, subKey, hasDot := strings.Cut(path, ".") - if !hasDot { - // Scalar: direct lookup - value, ok := importInputs[topKey] - return value, ok - } - // Object sub-key: one-level deep lookup - topValue, ok := importInputs[topKey] - if !ok { - return nil, false - } - if obj, ok := topValue.(map[string]any); ok { - value, ok := obj[subKey] - return value, ok - } - return nil, false + return importinpututil.ResolvePathValue(importInputs, path) }