From ed59598bfdacbf9605c748d99b1d3e818148b5cb Mon Sep 17 00:00:00 2001 From: Jayson Grace Date: Sun, 10 May 2026 15:04:19 -0600 Subject: [PATCH] feat: add live dashboard TUI and polling for validate command **Added:** - Introduced `--plain` flag to disable live dashboard and stream results to stdout in validate command - Added `--poll` flag to rerun validation on a configurable cadence in validate command - Implemented terminal TUI dashboard for validation runs using Bubbletea and Lipgloss, including real-time category/status breakdown and polling support - `internal/validate/tui.go` - Added static summary panel rendering for post-run summary output - `internal/validate/panel.go` - Created table-driven tests for poll interval parsing and dashboard rendering - `cli/cmd/validate_test.go`, `internal/validate/panel_test.go`, `internal/validate/tui_test.go` **Changed:** - Refactored validate command to support TUI dashboard, polling, and new CLI flags - Replaced direct color output with summary panel rendering when not using TUI - Updated output logic to separate TUI/console flows, including warnings for ignored flags **Removed:** - Removed dependency on `github.com/fatih/color` in favor of Lipgloss styling for output - Eliminated redundant/legacy color output logic from validate summary --- cli/cmd/validate.go | 146 ++++++++-- cli/cmd/validate_test.go | 55 ++++ cli/go.mod | 2 +- cli/internal/validate/panel.go | 302 ++++++++++++++++++++ cli/internal/validate/panel_test.go | 105 +++++++ cli/internal/validate/tui.go | 409 ++++++++++++++++++++++++++++ cli/internal/validate/tui_test.go | 89 ++++++ cli/internal/validate/validator.go | 77 ++++++ 8 files changed, 1154 insertions(+), 31 deletions(-) create mode 100644 cli/cmd/validate_test.go create mode 100644 cli/internal/validate/panel.go create mode 100644 cli/internal/validate/panel_test.go create mode 100644 cli/internal/validate/tui.go create mode 100644 cli/internal/validate/tui_test.go diff --git a/cli/cmd/validate.go b/cli/cmd/validate.go index e7f10f13..5bbb94dd 100644 --- a/cli/cmd/validate.go +++ b/cli/cmd/validate.go @@ -4,12 +4,15 @@ import ( "context" "fmt" "log/slog" + "os" + "strings" "time" "github.com/dreadnode/dreadgoad/internal/provider" "github.com/dreadnode/dreadgoad/internal/validate" - "github.com/fatih/color" "github.com/spf13/cobra" + "github.com/spf13/viper" + "golang.org/x/term" ) var validateCmd = &cobra.Command{ @@ -25,7 +28,10 @@ LLMNR/NBT-NS, GPO abuse, gMSA, LAPS, and services.`, dreadgoad validate --env staging --verbose dreadgoad validate --format json --output /tmp/results.json dreadgoad validate --no-fail - dreadgoad validate --quick`, + dreadgoad validate --quick + dreadgoad validate --plain # disable the live dashboard + dreadgoad validate --poll 5m # rerun every 5 minutes (minimum 1m) + dreadgoad validate --poll never # one-shot (default)`, RunE: runValidate, } @@ -37,15 +43,87 @@ func init() { validateCmd.Flags().Bool("verbose", false, "Enable verbose output") validateCmd.Flags().Bool("no-fail", false, "Don't exit with error on failed checks") validateCmd.Flags().Bool("quick", false, "Quick validation of critical vulnerabilities only") + validateCmd.Flags().Bool("plain", false, "Disable the live dashboard; stream results to stdout") + validateCmd.Flags().String("poll", "never", "Re-run cadence for the live dashboard (e.g. 1m, 5m, or 'never'; minimum 1m)") + + if err := viper.BindPFlag("validate.poll", validateCmd.Flags().Lookup("poll")); err != nil { + panic(fmt.Sprintf("failed to bind validate.poll: %v", err)) + } +} + +// minPollInterval is the floor for --poll. Shorter cadences don't give the +// validation pass enough time to finish before the next iteration kicks in, +// leaving the dashboard perpetually mid-run. +const minPollInterval = time.Minute + +// parsePollInterval interprets the --poll / validate.poll setting. The string +// "never" (case-insensitive) plus other off-style values disable polling and +// return 0; otherwise it must parse as a Go duration of at least +// minPollInterval. +func parsePollInterval(s string) (time.Duration, error) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "", "never", "off", "no", "false", "0", "0s": + return 0, nil + } + d, err := time.ParseDuration(s) + if err != nil { + return 0, fmt.Errorf("invalid poll interval %q: %w (use a Go duration like 1m/5m, or 'never')", s, err) + } + if d < minPollInterval { + return 0, fmt.Errorf("poll interval must be at least %s: %s", minPollInterval, s) + } + return d, nil +} + +type validateOpts struct { + verbose bool + outputPath string + noFail bool + quick bool + plain bool + pollInterval time.Duration +} + +func validateOptsFromFlags(cmd *cobra.Command) (validateOpts, error) { + opts := validateOpts{} + opts.verbose, _ = cmd.Flags().GetBool("verbose") + opts.outputPath, _ = cmd.Flags().GetString("output") + opts.noFail, _ = cmd.Flags().GetBool("no-fail") + opts.quick, _ = cmd.Flags().GetBool("quick") + opts.plain, _ = cmd.Flags().GetBool("plain") + + d, err := parsePollInterval(viper.GetString("validate.poll")) + if err != nil { + return opts, err + } + opts.pollInterval = d + return opts, nil +} + +// makeRunChecks wraps the validator's check entry-points and any provider drain +// in a single function suitable for TUIConfig.Run / one-shot invocation. +func makeRunChecks(v *validate.Validator, p provider.Provider, quick bool) func(context.Context) { + return func(c context.Context) { + if quick { + v.RunQuickChecks(c) + } else { + v.RunAllChecks(c) + } + // Wait for any provider-side cleanup (e.g. Azure Run Command DELETEs) + // to finish so orphan subresources don't accumulate across runs. + if d, ok := p.(provider.Drainer); ok { + d.Drain() + } + } } func runValidate(cmd *cobra.Command, args []string) error { ctx := context.Background() - verbose, _ := cmd.Flags().GetBool("verbose") - outputPath, _ := cmd.Flags().GetString("output") - noFail, _ := cmd.Flags().GetBool("no-fail") - quick, _ := cmd.Flags().GetBool("quick") + opts, err := validateOptsFromFlags(cmd) + if err != nil { + return err + } fmt.Println("==========================================") fmt.Println("GOAD Vulnerability Validation") @@ -56,29 +134,39 @@ func runValidate(cmd *cobra.Command, args []string) error { return err } + useTUI := !opts.plain && term.IsTerminal(int(os.Stdout.Fd())) + if opts.pollInterval > 0 && !useTUI { + fmt.Fprintf(os.Stderr, "Warning: --poll is ignored without the live dashboard (TTY/--plain)\n") + opts.pollInterval = 0 + } + fmt.Printf("Environment: %s\n", infra.Env) fmt.Printf("Region: %s\n", infra.Region) - v := validate.NewValidator(infra.Provider, infra.Env, verbose, slog.Default(), infra.Lab) + v := validate.NewValidator(infra.Provider, infra.Env, opts.verbose, slog.Default(), infra.Lab) if err := v.DiscoverHosts(ctx); err != nil { return fmt.Errorf("discover hosts: %w", err) } - if quick { - v.RunQuickChecks(ctx) + runChecks := makeRunChecks(v, infra.Provider, opts.quick) + runStart := time.Now() + if useTUI { + if err := validate.RunTUI(ctx, validate.TUIConfig{ + Validator: v, + Env: infra.Env, + Region: infra.Region, + Run: runChecks, + PollInterval: opts.pollInterval, + }); err != nil { + return err + } } else { - v.RunAllChecks(ctx) - } - - // Wait for any provider-side cleanup (e.g. Azure Run Command DELETEs) - // to finish so orphan subresources don't accumulate across runs. - if d, ok := infra.Provider.(provider.Drainer); ok { - d.Drain() + runChecks(ctx) } report := v.GetReport() - + outputPath := opts.outputPath if outputPath == "" { outputPath = fmt.Sprintf("/tmp/goad-validation-%s.json", time.Now().Format("20060102-150405")) } @@ -86,23 +174,21 @@ func runValidate(cmd *cobra.Command, args []string) error { fmt.Printf("Warning: could not save report: %v\n", err) } - fmt.Println("\n==========================================") - fmt.Println("Validation Summary") - fmt.Println("==========================================") - fmt.Printf("Total Checks: %d\n", report.Total) - color.Green("Passed: %d", report.Passed) - color.Red("Failed: %d", report.Failed) - color.Yellow("Warnings: %d", report.Warnings) - - if report.Total > 0 { - pct := report.Passed * 100 / report.Total - fmt.Printf("\nSuccess Rate: %d%%\n", pct) + if !useTUI { + fmt.Println() + fmt.Println(validate.RenderSummaryPanel(report, infra.Env, infra.Region, time.Since(runStart), terminalWidth())) } - fmt.Printf("\nResults saved to: %s\n", outputPath) - if !noFail && report.Failed > 0 { + if !opts.noFail && report.Failed > 0 { return fmt.Errorf("validation failed with %d errors", report.Failed) } return nil } + +func terminalWidth() int { + if w, _, err := term.GetSize(int(os.Stdout.Fd())); err == nil && w > 0 { + return w + } + return 120 +} diff --git a/cli/cmd/validate_test.go b/cli/cmd/validate_test.go new file mode 100644 index 00000000..f08b4dac --- /dev/null +++ b/cli/cmd/validate_test.go @@ -0,0 +1,55 @@ +package cmd + +import ( + "strings" + "testing" + "time" +) + +func TestParsePollInterval(t *testing.T) { + tests := []struct { + name string + in string + want time.Duration + wantErr string + }{ + {name: "empty disables polling", in: "", want: 0}, + {name: "never disables polling", in: "never", want: 0}, + {name: "NEVER case-insensitive", in: "NEVER", want: 0}, + {name: "off disables polling", in: "off", want: 0}, + {name: "zero disables polling", in: "0", want: 0}, + {name: "0s disables polling", in: "0s", want: 0}, + {name: "whitespace trimmed", in: " never ", want: 0}, + + {name: "exactly 1m allowed", in: "1m", want: time.Minute}, + {name: "5m allowed", in: "5m", want: 5 * time.Minute}, + {name: "60s equals 1m", in: "60s", want: time.Minute}, + {name: "1h allowed", in: "1h", want: time.Hour}, + + {name: "30s rejected", in: "30s", wantErr: "at least 1m"}, + {name: "59s rejected", in: "59s", wantErr: "at least 1m"}, + {name: "negative rejected", in: "-1m", wantErr: "at least 1m"}, + {name: "garbage rejected", in: "abc", wantErr: "invalid poll interval"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parsePollInterval(tc.in) + if tc.wantErr != "" { + if err == nil { + t.Fatalf("parsePollInterval(%q) = %v, want error containing %q", tc.in, got, tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("parsePollInterval(%q) error = %q, want substring %q", tc.in, err.Error(), tc.wantErr) + } + return + } + if err != nil { + t.Fatalf("parsePollInterval(%q) unexpected error: %v", tc.in, err) + } + if got != tc.want { + t.Fatalf("parsePollInterval(%q) = %v, want %v", tc.in, got, tc.want) + } + }) + } +} diff --git a/cli/go.mod b/cli/go.mod index 2d846540..404acbfd 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -24,6 +24,7 @@ require ( go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.51.0 golang.org/x/net v0.54.0 + golang.org/x/term v0.43.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -108,7 +109,6 @@ require ( github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.44.0 // indirect - golang.org/x/term v0.43.0 // indirect golang.org/x/text v0.37.0 // indirect gotest.tools/v3 v3.5.2 // indirect ) diff --git a/cli/internal/validate/panel.go b/cli/internal/validate/panel.go new file mode 100644 index 00000000..f7382875 --- /dev/null +++ b/cli/internal/validate/panel.go @@ -0,0 +1,302 @@ +package validate + +import ( + "fmt" + "sort" + "strings" + "time" + + "github.com/charmbracelet/lipgloss" +) + +// Dreadnode color palette (mirrors cli/internal/scoreboard/tui.go). +const ( + cSuccess = "#68c147" + cError = "#e44f4f" + cWarning = "#c8ac4a" + cInfo = "#4689bf" + cBrand = "#ca5e44" + cFG = "#e2e7ec" + cFGMuted = "#9da0a5" + cFGFaintest = "#686d73" +) + +var ( + styleTitle = lipgloss.NewStyle().Foreground(lipgloss.Color(cBrand)).Bold(true) + styleBorder = lipgloss.NewStyle().Foreground(lipgloss.Color(cBrand)) + styleGroupHdr = lipgloss.NewStyle().Foreground(lipgloss.Color(cBrand)).Bold(true) + styleSep = lipgloss.NewStyle().Foreground(lipgloss.Color(cFGFaintest)) + styleMuted = lipgloss.NewStyle().Foreground(lipgloss.Color(cFGMuted)) + styleFaint = lipgloss.NewStyle().Foreground(lipgloss.Color(cFGFaintest)) + styleFG = lipgloss.NewStyle().Foreground(lipgloss.Color(cFG)) + styleOK = lipgloss.NewStyle().Foreground(lipgloss.Color(cSuccess)).Bold(true) + styleWarn = lipgloss.NewStyle().Foreground(lipgloss.Color(cWarning)).Bold(true) + styleErr = lipgloss.NewStyle().Foreground(lipgloss.Color(cError)).Bold(true) + styleInfo = lipgloss.NewStyle().Foreground(lipgloss.Color(cInfo)).Bold(true) +) + +type categoryStats struct { + name string + pass int + fail int + warn int + total int // pass + fail + warn (skip/info excluded from totals) + others int // skip, info, other +} + +// RenderSummaryPanel returns a static status board for a completed validation +// run, styled to match the live scoreboard. +func RenderSummaryPanel(r *Report, env, region string, elapsed time.Duration, width int) string { + if width <= 0 { + width = 120 + } + innerWidth := width - 4 + if innerWidth < 40 { + innerWidth = 40 + } + + header := renderValidateHeader(r, env, region, elapsed, innerWidth) + + cats := aggregateByCategory(r) + colWidth := (innerWidth - 2) / 2 + if colWidth < 30 { + colWidth = 30 + } + left, right := splitForColumns(cats) + leftCol := renderCategoryColumn(left, colWidth) + rightCol := renderCategoryColumn(right, colWidth) + cols := lipgloss.JoinHorizontal(lipgloss.Top, leftCol, " ", rightCol) + + parts := []string{ + header, + "", + styleGroupHdr.Render(fmt.Sprintf(" CHECK RESULTS (%d/%d)", r.Passed, totalForRate(r))), + "", + cols, + } + if rate, ok := successRate(r); ok { + parts = append(parts, "", styleMuted.Render(fmt.Sprintf(" Success rate: %d%%", rate))) + } + + return panelWithTitle("DreadGOAD VALIDATION", strings.Join(parts, "\n"), width) +} + +func renderValidateHeader(r *Report, env, region string, elapsed time.Duration, width int) string { + left := strings.Builder{} + writeStat := func(label string, n int, valStyle lipgloss.Style, first bool) { + if !first { + left.WriteString(styleSep.Render(" | ")) + } + left.WriteString(styleGroupHdr.Render(label + " ")) + left.WriteString(valStyle.Render(fmt.Sprintf("%d", n))) + } + writeStat("PASSED", r.Passed, styleOK, true) + writeStat("FAILED", r.Failed, styleErr, false) + writeStat("WARNED", r.Warnings, styleWarn, false) + writeStat("TOTAL", totalForRate(r), styleInfo, false) + + rightParts := []string{} + if env != "" { + rightParts = append(rightParts, env) + } + if region != "" { + rightParts = append(rightParts, region) + } + if elapsed > 0 { + rightParts = append(rightParts, formatElapsed(elapsed)) + } + right := styleMuted.Render(strings.Join(rightParts, " | ")) + + leftStr := left.String() + pad := width - lipgloss.Width(leftStr) - lipgloss.Width(right) + if pad < 1 { + pad = 1 + } + return leftStr + strings.Repeat(" ", pad) + right +} + +func aggregateByCategory(r *Report) []categoryStats { + if r == nil { + return nil + } + idx := map[string]*categoryStats{} + for _, res := range r.Results { + c, ok := idx[res.Category] + if !ok { + c = &categoryStats{name: res.Category} + idx[res.Category] = c + } + switch res.Status { + case "PASS": + c.pass++ + c.total++ + case "FAIL": + c.fail++ + c.total++ + case "WARN": + c.warn++ + c.total++ + default: + c.others++ + } + } + out := make([]categoryStats, 0, len(idx)) + for _, c := range idx { + out = append(out, *c) + } + sort.Slice(out, func(i, j int) bool { return out[i].name < out[j].name }) + return out +} + +func splitForColumns(cats []categoryStats) ([]categoryStats, []categoryStats) { + if len(cats) == 0 { + return nil, nil + } + mid := (len(cats) + 1) / 2 + return cats[:mid], cats[mid:] +} + +func renderCategoryColumn(cats []categoryStats, width int) string { + if len(cats) == 0 { + return "" + } + iconWidth := 4 + countsWidth := 8 + detailWidth := 10 + nameWidth := width - iconWidth - countsWidth - detailWidth - 2 + if nameWidth < 10 { + nameWidth = 10 + } + + rows := make([]string, 0, len(cats)) + for _, c := range cats { + var iconCell string + switch { + case c.fail > 0: + iconCell = styleErr.Render("[x] ") + case c.warn > 0: + iconCell = styleWarn.Render("[!] ") + case c.total > 0: + iconCell = styleOK.Render("[v] ") + default: + iconCell = styleFaint.Render("[ ] ") + } + + var nameCell string + switch { + case c.fail > 0, c.total > 0: + nameCell = styleFG.Render(truncate(c.name, nameWidth)) + default: + nameCell = styleFaint.Render(truncate(c.name, nameWidth)) + } + nameCell = padRight(nameCell, nameWidth) + + counts := fmt.Sprintf("%d/%d", c.pass, c.total) + countsCell := padRight(styleMuted.Render(counts), countsWidth) + + var detail string + switch { + case c.fail > 0 && c.warn > 0: + detail = fmt.Sprintf("x%d !%d", c.fail, c.warn) + case c.fail > 0: + detail = fmt.Sprintf("x%d", c.fail) + case c.warn > 0: + detail = fmt.Sprintf("!%d", c.warn) + } + var detailCell string + switch { + case c.fail > 0: + detailCell = styleErr.Render(detail) + case c.warn > 0: + detailCell = styleWarn.Render(detail) + } + detailCell = padRight(detailCell, detailWidth) + + rows = append(rows, " "+iconCell+nameCell+countsCell+detailCell) + } + return strings.Join(rows, "\n") +} + +func totalForRate(r *Report) int { + if r == nil { + return 0 + } + return r.Passed + r.Failed + r.Warnings +} + +func successRate(r *Report) (int, bool) { + t := totalForRate(r) + if t == 0 { + return 0, false + } + return r.Passed * 100 / t, true +} + +func formatElapsed(d time.Duration) string { + if d < 0 { + d = 0 + } + h := int(d.Hours()) + m := int(d.Minutes()) % 60 + s := int(d.Seconds()) % 60 + return fmt.Sprintf("%d:%02d:%02d", h, m, s) +} + +// panelWithTitle frames body in a rounded border with title embedded in the +// top edge. Mirrors the scoreboard implementation so the validate summary and +// scoreboard share a consistent visual frame. +func panelWithTitle(title, body string, width int) string { + innerWidth := width - 4 + if innerWidth < 1 { + innerWidth = 1 + } + + titleText := " " + title + " " + titleVis := lipgloss.Width(titleText) + leadDashes := 2 + trailDashes := innerWidth + 2 - leadDashes - titleVis + if trailDashes < 1 { + trailDashes = 1 + } + top := styleBorder.Render("╭"+strings.Repeat("─", leadDashes)) + + styleTitle.Render(titleText) + + styleBorder.Render(strings.Repeat("─", trailDashes)+"╮") + bottom := styleBorder.Render("╰" + strings.Repeat("─", innerWidth+2) + "╯") + + var rows []string + rows = append(rows, top) + for _, line := range strings.Split(body, "\n") { + pad := innerWidth - lipgloss.Width(line) + if pad < 0 { + line = truncate(line, innerWidth) + pad = 0 + } + rows = append(rows, styleBorder.Render("│")+" "+line+strings.Repeat(" ", pad)+" "+styleBorder.Render("│")) + } + rows = append(rows, bottom) + return strings.Join(rows, "\n") +} + +func padRight(s string, w int) string { + pad := w - lipgloss.Width(s) + if pad <= 0 { + return s + } + return s + strings.Repeat(" ", pad) +} + +func truncate(s string, w int) string { + if w <= 0 { + return "" + } + if lipgloss.Width(s) <= w { + return s + } + if w <= 1 { + return s[:1] + } + if w > len(s) { + return s + } + return s[:w-1] + "…" +} diff --git a/cli/internal/validate/panel_test.go b/cli/internal/validate/panel_test.go new file mode 100644 index 00000000..9a8dd5d1 --- /dev/null +++ b/cli/internal/validate/panel_test.go @@ -0,0 +1,105 @@ +package validate + +import ( + "strings" + "testing" + "time" +) + +func sampleReport() *Report { + rs := []Result{ + {Status: "PASS", Category: "Discovery", Name: "Found DC01"}, + {Status: "PASS", Category: "Discovery", Name: "Found SRV01"}, + {Status: "PASS", Category: "Credentials", Name: "no plaintext creds in sysvol"}, + {Status: "FAIL", Category: "Credentials", Name: "autologon registry"}, + {Status: "PASS", Category: "Kerberos", Name: "kerberoastable spn"}, + {Status: "WARN", Category: "Kerberos", Name: "asrep roastable account"}, + {Status: "PASS", Category: "MSSQL", Name: "linked server"}, + {Status: "PASS", Category: "MSSQL", Name: "xp_cmdshell"}, + {Status: "FAIL", Category: "MSSQL", Name: "sysadmin role"}, + {Status: "PASS", Category: "ADCS", Name: "vulnerable template enabled"}, + {Status: "FAIL", Category: "ADCS-ESC1", Name: "ESC1 template"}, + {Status: "WARN", Category: "ADCS-ESC6", Name: "EDITF flag"}, + {Status: "PASS", Category: "ACL", Name: "WriteOwner on Domain Admins"}, + {Status: "PASS", Category: "Trusts", Name: "forest trust exists"}, + {Status: "PASS", Category: "Services", Name: "unquoted service path"}, + {Status: "WARN", Category: "LLMNR", Name: "LLMNR enabled"}, + {Status: "PASS", Category: "Shares", Name: "open share"}, + {Status: "PASS", Category: "GPO", Name: "GPO abuse path"}, + {Status: "PASS", Category: "gMSA", Name: "gMSA readable"}, + {Status: "PASS", Category: "LAPS", Name: "LAPS readable by group"}, + } + r := &Report{Results: rs} + for _, res := range rs { + switch res.Status { + case "PASS": + r.Passed++ + case "FAIL": + r.Failed++ + case "WARN": + r.Warnings++ + } + } + r.Total = r.Passed + r.Failed + r.Warnings + return r +} + +func TestRenderSummaryPanel_Snapshot(t *testing.T) { + r := sampleReport() + out := RenderSummaryPanel(r, "prod-east", "us-east-1", 8*time.Minute+45*time.Second, 120) + + for _, want := range []string{ + "DreadGOAD VALIDATION", + "PASSED", + "FAILED", + "WARNED", + "TOTAL", + "Credentials", + "MSSQL", + "prod-east", + "us-east-1", + "Success rate:", + } { + if !strings.Contains(out, want) { + t.Errorf("rendered panel missing %q\n%s", want, out) + } + } + if t.Failed() || testing.Verbose() { + t.Logf("\n%s", out) + } +} + +func TestRenderSummaryPanel_Empty(t *testing.T) { + r := &Report{} + out := RenderSummaryPanel(r, "", "", 0, 100) + if !strings.Contains(out, "DreadGOAD VALIDATION") { + t.Errorf("empty panel missing title\n%s", out) + } + if strings.Contains(out, "Success rate:") { + t.Errorf("empty panel should not show success rate\n%s", out) + } +} + +func TestAggregateByCategory(t *testing.T) { + r := sampleReport() + cats := aggregateByCategory(r) + got := map[string]categoryStats{} + for _, c := range cats { + got[c.name] = c + } + for _, prev := range cats { + if cats[len(cats)-1].name == "" { + t.Fatal("found empty category name") + } + _ = prev + } + if c := got["MSSQL"]; c.pass != 2 || c.fail != 1 || c.total != 3 { + t.Errorf("MSSQL aggregation wrong: %+v", c) + } + if c := got["Kerberos"]; c.pass != 1 || c.warn != 1 || c.fail != 0 { + t.Errorf("Kerberos aggregation wrong: %+v", c) + } + if c := got["Discovery"]; c.pass != 2 || c.total != 2 { + t.Errorf("Discovery aggregation wrong: %+v", c) + } +} diff --git a/cli/internal/validate/tui.go b/cli/internal/validate/tui.go new file mode 100644 index 00000000..a97e7969 --- /dev/null +++ b/cli/internal/validate/tui.go @@ -0,0 +1,409 @@ +package validate + +import ( + "context" + "fmt" + "io" + "log/slog" + "sort" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +// TUIConfig configures the live validation dashboard. +type TUIConfig struct { + // Validator is the active validator. Its OnResult/Silent/Logger are + // managed by RunTUI for the duration of the run. + Validator *Validator + // Env, Region are shown in the header. + Env, Region string + // Run executes the validation checks. It is called once per iteration + // from a goroutine; on each call, the validator is reset and the run + // reports streamed back to the TUI. Implementations typically call + // Validator.RunAllChecks or Validator.RunQuickChecks plus any provider + // drain. + Run func(context.Context) + // PollInterval re-runs the checks on this cadence after each pass + // completes. 0 means one-shot (run once, then wait for the user to + // quit). + PollInterval time.Duration +} + +// RunTUI launches the live validation dashboard. It returns when the user +// quits (q/ctrl-c/esc) or the context is cancelled. The validator's report is +// the canonical record on exit; callers should save it and print the path +// after RunTUI returns. +func RunTUI(ctx context.Context, cfg TUIConfig) error { + if cfg.Validator == nil || cfg.Run == nil { + return fmt.Errorf("validate.RunTUI: Validator and Run are required") + } + + // Channel sized to absorb bursts: ~50 checks × 16-way concurrent fan-out. + results := make(chan Result, 256) + phases := make(chan phaseEvent, 8) + + // Seed the model with results already on the report (e.g. Discovery + // PASS lines from before the TUI started) so the dashboard reflects + // total state, not just live deltas. + seed := snapshotReport(cfg.Validator) + + cfg.Validator.SetSilent(true) + cfg.Validator.SetOnResult(func(r Result) { + select { + case results <- r: + default: + // Channel full -- drop the live update; the structured report + // already has it. The TUI will resync via the validator's + // report when the run completes. + _ = r + } + }) + // Redirect slog output for the duration of the TUI run. Bubbletea's alt + // screen does not capture stderr, so the validator's PS-failure Warn + // lines would otherwise paint on top of the dashboard. + prevLog := cfg.Validator.SetLogger(slog.New(slog.NewTextHandler(io.Discard, nil))) + defer cfg.Validator.SetLogger(prevLog) + defer cfg.Validator.SetOnResult(nil) + defer cfg.Validator.SetSilent(false) + + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + + runDone := make(chan struct{}) + go runLoop(runCtx, cfg, results, phases, runDone) + + m := newValidateModel(cfg, seed, results, phases) + p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithContext(ctx)) + _, err := p.Run() + + // Make sure the run goroutine exits before we hand control back so the + // canonical report in v.report is settled. + cancel() + <-runDone + return err +} + +// runLoop drives the per-iteration validate -> wait cycle. It owns the +// results/phases channels for its lifetime and closes them on exit so the TUI +// model can detect completion. +func runLoop(ctx context.Context, cfg TUIConfig, results chan<- Result, phases chan<- phaseEvent, done chan<- struct{}) { + defer close(done) + defer close(results) + defer close(phases) + + // First iteration uses the validator state already seeded by + // DiscoverHosts (the model already has those results too); skip + // the initial Reset so we don't wipe Discovery. + for first := true; ; first = false { + if !first { + cfg.Validator.Reset() + } + if !sendPhase(ctx, phases, phaseEvent{kind: phaseRunning, iteration: time.Now()}) { + return + } + + cfg.Run(ctx) + if ctx.Err() != nil { + return + } + if cfg.PollInterval <= 0 { + sendPhase(ctx, phases, phaseEvent{kind: phaseDone}) + return + } + + until := time.Now().Add(cfg.PollInterval) + if !sendPhase(ctx, phases, phaseEvent{kind: phaseWaiting, deadline: until}) { + return + } + + select { + case <-ctx.Done(): + return + case <-time.After(cfg.PollInterval): + } + } +} + +// sendPhase forwards an event to the phases channel, returning false if the +// context cancels first. +func sendPhase(ctx context.Context, phases chan<- phaseEvent, ev phaseEvent) bool { + select { + case phases <- ev: + return true + case <-ctx.Done(): + return false + } +} + +// snapshotReport copies the validator's report under its mutex. +func snapshotReport(v *Validator) Report { + r := *v.GetReport() + cp := make([]Result, len(r.Results)) + copy(cp, r.Results) + r.Results = cp + return r +} + +type phaseKind int + +const ( + phaseRunning phaseKind = iota + phaseWaiting + phaseDone +) + +type phaseEvent struct { + kind phaseKind + deadline time.Time // for phaseWaiting + iteration time.Time // start time, for phaseRunning (used to reset elapsed) +} + +type liveModel struct { + cfg TUIConfig + report Report + cats map[string]*categoryStats + startTime time.Time + width int + height int + + results <-chan Result + phases <-chan phaseEvent + + phase phaseKind + waitUntil time.Time + iteration int + finished bool + quitting bool + + // Track seeded counts so a Reset between polls can subtract pre-TUI + // Discovery results (we want the per-iteration view, but the seeded + // Discovery rows were valid only for the first pass). + seededReport Report +} + +func newValidateModel(cfg TUIConfig, seed Report, results <-chan Result, phases <-chan phaseEvent) *liveModel { + m := &liveModel{ + cfg: cfg, + report: Report{Env: seed.Env, Date: seed.Date}, + cats: map[string]*categoryStats{}, + startTime: time.Now(), + results: results, + phases: phases, + phase: phaseRunning, + seededReport: seed, + } + for _, r := range seed.Results { + m.applyResult(r) + } + return m +} + +type liveResultMsg struct{ r Result } +type liveDoneMsg struct{} +type livePhaseMsg struct{ ev phaseEvent } +type livePhaseClosedMsg struct{} +type liveTickMsg struct{} + +func (m *liveModel) Init() tea.Cmd { + return tea.Batch(m.waitForResultCmd(), m.waitForPhaseCmd(), liveTickCmd()) +} + +func (m *liveModel) waitForResultCmd() tea.Cmd { + return func() tea.Msg { + r, ok := <-m.results + if !ok { + return liveDoneMsg{} + } + return liveResultMsg{r: r} + } +} + +func (m *liveModel) waitForPhaseCmd() tea.Cmd { + return func() tea.Msg { + ev, ok := <-m.phases + if !ok { + return livePhaseClosedMsg{} + } + return livePhaseMsg{ev: ev} + } +} + +func liveTickCmd() tea.Cmd { + return tea.Tick(500*time.Millisecond, func(time.Time) tea.Msg { return liveTickMsg{} }) +} + +func (m *liveModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + case tea.KeyMsg: + switch msg.String() { + case "q", "ctrl+c", "esc": + m.quitting = true + return m, tea.Quit + } + case liveResultMsg: + m.applyResult(msg.r) + return m, m.waitForResultCmd() + case liveDoneMsg: + // Results channel closed; checks goroutine has exited. We may still + // be waiting on a final phase event. + case livePhaseMsg: + m.applyPhase(msg.ev) + return m, m.waitForPhaseCmd() + case livePhaseClosedMsg: + m.finished = true + case liveTickMsg: + return m, liveTickCmd() + } + return m, nil +} + +func (m *liveModel) applyPhase(ev phaseEvent) { + switch ev.kind { + case phaseRunning: + // New iteration: clear the previous pass's per-category state and + // reset the elapsed-time anchor. Iteration 0 keeps the Discovery + // seed; later iterations start from zero. + if m.iteration > 0 { + m.report = Report{Env: m.report.Env, Date: m.report.Date} + m.cats = map[string]*categoryStats{} + } + m.iteration++ + m.startTime = ev.iteration + if m.startTime.IsZero() { + m.startTime = time.Now() + } + m.phase = phaseRunning + m.waitUntil = time.Time{} + case phaseWaiting: + m.phase = phaseWaiting + m.waitUntil = ev.deadline + case phaseDone: + m.phase = phaseDone + m.finished = true + } +} + +func (m *liveModel) applyResult(r Result) { + m.report.Results = append(m.report.Results, r) + switch r.Status { + case "PASS": + m.report.Passed++ + case "FAIL": + m.report.Failed++ + case "WARN": + m.report.Warnings++ + } + c, ok := m.cats[r.Category] + if !ok { + c = &categoryStats{name: r.Category} + m.cats[r.Category] = c + } + switch r.Status { + case "PASS": + c.pass++ + c.total++ + case "FAIL": + c.fail++ + c.total++ + case "WARN": + c.warn++ + c.total++ + default: + c.others++ + } +} + +func (m *liveModel) View() string { + if m.quitting { + return "" + } + width := m.width + if width <= 0 { + width = 120 + } + innerWidth := width - 4 + if innerWidth < 40 { + innerWidth = 40 + } + + cats := m.sortedCategories() + colWidth := (innerWidth - 2) / 2 + if colWidth < 30 { + colWidth = 30 + } + left, right := splitForColumns(cats) + cols := lipgloss.JoinHorizontal(lipgloss.Top, + renderCategoryColumn(left, colWidth), + " ", + renderCategoryColumn(right, colWidth), + ) + + header := renderValidateHeader(&m.report, m.cfg.Env, m.cfg.Region, time.Since(m.startTime), innerWidth) + subhdr := styleGroupHdr.Render(fmt.Sprintf(" CHECK RESULTS (%d/%d)", m.report.Passed, m.report.Passed+m.report.Failed+m.report.Warnings)) + + parts := []string{header, "", subhdr, "", cols, "", m.renderFooter()} + parts = append(parts, styleFaint.Render(" q/ctrl-c quit")) + + return panelWithTitle("DreadGOAD VALIDATION", strings.Join(parts, "\n"), width) +} + +func (m *liveModel) sortedCategories() []categoryStats { + out := make([]categoryStats, 0, len(m.cats)) + for _, c := range m.cats { + out = append(out, *c) + } + sort.Slice(out, func(i, j int) bool { return out[i].name < out[j].name }) + return out +} + +func (m *liveModel) renderFooter() string { + resultCount := len(m.report.Results) + b := strings.Builder{} + switch m.phase { + case phaseWaiting: + remaining := time.Until(m.waitUntil) + if remaining < 0 { + remaining = 0 + } + b.WriteString(styleWarn.Render(" WAITING")) + b.WriteString(styleMuted.Render(fmt.Sprintf(" next run in %s (%d results)", + formatRemaining(remaining), resultCount))) + case phaseDone: + b.WriteString(styleOK.Render(" COMPLETE")) + b.WriteString(styleMuted.Render(fmt.Sprintf(" (%d results)", resultCount))) + default: // phaseRunning + b.WriteString(styleInfo.Render(" RUNNING")) + if m.iteration > 1 { + b.WriteString(styleMuted.Render(fmt.Sprintf(" pass #%d (%d results so far)", m.iteration, resultCount))) + } else { + b.WriteString(styleMuted.Render(fmt.Sprintf(" (%d results so far)", resultCount))) + } + } + if rate, ok := successRate(&m.report); ok && (m.phase != phaseRunning || m.report.Failed+m.report.Warnings > 0) { + b.WriteString(styleSep.Render(" | ")) + b.WriteString(styleMuted.Render(fmt.Sprintf("success rate: %d%%", rate))) + } + if m.cfg.PollInterval > 0 && m.phase != phaseDone { + b.WriteString(styleSep.Render(" | ")) + b.WriteString(styleFaint.Render(fmt.Sprintf("poll: %s", m.cfg.PollInterval))) + } + return b.String() +} + +func formatRemaining(d time.Duration) string { + if d < time.Second { + return "0s" + } + if d < time.Minute { + return fmt.Sprintf("%ds", int(d.Seconds())) + } + m := int(d.Minutes()) + s := int(d.Seconds()) % 60 + return fmt.Sprintf("%dm%02ds", m, s) +} diff --git a/cli/internal/validate/tui_test.go b/cli/internal/validate/tui_test.go new file mode 100644 index 00000000..89c3a227 --- /dev/null +++ b/cli/internal/validate/tui_test.go @@ -0,0 +1,89 @@ +package validate + +import ( + "strings" + "testing" + "time" +) + +func TestLiveModel_ApplyAndRender(t *testing.T) { + m := newValidateModel(TUIConfig{Env: "prod-east", Region: "us-west-1"}, Report{}, nil, nil) + for _, r := range []Result{ + {Status: "PASS", Category: "ACL", Name: "x"}, + {Status: "PASS", Category: "ACL", Name: "y"}, + {Status: "FAIL", Category: "MSSQL", Name: "sysadmin"}, + {Status: "WARN", Category: "Kerberos", Name: "asrep"}, + } { + m.applyResult(r) + } + + if m.report.Passed != 2 || m.report.Failed != 1 || m.report.Warnings != 1 { + t.Fatalf("counters wrong: %+v", m.report) + } + if c := m.cats["MSSQL"]; c == nil || c.fail != 1 || c.total != 1 { + t.Errorf("MSSQL cat wrong: %+v", c) + } + + m.width = 120 + out := m.View() + for _, want := range []string{"DreadGOAD VALIDATION", "ACL", "MSSQL", "Kerberos", "RUNNING", "prod-east", "us-west-1"} { + if !strings.Contains(out, want) { + t.Errorf("View() missing %q\n%s", want, out) + } + } + + m.applyPhase(phaseEvent{kind: phaseDone}) + out2 := m.View() + if !strings.Contains(out2, "COMPLETE") { + t.Errorf("done View() missing COMPLETE\n%s", out2) + } + if !strings.Contains(out2, "success rate:") { + t.Errorf("done View() missing success rate\n%s", out2) + } + if t.Failed() || testing.Verbose() { + t.Logf("\n%s", out2) + } +} + +func TestLiveModel_PollPhases(t *testing.T) { + m := newValidateModel( + TUIConfig{Env: "staging", Region: "us-west-1", PollInterval: time.Minute}, + Report{Results: []Result{{Status: "PASS", Category: "Discovery", Name: "Found DC01"}}}, + nil, nil, + ) + m.width = 120 + + if m.report.Passed != 1 || m.cats["Discovery"] == nil { + t.Fatalf("seed not applied: %+v", m.report) + } + + // Iteration 1 starts running; seed retained. + m.applyPhase(phaseEvent{kind: phaseRunning, iteration: time.Now()}) + m.applyResult(Result{Status: "PASS", Category: "ACL", Name: "x"}) + if m.report.Passed != 2 { + t.Fatalf("first iteration count wrong: %+v", m.report) + } + + // Transition to waiting. + until := time.Now().Add(20 * time.Second) + m.applyPhase(phaseEvent{kind: phaseWaiting, deadline: until}) + if m.phase != phaseWaiting { + t.Errorf("phase not waiting: %v", m.phase) + } + if !strings.Contains(m.View(), "WAITING") { + t.Errorf("View missing WAITING\n%s", m.View()) + } + if !strings.Contains(m.View(), "poll: 1m0s") { + t.Errorf("View missing poll cadence\n%s", m.View()) + } + + // Iteration 2 clears prior state (including Discovery seed). + m.applyPhase(phaseEvent{kind: phaseRunning, iteration: time.Now()}) + if m.report.Passed != 0 || m.cats["ACL"] != nil || m.cats["Discovery"] != nil { + t.Fatalf("iteration 2 did not reset: %+v / %+v", m.report, m.cats) + } + m.applyResult(Result{Status: "FAIL", Category: "MSSQL", Name: "sysadmin"}) + if m.report.Failed != 1 { + t.Fatalf("iteration 2 fail count wrong: %+v", m.report) + } +} diff --git a/cli/internal/validate/validator.go b/cli/internal/validate/validator.go index 3b6be978..a8a0e056 100644 --- a/cli/internal/validate/validator.go +++ b/cli/internal/validate/validator.go @@ -53,6 +53,16 @@ type Validator struct { hosts map[string]string // hostname -> instance ID lab *labmap.LabMap + // onResult, if set, is invoked for every result appended to the report. + // The live TUI uses this to stream results into a channel while the + // concurrent check goroutines accumulate them in v.report. Safe to call + // from any goroutine; callers must not block (the validator's mutex is + // not held during the call, but excessive blocking will slow checks). + onResult func(Result) + // silent suppresses the streaming color writes from addResult so the TUI + // owns the screen. The structured report is unaffected. + silent bool + // failures counts consecutive runPS failures per host. A single transient // SSM/WinRM hiccup must not poison the rest of the run, so we only mark a // host dead after deadThreshold sustained failures. Successful calls @@ -130,6 +140,10 @@ type checkFunc func(context.Context, io.Writer) // output gives operators a live progress signal; the persisted JSON report // keeps the canonical order. func (v *Validator) runChecks(ctx context.Context, checks []checkFunc) { + v.mu.Lock() + silent := v.silent + v.mu.Unlock() + var stdoutMu sync.Mutex sem := make(chan struct{}, maxConcurrentChecks) var wg sync.WaitGroup @@ -140,6 +154,13 @@ func (v *Validator) runChecks(ctx context.Context, checks []checkFunc) { defer wg.Done() sem <- struct{}{} defer func() { <-sem }() + if silent { + // TUI mode owns the screen; checks must not emit the + // "== Section ==" banners or any stray writes. Results + // flow to the dashboard via the OnResult callback. + f(ctx, io.Discard) + return + } var buf bytes.Buffer f(ctx, &buf) stdoutMu.Lock() @@ -333,6 +354,53 @@ func (v *Validator) runPSErr(ctx context.Context, host, command string) (string, return "", lastErr } +// SetOnResult registers a callback invoked for every result appended to the +// report. Pass nil to unregister. +func (v *Validator) SetOnResult(fn func(Result)) { + v.mu.Lock() + v.onResult = fn + v.mu.Unlock() +} + +// SetSilent suppresses the streaming colorized writes from addResult. Used by +// the live TUI so the bubbletea program owns the screen. +func (v *Validator) SetSilent(silent bool) { + v.mu.Lock() + v.silent = silent + v.mu.Unlock() +} + +// Reset clears the run state so the validator can be reused for a fresh +// pass. Counters, results, and the dead-host/failure tracking are wiped. +// Discovered hosts and the configured logger/onResult/silent flags are +// preserved -- callers running a poll loop typically want to keep those. +func (v *Validator) Reset() { + v.mu.Lock() + v.report = Report{ + Date: time.Now().UTC().Format(time.RFC3339), + Env: v.env, + } + v.mu.Unlock() + + v.failures.Range(func(k, _ any) bool { v.failures.Delete(k); return true }) + v.dead.Range(func(k, _ any) bool { v.dead.Delete(k); return true }) +} + +// SetLogger swaps the validator's logger and returns the previous one. The +// live TUI uses this to redirect slog writes (which otherwise hit stderr and +// bleed through bubbletea's alt screen) to a discard handler for the duration +// of the run. +func (v *Validator) SetLogger(log *slog.Logger) *slog.Logger { + if log == nil { + log = slog.Default() + } + v.mu.Lock() + prev := v.log + v.log = log + v.mu.Unlock() + return prev +} + func (v *Validator) addResult(w io.Writer, status, category, name, detail string) { r := Result{Status: status, Category: category, Name: name, Detail: detail} @@ -346,8 +414,17 @@ func (v *Validator) addResult(w io.Writer, status, category, name, detail string case "WARN": v.report.Warnings++ } + cb := v.onResult + silent := v.silent v.mu.Unlock() + if cb != nil { + cb(r) + } + if silent { + return + } + switch status { case "PASS": _, _ = fmt.Fprint(w, color.GreenString(" ✓ %s\n", name))