diff --git a/cmd/invoke.go b/cmd/invoke.go index 9ff96d37fa..206dc27dd4 100644 --- a/cmd/invoke.go +++ b/cmd/invoke.go @@ -186,6 +186,10 @@ func runInvoke(cmd *cobra.Command, _ []string, newClient ClientFactory) (err err // Message to send the running function built from parameters gathered // from the user (or defaults) + exts, err := cfg.parseExtensions() + if err != nil { + return err + } m := fn.InvokeMessage{ ID: cfg.ID, Source: cfg.Source, @@ -194,7 +198,7 @@ func runInvoke(cmd *cobra.Command, _ []string, newClient ClientFactory) (err err RequestType: strings.ToUpper(cfg.RequestType), Data: cfg.Data, Format: cfg.Format, - Extensions: cfg.extensionsMap(), + Extensions: exts, } // If --file was specified, use its content for message data @@ -315,15 +319,16 @@ func newInvokeConfig() (cfg invokeConfig, err error) { return } -func (c invokeConfig) extensionsMap() map[string]string { - extensionsMap := make(map[string]string) +func (c invokeConfig) parseExtensions() (map[string]string, error) { + result := make(map[string]string) for _, ext := range c.Extensions { parts := strings.SplitN(ext, "=", 2) - if len(parts) == 2 { - extensionsMap[parts[0]] = parts[1] + if len(parts) != 2 || strings.TrimSpace(parts[0]) == "" { + return nil, fmt.Errorf("invalid --extension %q: must be in key=value format", ext) } + result[parts[0]] = parts[1] } - return extensionsMap + return result, nil } func (c invokeConfig) prompt() (invokeConfig, error) { diff --git a/cmd/invoke_test.go b/cmd/invoke_test.go index 7feae43f03..ce4fbb3f3c 100644 --- a/cmd/invoke_test.go +++ b/cmd/invoke_test.go @@ -78,3 +78,40 @@ func TestInvoke(t *testing.T) { t.Fatal("function was not invoked") } } + +// TestInvokeExtensionsMapValid ensures well-formed key=value extensions are +// parsed correctly, including values that themselves contain '='. +func TestInvokeExtensionsMapValid(t *testing.T) { + c := invokeConfig{Extensions: []string{"key=value", "foo=bar=baz"}} + m, err := c.parseExtensions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if m["key"] != "value" { + t.Fatalf("expected key=value, got %q", m["key"]) + } + if m["foo"] != "bar=baz" { + t.Fatalf("expected foo=bar=baz, got %q", m["foo"]) + } +} + +// TestInvokeExtensionsMapMalformed ensures that extension entries missing '=' +// or with an empty key return an error rather than being silently dropped. +func TestInvokeExtensionsMapMalformed(t *testing.T) { + cases := []struct { + name string + exts []string + }{ + {"missing equals", []string{"valid=ok", "badformat"}}, + {"empty key", []string{"valid=ok", "=nokey"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c := invokeConfig{Extensions: tc.exts} + _, err := c.parseExtensions() + if err == nil { + t.Fatalf("expected error for %v, got nil", tc.exts) + } + }) + } +}