diff --git a/command_setup.go b/command_setup.go index cac4a30314..fca499cccb 100644 --- a/command_setup.go +++ b/command_setup.go @@ -46,18 +46,33 @@ func (cmd *Command) setupDefaults(osArgs []string) { } if cmd.Reader == nil { - tracef("setting default Reader as os.Stdin (cmd=%[1]q)", cmd.Name) - cmd.Reader = os.Stdin + if cmd.parent != nil && cmd.parent.Reader != nil { + tracef("inheriting Reader from parent (cmd=%[1]q)", cmd.Name) + cmd.Reader = cmd.parent.Reader + } else { + tracef("setting default Reader as os.Stdin (cmd=%[1]q)", cmd.Name) + cmd.Reader = os.Stdin + } } if cmd.Writer == nil { - tracef("setting default Writer as os.Stdout (cmd=%[1]q)", cmd.Name) - cmd.Writer = os.Stdout + if cmd.parent != nil && cmd.parent.Writer != nil { + tracef("inheriting Writer from parent (cmd=%[1]q)", cmd.Name) + cmd.Writer = cmd.parent.Writer + } else { + tracef("setting default Writer as os.Stdout (cmd=%[1]q)", cmd.Name) + cmd.Writer = os.Stdout + } } if cmd.ErrWriter == nil { - tracef("setting default ErrWriter as os.Stderr (cmd=%[1]q)", cmd.Name) - cmd.ErrWriter = os.Stderr + if cmd.parent != nil && cmd.parent.ErrWriter != nil { + tracef("inheriting ErrWriter from parent (cmd=%[1]q)", cmd.Name) + cmd.ErrWriter = cmd.parent.ErrWriter + } else { + tracef("setting default ErrWriter as os.Stderr (cmd=%[1]q)", cmd.Name) + cmd.ErrWriter = os.Stderr + } } if cmd.AllowExtFlags { diff --git a/command_test.go b/command_test.go index 9d77c12f23..b8bd59c5a6 100644 --- a/command_test.go +++ b/command_test.go @@ -2833,6 +2833,32 @@ func TestSetupInitializesOnlyNilWriters(t *testing.T) { assert.Equal(t, cmd.Writer, os.Stdout, "expected a.Writer to be os.Stdout") } +// Regression for #2325. A Writer set on the root command should reach +// the subcommand's Action via c.Writer, not get silently replaced by +// os.Stdout the first time the subcommand runs setupDefaults. +func TestSubcommandInheritsRootWriters(t *testing.T) { + var out, errOut bytes.Buffer + root := &Command{ + Name: "demo", + Writer: &out, + ErrWriter: &errOut, + Commands: []*Command{ + { + Name: "sub", + Action: func(_ context.Context, c *Command) error { + _, _ = fmt.Fprintln(c.Writer, "from sub") + _, _ = fmt.Fprintln(c.ErrWriter, "errors from sub") + return nil + }, + }, + }, + } + + assert.NoError(t, root.Run(buildTestContext(t), []string{"demo", "sub"})) + assert.Equal(t, "from sub\n", out.String()) + assert.Equal(t, "errors from sub\n", errOut.String()) +} + func TestFlagAction(t *testing.T) { now := time.Now().UTC().Truncate(time.Minute) testCases := []struct {