From 4e319da776ebcff2bf44c9768a4e86b4f56a493a Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Mon, 23 Mar 2026 16:12:13 +0200 Subject: [PATCH 1/4] Wire server discovery protocol into thv serve The discovery package (pkg/server/discovery/) was already implemented and tested but had zero imports in the codebase. This wires it into the serve command so clients (CLI, Studio) can auto-discover a running server without hardcoded ports or environment variables. On startup, thv serve now generates a cryptographic nonce, writes a discovery file to $XDG_STATE_HOME/toolhive/server/server.json with the actual listen URL (supporting port 0 and Unix sockets), and returns the nonce via the X-Toolhive-Nonce health check header. On shutdown the file is removed. The skills client now tries discovery before falling back to the TOOLHIVE_API_URL env var or the default localhost:8080, with loopback and socket-path validation on discovered URLs. Additional fixes: SIGTERM handling in the serve command, a 30-second shutdown timeout (was unbounded), symlink rejection on the discovery file read path, directory permission tightening after MkdirAll, and constant-time nonce comparison. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Juan Antonio Osorio --- pkg/api/server.go | 96 ++++++++++- pkg/api/server_test.go | 71 ++++++++ pkg/api/v1/healtcheck_test.go | 99 ++++++++++++ pkg/api/v1/healthcheck.go | 12 +- pkg/server/discovery/discover.go | 92 +++++++++++ pkg/server/discovery/discover_test.go | 118 ++++++++++++++ pkg/server/discovery/discovery.go | 143 +++++++++++++++++ pkg/server/discovery/discovery_test.go | 214 +++++++++++++++++++++++++ pkg/server/discovery/health.go | 128 +++++++++++++++ pkg/server/discovery/health_test.go | 171 ++++++++++++++++++++ pkg/skills/client/client.go | 76 ++++++++- 11 files changed, 1211 insertions(+), 9 deletions(-) create mode 100644 pkg/api/server_test.go create mode 100644 pkg/server/discovery/discover.go create mode 100644 pkg/server/discovery/discover_test.go create mode 100644 pkg/server/discovery/discovery.go create mode 100644 pkg/server/discovery/discovery_test.go create mode 100644 pkg/server/discovery/health.go create mode 100644 pkg/server/discovery/health_test.go diff --git a/pkg/api/server.go b/pkg/api/server.go index 23953f3966..08a96b3958 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -17,6 +17,8 @@ package api import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "io" @@ -42,6 +44,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/registry" + "github.com/stacklok/toolhive/pkg/server/discovery" "github.com/stacklok/toolhive/pkg/skills" "github.com/stacklok/toolhive/pkg/skills/gitresolver" "github.com/stacklok/toolhive/pkg/skills/skillsvc" @@ -55,6 +58,7 @@ const ( middlewareTimeout = 60 * time.Second readHeaderTimeout = 10 * time.Second shutdownTimeout = 30 * time.Second + nonceBytes = 16 socketPermissions = 0660 // Socket file permissions (owner/group read-write) maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size ) @@ -65,6 +69,7 @@ type ServerBuilder struct { isUnixSocket bool debugMode bool enableDocs bool + nonce string oidcConfig *auth.TokenValidatorConfig middlewares []func(http.Handler) http.Handler customRoutes map[string]http.Handler @@ -108,6 +113,14 @@ func (b *ServerBuilder) WithDocs(enableDocs bool) *ServerBuilder { return b } +// WithNonce sets the server instance nonce used for discovery verification. +// When non-empty, the server writes a discovery file on startup and returns +// the nonce in the X-Toolhive-Nonce health check header. +func (b *ServerBuilder) WithNonce(nonce string) *ServerBuilder { + b.nonce = nonce + return b +} + // WithOIDCConfig sets the OIDC configuration func (b *ServerBuilder) WithOIDCConfig(oidcConfig *auth.TokenValidatorConfig) *ServerBuilder { b.oidcConfig = oidcConfig @@ -297,7 +310,7 @@ func (b *ServerBuilder) setupDefaultRoutes(r *chi.Mux) { // All other routes get standard timeout standardRouters := map[string]http.Handler{ - "/health": v1.HealthcheckRouter(b.containerRuntime), + "/health": v1.HealthcheckRouter(b.containerRuntime, b.nonce), "/api/v1beta/version": v1.VersionRouter(), "/api/v1beta/registry": v1.RegistryRouter(true), "/api/v1beta/discovery": v1.DiscoveryRouter(), @@ -504,6 +517,7 @@ type Server struct { address string isUnixSocket bool addrType string + nonce string storeCloser io.Closer } @@ -532,14 +546,29 @@ func NewServer(ctx context.Context, builder *ServerBuilder) (*Server, error) { address: builder.address, isUnixSocket: builder.isUnixSocket, addrType: addrType, + nonce: builder.nonce, storeCloser: builder.skillStoreCloser, }, nil } +// ListenURL returns the URL where the server is listening, using the actual +// bound address from the listener (important when binding to port 0). +func (s *Server) ListenURL() string { + if s.isUnixSocket { + return fmt.Sprintf("unix://%s", s.address) + } + return fmt.Sprintf("http://%s", s.listener.Addr().String()) +} + // Start starts the server and blocks until the context is cancelled func (s *Server) Start(ctx context.Context) error { slog.Info("starting server", "type", s.addrType, "address", s.address) + // Write server discovery file so clients can find this instance. + if err := s.writeDiscoveryFile(ctx); err != nil { + return err + } + // Start server in a goroutine serverErr := make(chan error, 1) go func() { @@ -562,6 +591,50 @@ func (s *Server) Start(ctx context.Context) error { } } +// writeDiscoveryFile writes the server discovery file if a nonce is configured. +// It checks for an existing healthy server first to prevent silent orphaning. +func (s *Server) writeDiscoveryFile(ctx context.Context) error { + if s.nonce == "" { + return nil + } + + // Guard against overwriting another server's discovery file. + result, err := discovery.Discover(ctx) + if err != nil { + slog.Debug("discovery check failed, proceeding with startup", "error", err) + } else { + switch result.State { + case discovery.StateRunning: + return fmt.Errorf("another ToolHive server is already running at %s (PID %d)", result.Info.URL, result.Info.PID) + case discovery.StateStale: + slog.Debug("cleaning up stale discovery file", "pid", result.Info.PID) + if err := discovery.CleanupStale(); err != nil { + slog.Warn("failed to clean up stale discovery file", "error", err) + } + case discovery.StateUnhealthy: + // The process is alive but not responding to health checks. + // This can happen after a crash-restart where the old process + // is hung. We intentionally overwrite the discovery file so + // this new server becomes discoverable. + slog.Warn("existing server is unhealthy, overwriting discovery file", "pid", result.Info.PID) + case discovery.StateNotFound: + // No existing server, proceed normally. + } + } + + info := &discovery.ServerInfo{ + URL: s.ListenURL(), + PID: os.Getpid(), + Nonce: s.nonce, + StartedAt: time.Now().UTC(), + } + if err := discovery.WriteServerInfo(info); err != nil { + return fmt.Errorf("failed to write discovery file: %w", err) + } + slog.Debug("wrote discovery file", "url", info.URL, "pid", info.PID) + return nil +} + // shutdown gracefully shuts down the server func (s *Server) shutdown() error { shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) @@ -579,6 +652,11 @@ func (s *Server) shutdown() error { // cleanup performs cleanup operations func (s *Server) cleanup() { + if s.nonce != "" { + if err := discovery.RemoveServerInfo(); err != nil { + slog.Warn("failed to remove discovery file", "error", err) + } + } if s.storeCloser != nil { if err := s.storeCloser.Close(); err != nil { slog.Warn("failed to close skill store", "error", err) @@ -653,6 +731,16 @@ func (a *clientPathAdapter) ListSkillSupportingClients() []string { return result } +// generateNonce creates a cryptographically random nonce for server instance +// identification. It returns a 32-character hex string (16 random bytes). +func generateNonce() (string, error) { + b := make([]byte, nonceBytes) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate server nonce: %w", err) + } + return hex.EncodeToString(b), nil +} + // Serve starts the server on the given address and serves the API. // It is assumed that the caller sets up appropriate signal handling. // If isUnixSocket is true, address is treated as a UNIX socket path. @@ -666,11 +754,17 @@ func Serve( oidcConfig *auth.TokenValidatorConfig, middlewares ...func(http.Handler) http.Handler, ) error { + nonce, err := generateNonce() + if err != nil { + return err + } + builder := NewServerBuilder(). WithAddress(address). WithUnixSocket(isUnixSocket). WithDebugMode(debugMode). WithDocs(enableDocs). + WithNonce(nonce). WithOIDCConfig(oidcConfig). WithMiddleware(middlewares...) diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go new file mode 100644 index 0000000000..d82c6c348f --- /dev/null +++ b/pkg/api/server_test.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateNonce(t *testing.T) { + t.Parallel() + + t.Run("returns valid 32-char hex string", func(t *testing.T) { + t.Parallel() + + nonce, err := generateNonce() + require.NoError(t, err) + + assert.Len(t, nonce, 32) + assert.Regexp(t, regexp.MustCompile(`^[0-9a-f]{32}$`), nonce) + }) + + t.Run("returns unique values on successive calls", func(t *testing.T) { + t.Parallel() + + nonce1, err := generateNonce() + require.NoError(t, err) + + nonce2, err := generateNonce() + require.NoError(t, err) + + assert.NotEqual(t, nonce1, nonce2) + }) +} + +func TestListenURL_TCP(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + s := &Server{ + listener: listener, + isUnixSocket: false, + address: "127.0.0.1:0", + } + + got := s.ListenURL() + expected := fmt.Sprintf("http://%s", listener.Addr().String()) + assert.Equal(t, expected, got) +} + +func TestListenURL_UnixSocket(t *testing.T) { + t.Parallel() + + const sockPath = "/tmp/test.sock" + s := &Server{ + isUnixSocket: true, + address: sockPath, + } + + got := s.ListenURL() + assert.Equal(t, "unix:///tmp/test.sock", got) +} diff --git a/pkg/api/v1/healtcheck_test.go b/pkg/api/v1/healtcheck_test.go index 326864908c..72dc16cb09 100644 --- a/pkg/api/v1/healtcheck_test.go +++ b/pkg/api/v1/healtcheck_test.go @@ -13,6 +13,7 @@ import ( "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/container/runtime/mocks" + "github.com/stacklok/toolhive/pkg/server/discovery" ) func TestGetHealthcheck(t *testing.T) { @@ -83,3 +84,101 @@ func TestGetHealthcheck(t *testing.T) { assert.Equal(t, expectedError.Error()+"\n", resp.Body.String()) }) } + +func TestGetHealthcheck_ReturnsNonceHeader(t *testing.T) { + t.Parallel() + + // Create a new gomock controller + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with a nonce value + routes := &healthcheckRoutes{containerRuntime: mockRuntime, nonce: "test-nonce-value"} + + // Setup mock to return nil (healthy) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(nil) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context()) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response status and nonce header + assert.Equal(t, http.StatusNoContent, resp.Code) + assert.Equal(t, "test-nonce-value", resp.Header().Get(discovery.NonceHeader)) +} + +func TestGetHealthcheck_OmitsNonceHeaderWhenEmpty(t *testing.T) { + t.Parallel() + + // Create a new gomock controller + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with an empty nonce + routes := &healthcheckRoutes{containerRuntime: mockRuntime, nonce: ""} + + // Setup mock to return nil (healthy) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(nil) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context()) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response status and absence of nonce header + assert.Equal(t, http.StatusNoContent, resp.Code) + assert.Empty(t, resp.Header().Get(discovery.NonceHeader)) + assert.Empty(t, resp.Header().Values(discovery.NonceHeader)) +} + +func TestGetHealthcheck_NoNonceOnUnhealthy(t *testing.T) { + t.Parallel() + + // Create a new gomock controller + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with a nonce value + routes := &healthcheckRoutes{containerRuntime: mockRuntime, nonce: "test-nonce"} + + // Setup mock to return an error (unhealthy) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(errors.New("runtime unavailable")) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context()) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response status and absence of nonce header + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Empty(t, resp.Header().Get(discovery.NonceHeader)) + assert.Empty(t, resp.Header().Values(discovery.NonceHeader)) +} diff --git a/pkg/api/v1/healthcheck.go b/pkg/api/v1/healthcheck.go index 5b065b1cd8..dde22364e1 100644 --- a/pkg/api/v1/healthcheck.go +++ b/pkg/api/v1/healthcheck.go @@ -9,11 +9,14 @@ import ( "github.com/go-chi/chi/v5" rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/server/discovery" ) // HealthcheckRouter sets up healthcheck route. -func HealthcheckRouter(containerRuntime rt.Runtime) http.Handler { - routes := &healthcheckRoutes{containerRuntime: containerRuntime} +// The nonce parameter, when non-empty, is returned via the X-Toolhive-Nonce +// header so clients can verify they are talking to the expected server instance. +func HealthcheckRouter(containerRuntime rt.Runtime, nonce string) http.Handler { + routes := &healthcheckRoutes{containerRuntime: containerRuntime, nonce: nonce} r := chi.NewRouter() r.Get("/", routes.getHealthcheck) return r @@ -21,6 +24,7 @@ func HealthcheckRouter(containerRuntime rt.Runtime) http.Handler { type healthcheckRoutes struct { containerRuntime rt.Runtime + nonce string } // getHealthcheck @@ -35,6 +39,10 @@ func (h *healthcheckRoutes) getHealthcheck(w http.ResponseWriter, r *http.Reques http.Error(w, err.Error(), http.StatusServiceUnavailable) return } + // Return the server nonce so clients can verify instance identity. + if h.nonce != "" { + w.Header().Set(discovery.NonceHeader, h.nonce) + } // If the container runtime is running, we consider the API healthy. w.WriteHeader(http.StatusNoContent) } diff --git a/pkg/server/discovery/discover.go b/pkg/server/discovery/discover.go new file mode 100644 index 0000000000..ff40ce4726 --- /dev/null +++ b/pkg/server/discovery/discover.go @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "errors" + "os" + + "github.com/stacklok/toolhive/pkg/process" +) + +// ServerState represents the state of a discovered server. +type ServerState int + +const ( + // StateNotFound means no discovery file exists. + StateNotFound ServerState = iota + // StateRunning means the server is healthy and responding. + StateRunning + // StateStale means the discovery file exists but the process is dead. + StateStale + // StateUnhealthy means the process is alive but the server is not responding. + StateUnhealthy +) + +// String returns a human-readable representation of the server state. +func (s ServerState) String() string { + switch s { + case StateNotFound: + return "not_found" + case StateRunning: + return "running" + case StateStale: + return "stale" + case StateUnhealthy: + return "unhealthy" + default: + return "unknown" + } +} + +// DiscoverResult holds the result of a server discovery attempt. +type DiscoverResult struct { + // State is the discovered server state. + State ServerState + // Info is the server information from the discovery file. + // It is nil when State is StateNotFound. + Info *ServerInfo +} + +// Discover attempts to find a running ToolHive server by reading the discovery +// file and verifying the server is healthy. +func Discover(ctx context.Context) (*DiscoverResult, error) { + return discover(ctx, defaultDiscoveryDir()) +} + +// discover is the internal implementation that accepts a directory for testability. +func discover(ctx context.Context, dir string) (*DiscoverResult, error) { + info, err := readServerInfoFrom(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return &DiscoverResult{State: StateNotFound}, nil + } + return nil, err + } + + // Try health check with nonce verification + if err := CheckHealth(ctx, info.URL, info.Nonce); err == nil { + return &DiscoverResult{State: StateRunning, Info: info}, nil + } + + // Health check failed — check if the process is still alive + alive, err := process.FindProcess(info.PID) + if err != nil { + // Can't determine process state; treat as stale + return &DiscoverResult{State: StateStale, Info: info}, nil + } + + if !alive { + return &DiscoverResult{State: StateStale, Info: info}, nil + } + + return &DiscoverResult{State: StateUnhealthy, Info: info}, nil +} + +// CleanupStale removes a stale discovery file. Clients should call this +// when Discover returns StateStale. +func CleanupStale() error { + return RemoveServerInfo() +} diff --git a/pkg/server/discovery/discover_test.go b/pkg/server/discovery/discover_test.go new file mode 100644 index 0000000000..99b283de51 --- /dev/null +++ b/pkg/server/discovery/discover_test.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDiscover_NotFound(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateNotFound, result.State) + assert.Nil(t, result.Info) +} + +func TestDiscover_Running(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + nonce := "running-nonce" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, nonce) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + info := &ServerInfo{ + URL: srv.URL, + PID: os.Getpid(), + Nonce: nonce, + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateRunning, result.State) + assert.Equal(t, nonce, result.Info.Nonce) +} + +func TestDiscover_Stale_DeadProcess(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:1", + PID: 999999999, + Nonce: "stale-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateStale, result.State) + assert.NotNil(t, result.Info) +} + +func TestDiscover_Unhealthy_AliveButNotResponding(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Server that returns 503 (unhealthy) — process is alive (our own PID) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + info := &ServerInfo{ + URL: srv.URL, + PID: os.Getpid(), + Nonce: "unhealthy-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateUnhealthy, result.State) + assert.NotNil(t, result.Info) +} + +func TestDiscover_NonceMismatch_TreatedAsUnhealthy(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Server returns wrong nonce — simulates PID reuse scenario + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, "different-server-nonce") + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + info := &ServerInfo{ + URL: srv.URL, + PID: os.Getpid(), + Nonce: "original-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + // Nonce mismatch means health check fails, but process is alive + assert.Equal(t, StateUnhealthy, result.State) +} diff --git a/pkg/server/discovery/discovery.go b/pkg/server/discovery/discovery.go new file mode 100644 index 0000000000..04790e19ca --- /dev/null +++ b/pkg/server/discovery/discovery.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package discovery provides server discovery file management for ToolHive. +// It writes, reads, and removes a JSON file that advertises a running server +// so clients (CLI, Studio) can find it without configuration. +package discovery + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/adrg/xdg" + + "github.com/stacklok/toolhive/pkg/fileutils" +) + +const ( + // dirPermissions is the permission mode for the discovery directory. + dirPermissions = 0700 + // filePermissions is the permission mode for the discovery file. + filePermissions = 0600 +) + +// ServerInfo contains the information advertised by a running ToolHive server. +type ServerInfo struct { + // URL is the address where the server is listening. + // For TCP: "http://127.0.0.1:52341" + // For Unix sockets: "unix:///path/to/thv.sock" + URL string `json:"url"` + + // PID is the process ID of the running server. + PID int `json:"pid"` + + // Nonce is a unique identifier generated at server startup. + // It solves PID reuse: clients verify the nonce via /health to confirm + // the discovery file refers to the expected server instance. + Nonce string `json:"nonce"` + + // StartedAt is the UTC timestamp when the server started. + StartedAt time.Time `json:"started_at"` +} + +// defaultDiscoveryDir returns the default directory for the discovery file +// based on the XDG Base Directory Specification. +func defaultDiscoveryDir() string { + return filepath.Join(xdg.StateHome, "toolhive", "server") +} + +// FilePath returns the full path to the server discovery file +// using the default XDG-based directory. +func FilePath() string { + return filepath.Join(defaultDiscoveryDir(), "server.json") +} + +// WriteServerInfo atomically writes the server discovery file. +// It creates the directory if needed, rejects symlinks at the target path, +// and writes with restricted permissions (0600). +func WriteServerInfo(info *ServerInfo) error { + return writeServerInfoTo(defaultDiscoveryDir(), info) +} + +// ReadServerInfo reads and parses the server discovery file. +// Returns os.ErrNotExist if the file does not exist. +func ReadServerInfo() (*ServerInfo, error) { + return readServerInfoFrom(defaultDiscoveryDir()) +} + +// RemoveServerInfo removes the server discovery file. +// It is a no-op if the file does not exist. +func RemoveServerInfo() error { + return removeServerInfoFrom(defaultDiscoveryDir()) +} + +// writeServerInfoTo writes the discovery file into the given directory. +func writeServerInfoTo(dir string, info *ServerInfo) error { + if err := os.MkdirAll(dir, dirPermissions); err != nil { + return fmt.Errorf("failed to create discovery directory: %w", err) + } + + // Tighten permissions on the directory in case it already existed with + // looser permissions. MkdirAll only applies mode to newly-created dirs. + if err := os.Chmod(dir, dirPermissions); err != nil { + return fmt.Errorf("failed to set discovery directory permissions: %w", err) + } + + path := filepath.Join(dir, "server.json") + + // Reject symlinks at the target path to prevent symlink attacks + if fi, err := os.Lstat(path); err == nil { + if fi.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to write discovery file: %s is a symlink", path) + } + } + + data, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal server info: %w", err) + } + + if err := fileutils.AtomicWriteFile(path, data, filePermissions); err != nil { + return fmt.Errorf("failed to write discovery file: %w", err) + } + + return nil +} + +// readServerInfoFrom reads the discovery file from the given directory. +func readServerInfoFrom(dir string) (*ServerInfo, error) { + path := filepath.Join(dir, "server.json") + + // Reject symlinks on the read path, consistent with the write path. + if fi, err := os.Lstat(path); err == nil { + if fi.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("refusing to read discovery file: %s is a symlink", path) + } + } + + data, err := os.ReadFile(path) // #nosec G304 -- path is constructed from a trusted XDG directory, not user input + if err != nil { + return nil, err + } + + var info ServerInfo + if err := json.Unmarshal(data, &info); err != nil { + return nil, fmt.Errorf("failed to parse discovery file: %w", err) + } + + return &info, nil +} + +// removeServerInfoFrom removes the discovery file from the given directory. +func removeServerInfoFrom(dir string) error { + err := os.Remove(filepath.Join(dir, "server.json")) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to remove discovery file: %w", err) + } + return nil +} diff --git a/pkg/server/discovery/discovery_test.go b/pkg/server/discovery/discovery_test.go new file mode 100644 index 0000000000..1ca39a0ada --- /dev/null +++ b/pkg/server/discovery/discovery_test.go @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteReadServerInfo_TCP(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:52341", + PID: 12345, + Nonce: "test-nonce-tcp", + StartedAt: time.Date(2026, 3, 23, 10, 0, 0, 0, time.UTC), + } + + require.NoError(t, writeServerInfoTo(dir, info)) + + got, err := readServerInfoFrom(dir) + require.NoError(t, err) + assert.Equal(t, info.URL, got.URL) + assert.Equal(t, info.PID, got.PID) + assert.Equal(t, info.Nonce, got.Nonce) + assert.True(t, info.StartedAt.Equal(got.StartedAt)) +} + +func TestWriteReadServerInfo_UnixSocket(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "unix:///tmp/thv-test.sock", + PID: 54321, + Nonce: "test-nonce-unix", + StartedAt: time.Date(2026, 3, 23, 11, 0, 0, 0, time.UTC), + } + + require.NoError(t, writeServerInfoTo(dir, info)) + + got, err := readServerInfoFrom(dir) + require.NoError(t, err) + assert.Equal(t, info.URL, got.URL) + assert.Equal(t, info.PID, got.PID) + assert.Equal(t, info.Nonce, got.Nonce) +} + +func TestReadServerInfo_NotFound(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + _, err := readServerInfoFrom(dir) + require.ErrorIs(t, err, os.ErrNotExist) +} + +func TestRemoveServerInfo_Exists(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + require.NoError(t, removeServerInfoFrom(dir)) + + _, err := readServerInfoFrom(dir) + require.ErrorIs(t, err, os.ErrNotExist) +} + +func TestRemoveServerInfo_NotFound(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Should not error when file doesn't exist + require.NoError(t, removeServerInfoFrom(dir)) +} + +func TestWriteServerInfo_FilePermissions(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + fi, err := os.Stat(filepath.Join(dir, "server.json")) + require.NoError(t, err) + assert.Equal(t, os.FileMode(filePermissions), fi.Mode().Perm()) +} + +func TestWriteServerInfo_CreatesDirectoryWithCorrectPermissions(t *testing.T) { + t.Parallel() + parent := t.TempDir() + dir := filepath.Join(parent, "nested", "server") + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + fi, err := os.Stat(dir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(dirPermissions), fi.Mode().Perm()) +} + +func TestWriteServerInfo_RejectsSymlink(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Create a symlink at the target path + target := filepath.Join(t.TempDir(), "evil.json") + require.NoError(t, os.WriteFile(target, []byte("{}"), 0600)) + require.NoError(t, os.Symlink(target, filepath.Join(dir, "server.json"))) + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + err := writeServerInfoTo(dir, info) + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") +} + +func TestReadServerInfo_RejectsSymlink(t *testing.T) { + t.Parallel() + + // Write a valid server.json in a real directory. + realDir := t.TempDir() + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "real-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(realDir, info)) + + // Create a second directory with a symlink named server.json that + // points to the real file. + symlinkDir := t.TempDir() + realFile := filepath.Join(realDir, "server.json") + symlinkFile := filepath.Join(symlinkDir, "server.json") + require.NoError(t, os.Symlink(realFile, symlinkFile)) + + _, err := readServerInfoFrom(symlinkDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") +} + +func TestWriteServerInfo_TightensExistingDirPermissions(t *testing.T) { + t.Parallel() + + // Create a directory with deliberately too-loose permissions. + dir := t.TempDir() + require.NoError(t, os.Chmod(dir, 0755)) + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "tighten-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + // Verify the directory permissions were tightened to 0700. + fi, err := os.Stat(dir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(dirPermissions), fi.Mode().Perm()) +} + +func TestWriteServerInfo_OverwritesExistingFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + first := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "first", + } + require.NoError(t, writeServerInfoTo(dir, first)) + + second := &ServerInfo{ + URL: "http://127.0.0.1:9090", + PID: 2, + Nonce: "second", + } + require.NoError(t, writeServerInfoTo(dir, second)) + + got, err := readServerInfoFrom(dir) + require.NoError(t, err) + assert.Equal(t, "second", got.Nonce) + assert.Equal(t, "http://127.0.0.1:9090", got.URL) +} diff --git a/pkg/server/discovery/health.go b/pkg/server/discovery/health.go new file mode 100644 index 0000000000..754416c844 --- /dev/null +++ b/pkg/server/discovery/health.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "crypto/subtle" + "fmt" + "net" + "net/http" + "net/url" + "path/filepath" + "strings" + "time" +) + +const ( + // healthTimeout is the maximum time to wait for a health check response. + healthTimeout = 5 * time.Second + + // NonceHeader is the HTTP header used to return the server nonce. + NonceHeader = "X-Toolhive-Nonce" +) + +// CheckHealth verifies that a server at the given URL is healthy and optionally +// matches the expected nonce. It supports http:// and unix:// URL schemes. +func CheckHealth(ctx context.Context, serverURL string, expectedNonce string) error { + client, requestURL, err := buildHealthClient(serverURL) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(ctx, healthTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return fmt.Errorf("failed to create health request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("health check failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("unexpected health status: %d", resp.StatusCode) + } + + if expectedNonce != "" { + actualNonce := resp.Header.Get(NonceHeader) + if subtle.ConstantTimeCompare([]byte(actualNonce), []byte(expectedNonce)) != 1 { + return fmt.Errorf("nonce mismatch: expected %q, got %q", expectedNonce, actualNonce) + } + } + + return nil +} + +// buildHealthClient returns an HTTP client and request URL appropriate for +// the given server URL scheme. +func buildHealthClient(serverURL string) (*http.Client, string, error) { + switch { + case strings.HasPrefix(serverURL, "unix://"): + socketPath, err := ParseUnixSocketPath(serverURL) + if err != nil { + return nil, "", err + } + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, + }, + } + return client, "http://localhost/health", nil + + case strings.HasPrefix(serverURL, "http://"): + if err := ValidateLoopbackURL(serverURL); err != nil { + return nil, "", err + } + return &http.Client{}, serverURL + "/health", nil + + default: + return nil, "", fmt.Errorf("unsupported URL scheme: %s", serverURL) + } +} + +// ValidateLoopbackURL checks that an http:// URL points to a loopback address. +func ValidateLoopbackURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + host := u.Hostname() + + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("invalid host in URL: %s", host) + } + if !ip.IsLoopback() { + return fmt.Errorf("refusing health check to non-loopback address: %s", host) + } + return nil +} + +// ParseUnixSocketPath extracts and validates the socket path from a unix:// URL. +func ParseUnixSocketPath(rawURL string) (string, error) { + path := strings.TrimPrefix(rawURL, "unix://") + if path == "" { + return "", fmt.Errorf("empty unix socket path") + } + + // Check for traversal before Clean resolves it away + if strings.Contains(path, "..") { + return "", fmt.Errorf("unix socket path must not contain '..': %s", path) + } + + path = filepath.Clean(path) + + if !filepath.IsAbs(path) { + return "", fmt.Errorf("unix socket path must be absolute: %s", path) + } + + return path, nil +} diff --git a/pkg/server/discovery/health_test.go b/pkg/server/discovery/health_test.go new file mode 100644 index 0000000000..3678935999 --- /dev/null +++ b/pkg/server/discovery/health_test.go @@ -0,0 +1,171 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseUnixSocketPath_Valid(t *testing.T) { + t.Parallel() + path, err := ParseUnixSocketPath("unix:///var/run/thv.sock") + require.NoError(t, err) + assert.Equal(t, "/var/run/thv.sock", path) +} + +func TestParseUnixSocketPath_RelativePathRejected(t *testing.T) { + t.Parallel() + _, err := ParseUnixSocketPath("unix://relative/path.sock") + require.Error(t, err) + assert.Contains(t, err.Error(), "absolute") +} + +func TestParseUnixSocketPath_DotDotRejected(t *testing.T) { + t.Parallel() + _, err := ParseUnixSocketPath("unix:///var/run/../etc/evil.sock") + require.Error(t, err) + assert.Contains(t, err.Error(), "..") +} + +func TestParseUnixSocketPath_Empty(t *testing.T) { + t.Parallel() + _, err := ParseUnixSocketPath("unix://") + require.Error(t, err) + assert.Contains(t, err.Error(), "empty") +} + +func TestCheckHealth_TCP_Success(t *testing.T) { + t.Parallel() + expectedNonce := "test-nonce-123" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, expectedNonce) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + err := CheckHealth(context.Background(), srv.URL, expectedNonce) + require.NoError(t, err) +} + +func TestCheckHealth_TCP_NonceMismatch(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, "wrong-nonce") + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + err := CheckHealth(context.Background(), srv.URL, "expected-nonce") + require.Error(t, err) + assert.Contains(t, err.Error(), "nonce mismatch") +} + +func TestCheckHealth_TCP_NoNonceCheck(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + // Empty expectedNonce skips nonce check + err := CheckHealth(context.Background(), srv.URL, "") + require.NoError(t, err) +} + +func TestCheckHealth_UnixSocket_Success(t *testing.T) { + t.Parallel() + socketDir := t.TempDir() + socketPath := filepath.Join(socketDir, "test.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + expectedNonce := "unix-nonce" + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, expectedNonce) + w.WriteHeader(http.StatusNoContent) + }) + srv := &http.Server{Handler: mux} + go func() { _ = srv.Serve(listener) }() + defer srv.Close() + + err = CheckHealth(context.Background(), "unix://"+socketPath, expectedNonce) + require.NoError(t, err) +} + +func TestCheckHealth_Unreachable(t *testing.T) { + t.Parallel() + err := CheckHealth(context.Background(), "http://127.0.0.1:1", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "health check failed") +} + +func TestCheckHealth_InvalidScheme(t *testing.T) { + t.Parallel() + err := CheckHealth(context.Background(), "ftp://localhost:21", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported URL scheme") +} + +func TestCheckHealth_NonLoopbackRejected(t *testing.T) { + t.Parallel() + err := CheckHealth(context.Background(), "http://192.168.1.1:8080", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "non-loopback") +} + +func TestCheckHealth_UnhealthyStatus(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + err := CheckHealth(context.Background(), srv.URL, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected health status") +} + +func TestValidateLoopbackURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + url string + wantErr bool + }{ + {"IPv4 loopback", "http://127.0.0.1:8080", false}, + {"IPv6 loopback", "http://[::1]:8080", false}, + {"non-loopback", "http://192.168.1.1:8080", true}, + {"hostname", "http://example.com:8080", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateLoopbackURL(tt.url) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCheckHealth_UnixSocket_NotFound(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(os.TempDir(), "nonexistent-test.sock") + err := CheckHealth(context.Background(), "unix://"+socketPath, "") + require.Error(t, err) +} diff --git a/pkg/skills/client/client.go b/pkg/skills/client/client.go index bd4adc5a7d..def2304668 100644 --- a/pkg/skills/client/client.go +++ b/pkg/skills/client/client.go @@ -11,6 +11,8 @@ import ( "errors" "fmt" "io" + "log/slog" + "net" "net/http" "net/url" "strings" @@ -18,6 +20,7 @@ import ( "github.com/stacklok/toolhive-core/env" "github.com/stacklok/toolhive-core/httperr" + "github.com/stacklok/toolhive/pkg/server/discovery" "github.com/stacklok/toolhive/pkg/skills" ) @@ -74,19 +77,80 @@ func NewClient(baseURL string, opts ...Option) *Client { return c } -// NewDefaultClient creates a Skills API client using the TOOLHIVE_API_URL -// environment variable, falling back to http://127.0.0.1:8080. +// NewDefaultClient creates a Skills API client by trying, in order: +// 1. The TOOLHIVE_API_URL environment variable (explicit override) +// 2. The server discovery file (auto-detected running server) +// 3. The default URL http://127.0.0.1:8080 func NewDefaultClient(opts ...Option) *Client { return newDefaultClientWithEnv(&env.OSReader{}, opts...) } // newDefaultClientWithEnv is the testable core of NewDefaultClient. func newDefaultClientWithEnv(envReader env.Reader, opts ...Option) *Client { - base := envReader.Getenv(envAPIURL) - if base == "" { - base = defaultBaseURL + // 1. Explicit env var override always wins. + if base := envReader.Getenv(envAPIURL); base != "" { + return NewClient(base, opts...) + } + + // 2. Try server discovery. + if base, httpOpts := resolveViaDiscovery(); base != "" { + merged := make([]Option, 0, len(httpOpts)+len(opts)) + merged = append(merged, httpOpts...) + merged = append(merged, opts...) + return NewClient(base, merged...) + } + + // 3. Fall back to the default URL. + return NewClient(defaultBaseURL, opts...) +} + +// resolveViaDiscovery attempts to find a running server via the discovery file. +// It returns the base URL and any additional options (e.g. a Unix socket transport). +// On failure it returns empty values and the caller falls back to the default. +func resolveViaDiscovery() (string, []Option) { + result, err := discovery.Discover(context.Background()) + if err != nil { + slog.Debug("server discovery failed", "error", err) + return "", nil + } + if result.State != discovery.StateRunning { + return "", nil + } + + serverURL := result.Info.URL + + // Validate and configure transport based on URL scheme. + switch { + case strings.HasPrefix(serverURL, "unix://"): + socketPath, err := discovery.ParseUnixSocketPath(serverURL) + if err != nil { + slog.Debug("invalid unix socket path in discovery file", "error", err) + return "", nil + } + // For Unix sockets, the base URL is http://localhost (the dialer handles routing) + // and we override the HTTP transport to dial the socket. + transport := &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, + } + opt := WithHTTPClient(&http.Client{ + Timeout: defaultTimeout, + Transport: transport, + }) + return "http://localhost", []Option{opt} + + case strings.HasPrefix(serverURL, "http://"): + if err := discovery.ValidateLoopbackURL(serverURL); err != nil { + slog.Debug("discovery URL is not a loopback address", "url", serverURL, "error", err) + return "", nil + } + return serverURL, nil + + default: + slog.Debug("unsupported URL scheme in discovery file", "url", serverURL) + return "", nil } - return NewClient(base, opts...) } // --- SkillService implementation --- From dde283d02a4209f5e76aa60c370032290f19d5b4 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 24 Mar 2026 11:04:44 +0000 Subject: [PATCH 2/4] Address review feedback on server discovery - Wrap writeDiscoveryFile check-then-write in WithFileLock to prevent TOCTOU race when two servers start simultaneously - Log FindProcess errors at Debug level instead of silently discarding - Consolidate ListenURL tests into a table-driven test - Rename healtcheck_test.go to healthcheck_test.go (fix typo) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Juan Antonio Osorio --- pkg/api/server.go | 69 ++++++++++--------- pkg/api/server_test.go | 65 ++++++++++------- ...healtcheck_test.go => healthcheck_test.go} | 0 pkg/server/discovery/discover.go | 3 +- 4 files changed, 80 insertions(+), 57 deletions(-) rename pkg/api/v1/{healtcheck_test.go => healthcheck_test.go} (100%) diff --git a/pkg/api/server.go b/pkg/api/server.go index 08a96b3958..6c56e3eca5 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -41,6 +41,7 @@ import ( "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/container" "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/fileutils" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/registry" @@ -593,46 +594,50 @@ func (s *Server) Start(ctx context.Context) error { // writeDiscoveryFile writes the server discovery file if a nonce is configured. // It checks for an existing healthy server first to prevent silent orphaning. +// The entire check-then-write sequence is wrapped in a file lock to prevent +// TOCTOU races when two servers start simultaneously. func (s *Server) writeDiscoveryFile(ctx context.Context) error { if s.nonce == "" { return nil } - // Guard against overwriting another server's discovery file. - result, err := discovery.Discover(ctx) - if err != nil { - slog.Debug("discovery check failed, proceeding with startup", "error", err) - } else { - switch result.State { - case discovery.StateRunning: - return fmt.Errorf("another ToolHive server is already running at %s (PID %d)", result.Info.URL, result.Info.PID) - case discovery.StateStale: - slog.Debug("cleaning up stale discovery file", "pid", result.Info.PID) - if err := discovery.CleanupStale(); err != nil { - slog.Warn("failed to clean up stale discovery file", "error", err) + return fileutils.WithFileLock(discovery.FilePath(), func() error { + // Guard against overwriting another server's discovery file. + result, err := discovery.Discover(ctx) + if err != nil { + slog.Debug("discovery check failed, proceeding with startup", "error", err) + } else { + switch result.State { + case discovery.StateRunning: + return fmt.Errorf("another ToolHive server is already running at %s (PID %d)", result.Info.URL, result.Info.PID) + case discovery.StateStale: + slog.Debug("cleaning up stale discovery file", "pid", result.Info.PID) + if err := discovery.CleanupStale(); err != nil { + slog.Warn("failed to clean up stale discovery file", "error", err) + } + case discovery.StateUnhealthy: + // The process is alive but not responding to health checks. + // This can happen after a crash-restart where the old process + // is hung. We intentionally overwrite the discovery file so + // this new server becomes discoverable. + slog.Warn("existing server is unhealthy, overwriting discovery file", "pid", result.Info.PID) + case discovery.StateNotFound: + // No existing server, proceed normally. } - case discovery.StateUnhealthy: - // The process is alive but not responding to health checks. - // This can happen after a crash-restart where the old process - // is hung. We intentionally overwrite the discovery file so - // this new server becomes discoverable. - slog.Warn("existing server is unhealthy, overwriting discovery file", "pid", result.Info.PID) - case discovery.StateNotFound: - // No existing server, proceed normally. } - } - info := &discovery.ServerInfo{ - URL: s.ListenURL(), - PID: os.Getpid(), - Nonce: s.nonce, - StartedAt: time.Now().UTC(), - } - if err := discovery.WriteServerInfo(info); err != nil { - return fmt.Errorf("failed to write discovery file: %w", err) - } - slog.Debug("wrote discovery file", "url", info.URL, "pid", info.PID) - return nil + info := &discovery.ServerInfo{ + URL: s.ListenURL(), + PID: os.Getpid(), + Nonce: s.nonce, + StartedAt: time.Now().UTC(), + } + if err := discovery.WriteServerInfo(info); err != nil { + return fmt.Errorf("failed to write discovery file: %w", err) + } + slog.Debug("wrote discovery file", "url", info.URL, "pid", info.PID) + return nil + }) } // shutdown gracefully shuts down the server diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index d82c6c348f..f0c0923401 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -39,33 +39,50 @@ func TestGenerateNonce(t *testing.T) { }) } -func TestListenURL_TCP(t *testing.T) { +func TestListenURL(t *testing.T) { t.Parallel() - listener, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer listener.Close() - - s := &Server{ - listener: listener, - isUnixSocket: false, - address: "127.0.0.1:0", + tests := []struct { + name string + server func(t *testing.T) *Server + expected func(s *Server) string + }{ + { + name: "TCP returns http URL with actual port", + server: func(t *testing.T) *Server { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { listener.Close() }) + return &Server{ + listener: listener, + isUnixSocket: false, + address: "127.0.0.1:0", + } + }, + expected: func(s *Server) string { + return fmt.Sprintf("http://%s", s.listener.Addr().String()) + }, + }, + { + name: "Unix socket returns unix URL", + server: func(_ *testing.T) *Server { + return &Server{ + isUnixSocket: true, + address: "/tmp/test.sock", + } + }, + expected: func(_ *Server) string { + return "unix:///tmp/test.sock" + }, + }, } - got := s.ListenURL() - expected := fmt.Sprintf("http://%s", listener.Addr().String()) - assert.Equal(t, expected, got) -} - -func TestListenURL_UnixSocket(t *testing.T) { - t.Parallel() - - const sockPath = "/tmp/test.sock" - s := &Server{ - isUnixSocket: true, - address: sockPath, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + s := tt.server(t) + assert.Equal(t, tt.expected(s), s.ListenURL()) + }) } - - got := s.ListenURL() - assert.Equal(t, "unix:///tmp/test.sock", got) } diff --git a/pkg/api/v1/healtcheck_test.go b/pkg/api/v1/healthcheck_test.go similarity index 100% rename from pkg/api/v1/healtcheck_test.go rename to pkg/api/v1/healthcheck_test.go diff --git a/pkg/server/discovery/discover.go b/pkg/server/discovery/discover.go index ff40ce4726..c9d9fcbee4 100644 --- a/pkg/server/discovery/discover.go +++ b/pkg/server/discovery/discover.go @@ -6,6 +6,7 @@ package discovery import ( "context" "errors" + "log/slog" "os" "github.com/stacklok/toolhive/pkg/process" @@ -74,7 +75,7 @@ func discover(ctx context.Context, dir string) (*DiscoverResult, error) { // Health check failed — check if the process is still alive alive, err := process.FindProcess(info.PID) if err != nil { - // Can't determine process state; treat as stale + slog.Debug("cannot determine process state, treating as stale", "pid", info.PID, "error", err) return &DiscoverResult{State: StateStale, Info: info}, nil } From 5ea0d4dfe853ac296396d3cdbf1fe64ab9ceb009 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Wed, 25 Mar 2026 13:02:40 +0000 Subject: [PATCH 3/4] Create discovery directory before acquiring lock file The discovery lock file is created in the same directory as server.json, but the directory may not exist on a fresh system. MkdirAll was called inside the lock callback (via WriteServerInfo), but the lock acquisition itself needs the directory to already exist. Create the directory before calling WithFileLock so the lock file can be written. Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/api/server.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pkg/api/server.go b/pkg/api/server.go index 6c56e3eca5..43e689fc53 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -601,7 +601,14 @@ func (s *Server) writeDiscoveryFile(ctx context.Context) error { return nil } - return fileutils.WithFileLock(discovery.FilePath(), func() error { + // Ensure the discovery directory exists before acquiring the lock, + // since the lock file is created in the same directory. + discoveryPath := discovery.FilePath() + if err := os.MkdirAll(filepath.Dir(discoveryPath), 0700); err != nil { + return fmt.Errorf("failed to create discovery directory: %w", err) + } + + return fileutils.WithFileLock(discoveryPath, func() error { // Guard against overwriting another server's discovery file. result, err := discovery.Discover(ctx) if err != nil { From 3de135a0ff6755b5ebca73718712d055c10eb84a Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Wed, 25 Mar 2026 16:38:42 +0200 Subject: [PATCH 4/4] Address review feedback on discovery wiring - Extract shared HTTPClientForURL in the discovery package to deduplicate transport setup between health.go and the skills client - Propagate context.Context through NewDefaultClient and resolveViaDiscovery instead of using context.Background() - Add comment explaining intentional opts-shadowing order so caller-supplied options can override discovery defaults - Use url.JoinPath in buildHealthClient instead of string concatenation Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/thv/app/skill_build.go | 2 +- cmd/thv/app/skill_helpers.go | 8 +++-- cmd/thv/app/skill_info.go | 2 +- cmd/thv/app/skill_install.go | 2 +- cmd/thv/app/skill_list.go | 2 +- cmd/thv/app/skill_push.go | 2 +- cmd/thv/app/skill_uninstall.go | 2 +- cmd/thv/app/skill_validate.go | 2 +- pkg/server/discovery/health.go | 22 +++++++++++-- pkg/skills/client/client.go | 55 ++++++++++---------------------- pkg/skills/client/client_test.go | 6 ++-- 11 files changed, 51 insertions(+), 54 deletions(-) diff --git a/cmd/thv/app/skill_build.go b/cmd/thv/app/skill_build.go index 3838d9d223..4d58e75068 100644 --- a/cmd/thv/app/skill_build.go +++ b/cmd/thv/app/skill_build.go @@ -39,7 +39,7 @@ func skillBuildCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to resolve path: %w", err) } - c := newSkillClient() + c := newSkillClient(cmd.Context()) result, err := c.Build(cmd.Context(), skills.BuildOptions{ Path: absPath, diff --git a/cmd/thv/app/skill_helpers.go b/cmd/thv/app/skill_helpers.go index deab8034df..06f0bdc6f4 100644 --- a/cmd/thv/app/skill_helpers.go +++ b/cmd/thv/app/skill_helpers.go @@ -4,6 +4,7 @@ package app import ( + "context" "errors" "fmt" @@ -14,8 +15,9 @@ import ( ) // newSkillClient creates a new Skills API HTTP client using default settings. -func newSkillClient() *skillclient.Client { - return skillclient.NewDefaultClient() +// The context is used for server discovery; it is not stored. +func newSkillClient(ctx context.Context) *skillclient.Client { + return skillclient.NewDefaultClient(ctx) } // completeSkillNames provides shell completion for installed skill names. @@ -24,7 +26,7 @@ func completeSkillNames(cmd *cobra.Command, args []string, _ string) ([]string, return nil, cobra.ShellCompDirectiveNoFileComp } - c := newSkillClient() + c := newSkillClient(cmd.Context()) installed, err := c.List(cmd.Context(), skills.ListOptions{}) if err != nil { return nil, cobra.ShellCompDirectiveError diff --git a/cmd/thv/app/skill_info.go b/cmd/thv/app/skill_info.go index 5305143126..440429085d 100644 --- a/cmd/thv/app/skill_info.go +++ b/cmd/thv/app/skill_info.go @@ -43,7 +43,7 @@ func init() { } func skillInfoCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) info, err := c.Info(cmd.Context(), skills.InfoOptions{ Name: args[0], diff --git a/cmd/thv/app/skill_install.go b/cmd/thv/app/skill_install.go index 592c1039ce..389ad10a3f 100644 --- a/cmd/thv/app/skill_install.go +++ b/cmd/thv/app/skill_install.go @@ -42,7 +42,7 @@ func init() { } func skillInstallCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) _, err := c.Install(cmd.Context(), skills.InstallOptions{ Name: args[0], diff --git a/cmd/thv/app/skill_list.go b/cmd/thv/app/skill_list.go index e80fd5d4ec..0f749cc1ba 100644 --- a/cmd/thv/app/skill_list.go +++ b/cmd/thv/app/skill_list.go @@ -47,7 +47,7 @@ func init() { } func skillListCmdFunc(cmd *cobra.Command, _ []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) installed, err := c.List(cmd.Context(), skills.ListOptions{ Scope: skills.Scope(skillListScope), diff --git a/cmd/thv/app/skill_push.go b/cmd/thv/app/skill_push.go index 5eaa690720..9edd1c3af3 100644 --- a/cmd/thv/app/skill_push.go +++ b/cmd/thv/app/skill_push.go @@ -22,7 +22,7 @@ func init() { } func skillPushCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) err := c.Push(cmd.Context(), skills.PushOptions{ Reference: args[0], diff --git a/cmd/thv/app/skill_uninstall.go b/cmd/thv/app/skill_uninstall.go index 20c76cd58e..9f5ff6a755 100644 --- a/cmd/thv/app/skill_uninstall.go +++ b/cmd/thv/app/skill_uninstall.go @@ -39,7 +39,7 @@ func init() { } func skillUninstallCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) err := c.Uninstall(cmd.Context(), skills.UninstallOptions{ Name: args[0], diff --git a/cmd/thv/app/skill_validate.go b/cmd/thv/app/skill_validate.go index d696c433b1..49e7254a06 100644 --- a/cmd/thv/app/skill_validate.go +++ b/cmd/thv/app/skill_validate.go @@ -37,7 +37,7 @@ func skillValidateCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to resolve path: %w", err) } - c := newSkillClient() + c := newSkillClient(cmd.Context()) result, err := c.Validate(cmd.Context(), absPath) if err != nil { diff --git a/pkg/server/discovery/health.go b/pkg/server/discovery/health.go index 754416c844..cf80cb718d 100644 --- a/pkg/server/discovery/health.go +++ b/pkg/server/discovery/health.go @@ -62,6 +62,24 @@ func CheckHealth(ctx context.Context, serverURL string, expectedNonce string) er // buildHealthClient returns an HTTP client and request URL appropriate for // the given server URL scheme. func buildHealthClient(serverURL string) (*http.Client, string, error) { + client, baseURL, err := HTTPClientForURL(serverURL) + if err != nil { + return nil, "", err + } + healthURL, err := url.JoinPath(baseURL, "health") + if err != nil { + return nil, "", fmt.Errorf("failed to build health URL: %w", err) + } + return client, healthURL, nil +} + +// HTTPClientForURL returns an HTTP client configured for the given server URL +// and the base URL to use for requests. For unix:// URLs it creates a client +// with a Unix socket transport and returns "http://localhost" as the base URL. +// For http:// URLs it validates the host is a loopback address and returns a +// default client. The returned client has no timeout set; callers should apply +// their own timeout via context or client.Timeout. +func HTTPClientForURL(serverURL string) (*http.Client, string, error) { switch { case strings.HasPrefix(serverURL, "unix://"): socketPath, err := ParseUnixSocketPath(serverURL) @@ -75,13 +93,13 @@ func buildHealthClient(serverURL string) (*http.Client, string, error) { }, }, } - return client, "http://localhost/health", nil + return client, "http://localhost", nil case strings.HasPrefix(serverURL, "http://"): if err := ValidateLoopbackURL(serverURL); err != nil { return nil, "", err } - return &http.Client{}, serverURL + "/health", nil + return &http.Client{}, serverURL, nil default: return nil, "", fmt.Errorf("unsupported URL scheme: %s", serverURL) diff --git a/pkg/skills/client/client.go b/pkg/skills/client/client.go index def2304668..7878fb4222 100644 --- a/pkg/skills/client/client.go +++ b/pkg/skills/client/client.go @@ -12,7 +12,6 @@ import ( "fmt" "io" "log/slog" - "net" "net/http" "net/url" "strings" @@ -81,19 +80,23 @@ func NewClient(baseURL string, opts ...Option) *Client { // 1. The TOOLHIVE_API_URL environment variable (explicit override) // 2. The server discovery file (auto-detected running server) // 3. The default URL http://127.0.0.1:8080 -func NewDefaultClient(opts ...Option) *Client { - return newDefaultClientWithEnv(&env.OSReader{}, opts...) +// +// The context is used for the server discovery health check; it is not stored. +func NewDefaultClient(ctx context.Context, opts ...Option) *Client { + return newDefaultClientWithEnv(ctx, &env.OSReader{}, opts...) } // newDefaultClientWithEnv is the testable core of NewDefaultClient. -func newDefaultClientWithEnv(envReader env.Reader, opts ...Option) *Client { +func newDefaultClientWithEnv(ctx context.Context, envReader env.Reader, opts ...Option) *Client { // 1. Explicit env var override always wins. if base := envReader.Getenv(envAPIURL); base != "" { return NewClient(base, opts...) } // 2. Try server discovery. - if base, httpOpts := resolveViaDiscovery(); base != "" { + if base, httpOpts := resolveViaDiscovery(ctx); base != "" { + // Discovery opts go first so caller-supplied opts can override them + // (e.g. a caller-provided WithTimeout replaces the discovery default). merged := make([]Option, 0, len(httpOpts)+len(opts)) merged = append(merged, httpOpts...) merged = append(merged, opts...) @@ -107,8 +110,8 @@ func newDefaultClientWithEnv(envReader env.Reader, opts ...Option) *Client { // resolveViaDiscovery attempts to find a running server via the discovery file. // It returns the base URL and any additional options (e.g. a Unix socket transport). // On failure it returns empty values and the caller falls back to the default. -func resolveViaDiscovery() (string, []Option) { - result, err := discovery.Discover(context.Background()) +func resolveViaDiscovery(ctx context.Context) (string, []Option) { + result, err := discovery.Discover(ctx) if err != nil { slog.Debug("server discovery failed", "error", err) return "", nil @@ -117,40 +120,14 @@ func resolveViaDiscovery() (string, []Option) { return "", nil } - serverURL := result.Info.URL - - // Validate and configure transport based on URL scheme. - switch { - case strings.HasPrefix(serverURL, "unix://"): - socketPath, err := discovery.ParseUnixSocketPath(serverURL) - if err != nil { - slog.Debug("invalid unix socket path in discovery file", "error", err) - return "", nil - } - // For Unix sockets, the base URL is http://localhost (the dialer handles routing) - // and we override the HTTP transport to dial the socket. - transport := &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPath) - }, - } - opt := WithHTTPClient(&http.Client{ - Timeout: defaultTimeout, - Transport: transport, - }) - return "http://localhost", []Option{opt} - - case strings.HasPrefix(serverURL, "http://"): - if err := discovery.ValidateLoopbackURL(serverURL); err != nil { - slog.Debug("discovery URL is not a loopback address", "url", serverURL, "error", err) - return "", nil - } - return serverURL, nil - - default: - slog.Debug("unsupported URL scheme in discovery file", "url", serverURL) + client, baseURL, err := discovery.HTTPClientForURL(result.Info.URL) + if err != nil { + slog.Debug("invalid URL in discovery file", "url", result.Info.URL, "error", err) return "", nil } + client.Timeout = defaultTimeout + + return baseURL, []Option{WithHTTPClient(client)} } // --- SkillService implementation --- diff --git a/pkg/skills/client/client_test.go b/pkg/skills/client/client_test.go index dc0e7af04b..9dc2d92fd3 100644 --- a/pkg/skills/client/client_test.go +++ b/pkg/skills/client/client_test.go @@ -564,7 +564,7 @@ func TestNewDefaultClient(t *testing.T) { mockEnv := envmocks.NewMockReader(ctrl) mockEnv.EXPECT().Getenv(envAPIURL).Return("") - c := newDefaultClientWithEnv(mockEnv) + c := newDefaultClientWithEnv(t.Context(), mockEnv) assert.Equal(t, defaultBaseURL, c.baseURL) }) @@ -574,7 +574,7 @@ func TestNewDefaultClient(t *testing.T) { mockEnv := envmocks.NewMockReader(ctrl) mockEnv.EXPECT().Getenv(envAPIURL).Return("http://localhost:9999") - c := newDefaultClientWithEnv(mockEnv) + c := newDefaultClientWithEnv(t.Context(), mockEnv) assert.Equal(t, "http://localhost:9999", c.baseURL) }) @@ -584,7 +584,7 @@ func TestNewDefaultClient(t *testing.T) { mockEnv := envmocks.NewMockReader(ctrl) mockEnv.EXPECT().Getenv(envAPIURL).Return("") - c := newDefaultClientWithEnv(mockEnv, WithTimeout(5*time.Second)) + c := newDefaultClientWithEnv(t.Context(), mockEnv, WithTimeout(5*time.Second)) assert.Equal(t, 5*time.Second, c.httpClient.Timeout) }) }