diff --git a/core/cmd/clovapi/proxy_lifecycle.go b/core/cmd/clovapi/proxy_lifecycle.go index 2ea322a4..e8897a2b 100644 --- a/core/cmd/clovapi/proxy_lifecycle.go +++ b/core/cmd/clovapi/proxy_lifecycle.go @@ -185,6 +185,13 @@ func findListenPID(port int) (int, error) { return 0, errors.New("no listener found") } +var ( + probeProxyHealthForStop = probeProxyHealth + processAliveForStop = processAlive + killProcessTreeForStop = killProcessTree + findListenPIDForStop = findListenPID +) + func waitProxyDown(cfg profile.ProxyConfig, deadline time.Duration) error { deadlineAt := time.Now().Add(deadline) for time.Now().Before(deadlineAt) { @@ -200,8 +207,19 @@ func waitProxyDown(cfg profile.ProxyConfig, deadline time.Duration) error { return fmt.Errorf("proxy still healthy at %s", proxyHealthURL(cfg)) } +func verifyPortListenerIsClovapiProxy(cfg profile.ProxyConfig, pid int) error { + ok, err := probeProxyHealthForStop(cfg) + if err != nil { + return err + } + if ok { + return nil + } + return fmt.Errorf("refusing to stop process %d listening on %s: health endpoint does not identify as clovapi proxy", pid, proxyBaseURL(cfg)) +} + func runProxyStop(cfg profile.ProxyConfig, verbose bool) error { - wasHealthy, _ := probeProxyHealth(cfg) + wasHealthy, _ := probeProxyHealthForStop(cfg) if wasHealthy { _ = shutdownProxyViaHTTP(cfg) @@ -210,16 +228,19 @@ func runProxyStop(cfg profile.ProxyConfig, verbose bool) error { rec, pidErr := readProxyPIDFile() if pidErr == nil && rec.PID > 0 { - if processAlive(rec.PID) { - if err := killProcessTree(rec.PID); err != nil && verbose { + if processAliveForStop(rec.PID) { + if err := killProcessTreeForStop(rec.PID); err != nil && verbose { fmt.Fprintf(os.Stderr, "warning: kill proxy pid %d: %v\n", rec.PID, err) } } } - if listenPID, err := findListenPID(cfg.Port); err == nil && listenPID > 0 { + if listenPID, err := findListenPIDForStop(cfg.Port); err == nil && listenPID > 0 { if pidErr != nil || listenPID != rec.PID { - _ = killProcessTree(listenPID) + if err := verifyPortListenerIsClovapiProxy(cfg, listenPID); err != nil { + return err + } + _ = killProcessTreeForStop(listenPID) } } diff --git a/core/cmd/clovapi/proxy_lifecycle_test.go b/core/cmd/clovapi/proxy_lifecycle_test.go index c0a72c68..b8bd09d6 100644 --- a/core/cmd/clovapi/proxy_lifecycle_test.go +++ b/core/cmd/clovapi/proxy_lifecycle_test.go @@ -1,8 +1,10 @@ package main import ( + "errors" "os" "path/filepath" + "strings" "testing" "github.com/spf13/cobra" @@ -42,3 +44,132 @@ func TestProxyPIDFileRoundTrip(t *testing.T) { t.Fatalf("expected pid file removed, err=%v", err) } } + +func TestRunProxyStopRefusesForeignListenerWithoutPIDFile(t *testing.T) { + dir := t.TempDir() + config.SetDirOverride(dir) + t.Cleanup(func() { config.SetDirOverride("") }) + + cfg := profile.ProxyConfig{Host: "127.0.0.1", Port: 27483} + var killed []int + restoreProxyStopHooks(t, + func(profile.ProxyConfig) (bool, error) { return false, nil }, + func(int) bool { return false }, + func(pid int) error { + killed = append(killed, pid) + return nil + }, + func(int) (int, error) { return 4242, nil }, + ) + + err := runProxyStop(cfg, false) + if err == nil { + t.Fatal("expected foreign listener error") + } + if !strings.Contains(err.Error(), "does not identify as clovapi proxy") { + t.Fatalf("unexpected error: %v", err) + } + if len(killed) != 0 { + t.Fatalf("foreign listener was killed: %v", killed) + } +} + +func TestRunProxyStopRefusesForeignListenerWithStalePIDFile(t *testing.T) { + dir := t.TempDir() + config.SetDirOverride(dir) + t.Cleanup(func() { config.SetDirOverride("") }) + + cfg := profile.ProxyConfig{Host: "127.0.0.1", Port: 27483} + if err := writeProxyPIDFile(1111, cfg); err != nil { + t.Fatal(err) + } + var killed []int + restoreProxyStopHooks(t, + func(profile.ProxyConfig) (bool, error) { return false, nil }, + func(int) bool { return false }, + func(pid int) error { + killed = append(killed, pid) + return nil + }, + func(int) (int, error) { return 4242, nil }, + ) + + err := runProxyStop(cfg, false) + if err == nil { + t.Fatal("expected foreign listener error") + } + if !strings.Contains(err.Error(), "does not identify as clovapi proxy") { + t.Fatalf("unexpected error: %v", err) + } + if len(killed) != 0 { + t.Fatalf("foreign listener was killed: %v", killed) + } +} + +func TestRunProxyStopKillsVerifiedClovapiListenerWithoutPIDFile(t *testing.T) { + dir := t.TempDir() + config.SetDirOverride(dir) + t.Cleanup(func() { config.SetDirOverride("") }) + + cfg := profile.ProxyConfig{Host: "127.0.0.1", Port: 27483} + probes := 0 + var killed []int + restoreProxyStopHooks(t, + func(profile.ProxyConfig) (bool, error) { + probes++ + return probes > 1, nil + }, + func(int) bool { return false }, + func(pid int) error { + killed = append(killed, pid) + return nil + }, + func(int) (int, error) { return 4242, nil }, + ) + + if err := runProxyStop(cfg, false); err != nil { + t.Fatal(err) + } + if len(killed) != 1 || killed[0] != 4242 { + t.Fatalf("expected verified listener killed, got %v", killed) + } +} + +func restoreProxyStopHooks( + t *testing.T, + probe func(profile.ProxyConfig) (bool, error), + alive func(int) bool, + kill func(int) error, + find func(int) (int, error), +) { + t.Helper() + originalProbe := probeProxyHealthForStop + originalAlive := processAliveForStop + originalKill := killProcessTreeForStop + originalFind := findListenPIDForStop + probeProxyHealthForStop = probe + processAliveForStop = alive + killProcessTreeForStop = kill + findListenPIDForStop = find + t.Cleanup(func() { + probeProxyHealthForStop = originalProbe + processAliveForStop = originalAlive + killProcessTreeForStop = originalKill + findListenPIDForStop = originalFind + }) +} + +func TestVerifyPortListenerIsClovapiProxyPropagatesProbeError(t *testing.T) { + want := errors.New("probe failed") + restoreProxyStopHooks(t, + func(profile.ProxyConfig) (bool, error) { return false, want }, + func(int) bool { return false }, + func(int) error { return nil }, + func(int) (int, error) { return 0, nil }, + ) + + err := verifyPortListenerIsClovapiProxy(profile.ProxyConfig{Host: "127.0.0.1", Port: 27483}, 4242) + if !errors.Is(err, want) { + t.Fatalf("error = %v, want %v", err, want) + } +} diff --git a/core/internal/buildinfo/buildinfo.go b/core/internal/buildinfo/buildinfo.go index f60bc3c4..54c2dfdd 100644 --- a/core/internal/buildinfo/buildinfo.go +++ b/core/internal/buildinfo/buildinfo.go @@ -4,7 +4,7 @@ import "strings" // Set at link time via -ldflags (see .goreleaser.yaml). var ( - Version = "dev0.1.42" + Version = "dev0.1.43" Commit = "none" Date = "unknown" )