diff --git a/go.mod b/go.mod index a6c3807..cb11c09 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.24.11 require ( github.com/BurntSushi/toml v0.4.1 + github.com/Masterminds/semver/v3 v3.4.0 github.com/creativeprojects/go-selfupdate v1.5.2 github.com/fatih/color v1.16.0 github.com/golang/protobuf v1.5.4 @@ -26,7 +27,6 @@ require ( require ( code.gitea.io/sdk/gitea v0.22.1 // indirect github.com/42wim/httpsig v1.2.3 // indirect - github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/internal/selfupdate.go b/internal/selfupdate.go index 1520561..c94527f 100644 --- a/internal/selfupdate.go +++ b/internal/selfupdate.go @@ -52,6 +52,7 @@ func (m InstallMethod) String() string { type UpdateOptions struct { SkipConfirm bool TargetVersion string // e.g., "v1.2.3", "master", or "" for latest + ForceUpdate bool // skip "already on latest" check } // DetectInstallMethod determines the installation method by examining the binary path. @@ -207,7 +208,7 @@ func RunUpdate(ctx context.Context, currentVersion string, opts UpdateOptions) e return nil } logrus.WithField("latest", release.Version()).Debug("Latest release detected") - if release.LessOrEqual(currentVersion) { + if !opts.ForceUpdate && release.LessOrEqual(currentVersion) { fmt.Printf("You are already running the latest version (%s).\n", currentVersion) return nil } @@ -215,18 +216,20 @@ func RunUpdate(ctx context.Context, currentVersion string, opts UpdateOptions) e targetVersion := release.Version() - // Show update info with release notes BEFORE confirmation - fmt.Printf("\nUpdate available: %s → %s\n", currentVersion, targetVersion) - - if notes := release.ReleaseNotes; notes != "" { - formatted := formatReleaseNotes(notes, 15) - if formatted != "" { - fmt.Println() - fmt.Println("Release notes:") - fmt.Println(formatted) + // Show update info with release notes BEFORE confirmation (skip if caller already showed it) + if !opts.SkipConfirm { + fmt.Printf("\nUpdate available: %s → %s\n", currentVersion, targetVersion) + + if notes := release.ReleaseNotes; notes != "" { + formatted := formatReleaseNotes(notes, 15) + if formatted != "" { + fmt.Println() + fmt.Println("Release notes:") + fmt.Println(formatted) + } } + fmt.Println() } - fmt.Println() logrus.WithField("method", method.String()).Debug("Updating via install method") @@ -288,7 +291,6 @@ func updateViaCommand(ctx context.Context, tool string, opts UpdateOptions, curr } fmt.Println(color.GreenString("\nSuccessfully updated to %s.", targetVer)) - return nil } diff --git a/internal/update.go b/internal/update.go index 843fc75..2c6edb1 100644 --- a/internal/update.go +++ b/internal/update.go @@ -7,21 +7,32 @@ import ( "net/http" "os" "path" + "path/filepath" + "runtime" "strings" "time" + semver "github.com/Masterminds/semver/v3" "github.com/fatih/color" "github.com/sirupsen/logrus" ) type releaseResponse struct { TagName string `json:"tag_name"` + Body string `json:"body"` } const repoURL = "https://github.com/ThreeDotsLabs/cli" const releasesURL = "https://api.github.com/repos/ThreeDotsLabs/cli/releases/latest" +const updateCheckInterval = 30 * time.Minute +const dismissalDuration = 30 * time.Minute -func CheckForUpdate(currentVersion string) { +type latestRelease struct { + Version string + ReleaseNotes string +} + +func CheckForUpdate(currentVersion string, commandName string, forcePrompt bool) { if os.Getenv("TDL_NO_UPDATE_CHECK") != "" { logrus.Debug("Update check disabled via TDL_NO_UPDATE_CHECK") return @@ -31,34 +42,202 @@ func CheckForUpdate(currentVersion string) { return } + isUpdateCommand := commandName == "update" || commandName == "u" + updateInfo, _ := getUpdateInfo() - if updateInfo.UpdateAvailable && updateInfo.CurrentVersion == currentVersion { - printVersionNotice(updateInfo.CurrentVersion, updateInfo.AvailableVersion) + // Fast path: cached update available — no API call needed + if updateInfo.UpdateAvailable && isNewerVersion(updateInfo.AvailableVersion, currentVersion) { + showUpdatePromptOrNotice(updateInfo, currentVersion, isUpdateCommand, forcePrompt) return } - if time.Since(updateInfo.LastChecked) < time.Hour { + // Fast path: check interval not elapsed — return immediately + if !forcePrompt && time.Since(updateInfo.LastChecked) < updateCheckInterval { return } - latestVersion := getLatestVersion() + release := getLatestRelease() + if release == nil { + return + } - if latestVersion != "" && latestVersion != currentVersion { + isNewer := release.Version != "" && isNewerVersion(release.Version, currentVersion) + isDifferent := release.Version != "" && release.Version != currentVersion + + if isNewer || (forcePrompt && isDifferent) { updateInfo.CurrentVersion = currentVersion - updateInfo.AvailableVersion = latestVersion + updateInfo.AvailableVersion = release.Version updateInfo.UpdateAvailable = true + updateInfo.ReleaseNotes = release.ReleaseNotes + + updateInfo.LastChecked = time.Now() + _ = storeUpdateInfo(updateInfo) - printVersionNotice(currentVersion, latestVersion) + showUpdatePromptOrNotice(updateInfo, currentVersion, isUpdateCommand, forcePrompt) } else { updateInfo.CurrentVersion = currentVersion updateInfo.AvailableVersion = "" updateInfo.UpdateAvailable = false + updateInfo.ReleaseNotes = "" + // Clear stale dismissal since there's no pending update + updateInfo.DismissedVersion = "" + updateInfo.DismissedAt = time.Time{} + + updateInfo.LastChecked = time.Now() + _ = storeUpdateInfo(updateInfo) } +} - updateInfo.LastChecked = time.Now() +func showUpdatePromptOrNotice(updateInfo UpdateInfo, currentVersion string, isUpdateCommand bool, forcePrompt bool) { + // If user is running "tdl update", skip — they're already updating + if isUpdateCommand { + return + } + + // Non-interactive terminal (CI, piped stdin) — passive notice only + if !IsStdinTerminal() { + printVersionNotice(currentVersion, updateInfo.AvailableVersion) + return + } - _ = storeUpdateInfo(updateInfo) + if forcePrompt || shouldShowBlockingPrompt(updateInfo) { + showBlockingUpdatePrompt(updateInfo, currentVersion) + } else { + printVersionNotice(currentVersion, updateInfo.AvailableVersion) + } +} + +func shouldShowBlockingPrompt(info UpdateInfo) bool { + // Never dismissed — show prompt + if info.DismissedVersion == "" { + return true + } + + // Dismissed a different version — new release, re-prompt + if info.DismissedVersion != info.AvailableVersion { + return true + } + + // Dismissed same version — only re-prompt after dismissal duration + return time.Since(info.DismissedAt) > dismissalDuration +} + +func showBlockingUpdatePrompt(updateInfo UpdateInfo, currentVersion string) { + c := color.New(color.FgHiYellow) + _, _ = c.Printf("A new version of the CLI is available: %s \u2192 %s\n", currentVersion, updateInfo.AvailableVersion) + _, _ = c.Printf("Some features may be missing or not work correctly.\n") + + if updateInfo.ReleaseNotes != "" { + formatted := formatReleaseNotes(updateInfo.ReleaseNotes, 15) + if formatted != "" { + fmt.Println() + fmt.Println("Release notes:") + fmt.Println(formatted) + } + } + fmt.Println() + + method := DetectInstallMethod() + + // Check if binary requires elevated permissions (direct binary install) + if method == InstallMethodDirectBinary || method == InstallMethodUnknown { + binaryPath, err := os.Executable() + if err == nil { + binaryPath, _ = filepath.EvalSymlinks(binaryPath) + } + if err != nil || !canWriteBinary(binaryPath) { + cmdName := os.Args[0] + var updateCmd string + if runtime.GOOS == "windows" { + updateCmd = fmt.Sprintf("%s update", cmdName) + fmt.Println("The binary requires elevated permissions to update.") + fmt.Println("To update, re-open your terminal as Administrator and run:") + } else { + updateCmd = fmt.Sprintf("sudo %s update", cmdName) + fmt.Printf("The binary at %s requires elevated permissions to update.\n", binaryPath) + fmt.Println("To update, run:") + } + fmt.Println(" " + SprintCommand(updateCmd)) + fmt.Printf("\nOr download from: %s/releases/latest\n", repoURL) + fmt.Println() + + result := Prompt( + Actions{ + {Shortcut: '\n', Action: "exit", ShortcutAliases: []rune{'\r'}}, + {Shortcut: 's', Action: "skip and continue"}, + }, + os.Stdin, + os.Stdout, + ) + + // Store dismissal regardless of choice + updateInfo.DismissedVersion = updateInfo.AvailableVersion + updateInfo.DismissedAt = time.Now() + _ = storeUpdateInfo(updateInfo) + + if result == '\n' { + os.Exit(0) + } + fmt.Println() + return + } + } + + hint := updateCommandHint(method) + action := "update now" + if hint != "" { + action = fmt.Sprintf("run %s", SprintCommand(hint)) + } + + result := Prompt( + Actions{ + {Shortcut: '\n', Action: action, ShortcutAliases: []rune{'\r'}}, + {Shortcut: 's', Action: "skip"}, + }, + os.Stdin, + os.Stdout, + ) + + if result == 's' { + // User declined — record dismissal + updateInfo.DismissedVersion = updateInfo.AvailableVersion + updateInfo.DismissedAt = time.Now() + _ = storeUpdateInfo(updateInfo) + fmt.Println() + return + } + + // User pressed ENTER — run update with SkipConfirm (no double confirmation) + fmt.Println() + ctx := context.Background() + err := RunUpdate(ctx, currentVersion, UpdateOptions{SkipConfirm: true, ForceUpdate: true}) + if err != nil { + fmt.Println(color.RedString("Update failed: %v", err)) + fmt.Println(color.HiBlackString("Continuing with current version...")) + fmt.Println() + return + } + + // Update succeeded — binary is replaced, must exit + fmt.Println() + fmt.Println("Please re-run your command.") + os.Exit(0) +} + +func updateCommandHint(method InstallMethod) string { + switch method { + case InstallMethodHomebrew: + return "brew upgrade tdl" + case InstallMethodGoInstall: + return "go install github.com/ThreeDotsLabs/cli/tdl@latest" + case InstallMethodNix: + return "nix profile upgrade --flake github:ThreeDotsLabs/cli" + case InstallMethodScoop: + return "scoop update tdl" + default: + return "" + } } func printVersionNotice(currentVersion string, availableVersion string) { @@ -69,18 +248,18 @@ func printVersionNotice(currentVersion string, availableVersion string) { fmt.Println() } -func getLatestVersion() string { +func getLatestRelease() *latestRelease { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, releasesURL, nil) if err != nil { - return "" + return nil } resp, err := http.DefaultClient.Do(req) if err != nil { - return "" + return nil } defer func() { _ = resp.Body.Close() @@ -88,10 +267,18 @@ func getLatestVersion() string { var release releaseResponse if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { - return "" + return nil + } + + version := strings.TrimLeft(release.TagName, "v") + if version == "" { + return nil } - return strings.TrimLeft(release.TagName, "v") + return &latestRelease{ + Version: version, + ReleaseNotes: release.Body, + } } func updateInfoPath() string { @@ -103,6 +290,9 @@ type UpdateInfo struct { AvailableVersion string `json:"available_version"` UpdateAvailable bool `json:"update_available"` LastChecked time.Time `json:"last_checked"` + ReleaseNotes string `json:"release_notes,omitempty"` + DismissedVersion string `json:"dismissed_version,omitempty"` + DismissedAt time.Time `json:"dismissed_at,omitempty"` } func getUpdateInfo() (UpdateInfo, error) { @@ -136,6 +326,18 @@ func storeUpdateInfo(info UpdateInfo) error { return nil } +func isNewerVersion(latest, current string) bool { + latestV, err := semver.NewVersion(latest) + if err != nil { + return latest != current + } + currentV, err := semver.NewVersion(current) + if err != nil { + return latest != current + } + return latestV.GreaterThan(currentV) +} + func fileExists(path string) bool { _, err := os.Stat(path) if err == nil { diff --git a/internal/update_test.go b/internal/update_test.go new file mode 100644 index 0000000..3eee5b5 --- /dev/null +++ b/internal/update_test.go @@ -0,0 +1,108 @@ +package internal + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShouldShowBlockingPrompt(t *testing.T) { + t.Run("never dismissed", func(t *testing.T) { + info := UpdateInfo{ + AvailableVersion: "1.2.0", + } + assert.True(t, shouldShowBlockingPrompt(info)) + }) + + t.Run("dismissed different version", func(t *testing.T) { + info := UpdateInfo{ + AvailableVersion: "1.3.0", + DismissedVersion: "1.2.0", + DismissedAt: time.Now(), + } + assert.True(t, shouldShowBlockingPrompt(info)) + }) + + t.Run("dismissed same version recently", func(t *testing.T) { + info := UpdateInfo{ + AvailableVersion: "1.2.0", + DismissedVersion: "1.2.0", + DismissedAt: time.Now().Add(-10 * time.Minute), + } + assert.False(t, shouldShowBlockingPrompt(info)) + }) + + t.Run("dismissed same version expired", func(t *testing.T) { + info := UpdateInfo{ + AvailableVersion: "1.2.0", + DismissedVersion: "1.2.0", + DismissedAt: time.Now().Add(-31 * time.Minute), + } + assert.True(t, shouldShowBlockingPrompt(info)) + }) +} + +func TestUpdateCommandHint(t *testing.T) { + tests := []struct { + method InstallMethod + expected string + }{ + {InstallMethodHomebrew, "brew upgrade tdl"}, + {InstallMethodGoInstall, "go install github.com/ThreeDotsLabs/cli/tdl@latest"}, + {InstallMethodNix, "nix profile upgrade --flake github:ThreeDotsLabs/cli"}, + {InstallMethodScoop, "scoop update tdl"}, + {InstallMethodDirectBinary, ""}, + {InstallMethodUnknown, ""}, + } + + for _, tt := range tests { + t.Run(tt.method.String(), func(t *testing.T) { + assert.Equal(t, tt.expected, updateCommandHint(tt.method)) + }) + } +} + +func TestUpdateInfoBackwardCompatibility(t *testing.T) { + t.Run("old JSON without new fields deserializes cleanly", func(t *testing.T) { + oldJSON := `{"current_version":"1.0.0","available_version":"1.1.0","update_available":true,"last_checked":"2025-01-01T00:00:00Z"}` + var info UpdateInfo + err := json.Unmarshal([]byte(oldJSON), &info) + require.NoError(t, err) + + assert.Equal(t, "1.0.0", info.CurrentVersion) + assert.Equal(t, "1.1.0", info.AvailableVersion) + assert.True(t, info.UpdateAvailable) + // New fields default to zero values + assert.Empty(t, info.ReleaseNotes) + assert.Empty(t, info.DismissedVersion) + assert.True(t, info.DismissedAt.IsZero()) + }) + + t.Run("round trip with all fields", func(t *testing.T) { + now := time.Now().Truncate(time.Second) + info := UpdateInfo{ + CurrentVersion: "1.0.0", + AvailableVersion: "1.1.0", + UpdateAvailable: true, + LastChecked: now, + ReleaseNotes: "- bug fix", + DismissedVersion: "1.1.0", + DismissedAt: now, + } + data, err := json.Marshal(info) + require.NoError(t, err) + + var decoded UpdateInfo + err = json.Unmarshal(data, &decoded) + require.NoError(t, err) + + assert.Equal(t, info.CurrentVersion, decoded.CurrentVersion) + assert.Equal(t, info.AvailableVersion, decoded.AvailableVersion) + assert.Equal(t, info.ReleaseNotes, decoded.ReleaseNotes) + assert.Equal(t, info.DismissedVersion, decoded.DismissedVersion) + assert.True(t, info.DismissedAt.Equal(decoded.DismissedAt)) + }) +} diff --git a/tdl/main.go b/tdl/main.go index ddafcf6..817acd0 100644 --- a/tdl/main.go +++ b/tdl/main.go @@ -120,6 +120,11 @@ var app = &cli.App{ Aliases: []string{"v"}, EnvVars: []string{"VERBOSE"}, }, + &cli.BoolFlag{ + Name: "force-update-prompt", + Usage: "force the update prompt to appear (for testing)", + Hidden: true, + }, }, Before: func(c *cli.Context) error { if verbose := c.Bool("verbose"); verbose { @@ -129,7 +134,15 @@ var app = &cli.App{ logrus.SetLevel(logrus.WarnLevel) } - internal.CheckForUpdate(version) + commandName := "" + for _, arg := range os.Args[1:] { + if !strings.HasPrefix(arg, "-") { + commandName = arg + break + } + } + + internal.CheckForUpdate(version, commandName, c.Bool("force-update-prompt")) return nil },