Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 116 additions & 30 deletions cli/cmd/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"context"
"fmt"
"log/slog"
"os"
"strings"
"time"

"github.com/dreadnode/dreadgoad/internal/provider"
"github.com/dreadnode/dreadgoad/internal/validate"
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/term"
)

var validateCmd = &cobra.Command{
Expand All @@ -25,7 +28,10 @@ LLMNR/NBT-NS, GPO abuse, gMSA, LAPS, and services.`,
dreadgoad validate --env staging --verbose
dreadgoad validate --format json --output /tmp/results.json
dreadgoad validate --no-fail
dreadgoad validate --quick`,
dreadgoad validate --quick
dreadgoad validate --plain # disable the live dashboard
dreadgoad validate --poll 5m # rerun every 5 minutes (minimum 1m)
dreadgoad validate --poll never # one-shot (default)`,
RunE: runValidate,
}

Expand All @@ -37,15 +43,87 @@ func init() {
validateCmd.Flags().Bool("verbose", false, "Enable verbose output")
validateCmd.Flags().Bool("no-fail", false, "Don't exit with error on failed checks")
validateCmd.Flags().Bool("quick", false, "Quick validation of critical vulnerabilities only")
validateCmd.Flags().Bool("plain", false, "Disable the live dashboard; stream results to stdout")
validateCmd.Flags().String("poll", "never", "Re-run cadence for the live dashboard (e.g. 1m, 5m, or 'never'; minimum 1m)")

if err := viper.BindPFlag("validate.poll", validateCmd.Flags().Lookup("poll")); err != nil {
panic(fmt.Sprintf("failed to bind validate.poll: %v", err))
}
}

// minPollInterval is the floor for --poll. Shorter cadences don't give the
// validation pass enough time to finish before the next iteration kicks in,
// leaving the dashboard perpetually mid-run.
const minPollInterval = time.Minute

// parsePollInterval interprets the --poll / validate.poll setting. The string
// "never" (case-insensitive) plus other off-style values disable polling and
// return 0; otherwise it must parse as a Go duration of at least
// minPollInterval.
func parsePollInterval(s string) (time.Duration, error) {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "never", "off", "no", "false", "0", "0s":
return 0, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid poll interval %q: %w (use a Go duration like 1m/5m, or 'never')", s, err)
}
if d < minPollInterval {
return 0, fmt.Errorf("poll interval must be at least %s: %s", minPollInterval, s)
}
return d, nil
}

type validateOpts struct {
verbose bool
outputPath string
noFail bool
quick bool
plain bool
pollInterval time.Duration
}

func validateOptsFromFlags(cmd *cobra.Command) (validateOpts, error) {
opts := validateOpts{}
opts.verbose, _ = cmd.Flags().GetBool("verbose")
opts.outputPath, _ = cmd.Flags().GetString("output")
opts.noFail, _ = cmd.Flags().GetBool("no-fail")
opts.quick, _ = cmd.Flags().GetBool("quick")
opts.plain, _ = cmd.Flags().GetBool("plain")

d, err := parsePollInterval(viper.GetString("validate.poll"))
if err != nil {
return opts, err
}
opts.pollInterval = d
return opts, nil
}

// makeRunChecks wraps the validator's check entry-points and any provider drain
// in a single function suitable for TUIConfig.Run / one-shot invocation.
func makeRunChecks(v *validate.Validator, p provider.Provider, quick bool) func(context.Context) {
return func(c context.Context) {
if quick {
v.RunQuickChecks(c)
} else {
v.RunAllChecks(c)
}
// Wait for any provider-side cleanup (e.g. Azure Run Command DELETEs)
// to finish so orphan subresources don't accumulate across runs.
if d, ok := p.(provider.Drainer); ok {
d.Drain()
}
}
}

func runValidate(cmd *cobra.Command, args []string) error {
ctx := context.Background()

verbose, _ := cmd.Flags().GetBool("verbose")
outputPath, _ := cmd.Flags().GetString("output")
noFail, _ := cmd.Flags().GetBool("no-fail")
quick, _ := cmd.Flags().GetBool("quick")
opts, err := validateOptsFromFlags(cmd)
if err != nil {
return err
}

fmt.Println("==========================================")
fmt.Println("GOAD Vulnerability Validation")
Expand All @@ -56,53 +134,61 @@ func runValidate(cmd *cobra.Command, args []string) error {
return err
}

useTUI := !opts.plain && term.IsTerminal(int(os.Stdout.Fd()))
if opts.pollInterval > 0 && !useTUI {
fmt.Fprintf(os.Stderr, "Warning: --poll is ignored without the live dashboard (TTY/--plain)\n")
opts.pollInterval = 0
}

fmt.Printf("Environment: %s\n", infra.Env)
fmt.Printf("Region: %s\n", infra.Region)

v := validate.NewValidator(infra.Provider, infra.Env, verbose, slog.Default(), infra.Lab)
v := validate.NewValidator(infra.Provider, infra.Env, opts.verbose, slog.Default(), infra.Lab)

if err := v.DiscoverHosts(ctx); err != nil {
return fmt.Errorf("discover hosts: %w", err)
}

if quick {
v.RunQuickChecks(ctx)
runChecks := makeRunChecks(v, infra.Provider, opts.quick)
runStart := time.Now()
if useTUI {
if err := validate.RunTUI(ctx, validate.TUIConfig{
Validator: v,
Env: infra.Env,
Region: infra.Region,
Run: runChecks,
PollInterval: opts.pollInterval,
}); err != nil {
return err
}
} else {
v.RunAllChecks(ctx)
}

// Wait for any provider-side cleanup (e.g. Azure Run Command DELETEs)
// to finish so orphan subresources don't accumulate across runs.
if d, ok := infra.Provider.(provider.Drainer); ok {
d.Drain()
runChecks(ctx)
}

report := v.GetReport()

outputPath := opts.outputPath
if outputPath == "" {
outputPath = fmt.Sprintf("/tmp/goad-validation-%s.json", time.Now().Format("20060102-150405"))
}
if err := v.SaveReport(outputPath); err != nil {
fmt.Printf("Warning: could not save report: %v\n", err)
}

fmt.Println("\n==========================================")
fmt.Println("Validation Summary")
fmt.Println("==========================================")
fmt.Printf("Total Checks: %d\n", report.Total)
color.Green("Passed: %d", report.Passed)
color.Red("Failed: %d", report.Failed)
color.Yellow("Warnings: %d", report.Warnings)

if report.Total > 0 {
pct := report.Passed * 100 / report.Total
fmt.Printf("\nSuccess Rate: %d%%\n", pct)
if !useTUI {
fmt.Println()
fmt.Println(validate.RenderSummaryPanel(report, infra.Env, infra.Region, time.Since(runStart), terminalWidth()))
}

fmt.Printf("\nResults saved to: %s\n", outputPath)

if !noFail && report.Failed > 0 {
if !opts.noFail && report.Failed > 0 {
return fmt.Errorf("validation failed with %d errors", report.Failed)
}
return nil
}

func terminalWidth() int {
if w, _, err := term.GetSize(int(os.Stdout.Fd())); err == nil && w > 0 {
return w
}
return 120
}
55 changes: 55 additions & 0 deletions cli/cmd/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package cmd

import (
"strings"
"testing"
"time"
)

func TestParsePollInterval(t *testing.T) {
tests := []struct {
name string
in string
want time.Duration
wantErr string
}{
{name: "empty disables polling", in: "", want: 0},
{name: "never disables polling", in: "never", want: 0},
{name: "NEVER case-insensitive", in: "NEVER", want: 0},
{name: "off disables polling", in: "off", want: 0},
{name: "zero disables polling", in: "0", want: 0},
{name: "0s disables polling", in: "0s", want: 0},
{name: "whitespace trimmed", in: " never ", want: 0},

{name: "exactly 1m allowed", in: "1m", want: time.Minute},
{name: "5m allowed", in: "5m", want: 5 * time.Minute},
{name: "60s equals 1m", in: "60s", want: time.Minute},
{name: "1h allowed", in: "1h", want: time.Hour},

{name: "30s rejected", in: "30s", wantErr: "at least 1m"},
{name: "59s rejected", in: "59s", wantErr: "at least 1m"},
{name: "negative rejected", in: "-1m", wantErr: "at least 1m"},
{name: "garbage rejected", in: "abc", wantErr: "invalid poll interval"},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := parsePollInterval(tc.in)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("parsePollInterval(%q) = %v, want error containing %q", tc.in, got, tc.wantErr)
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Fatalf("parsePollInterval(%q) error = %q, want substring %q", tc.in, err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("parsePollInterval(%q) unexpected error: %v", tc.in, err)
}
if got != tc.want {
t.Fatalf("parsePollInterval(%q) = %v, want %v", tc.in, got, tc.want)
}
})
}
}
2 changes: 1 addition & 1 deletion cli/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.51.0
golang.org/x/net v0.54.0
golang.org/x/term v0.43.0
gopkg.in/yaml.v3 v3.0.1
)

Expand Down Expand Up @@ -108,7 +109,6 @@ require (
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.44.0 // indirect
golang.org/x/term v0.43.0 // indirect
golang.org/x/text v0.37.0 // indirect
gotest.tools/v3 v3.5.2 // indirect
)
Loading
Loading