diff --git a/cmd/thv/app/skill_build.go b/cmd/thv/app/skill_build.go index 3838d9d223..4d58e75068 100644 --- a/cmd/thv/app/skill_build.go +++ b/cmd/thv/app/skill_build.go @@ -39,7 +39,7 @@ func skillBuildCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to resolve path: %w", err) } - c := newSkillClient() + c := newSkillClient(cmd.Context()) result, err := c.Build(cmd.Context(), skills.BuildOptions{ Path: absPath, diff --git a/cmd/thv/app/skill_helpers.go b/cmd/thv/app/skill_helpers.go index deab8034df..06f0bdc6f4 100644 --- a/cmd/thv/app/skill_helpers.go +++ b/cmd/thv/app/skill_helpers.go @@ -4,6 +4,7 @@ package app import ( + "context" "errors" "fmt" @@ -14,8 +15,9 @@ import ( ) // newSkillClient creates a new Skills API HTTP client using default settings. -func newSkillClient() *skillclient.Client { - return skillclient.NewDefaultClient() +// The context is used for server discovery; it is not stored. +func newSkillClient(ctx context.Context) *skillclient.Client { + return skillclient.NewDefaultClient(ctx) } // completeSkillNames provides shell completion for installed skill names. @@ -24,7 +26,7 @@ func completeSkillNames(cmd *cobra.Command, args []string, _ string) ([]string, return nil, cobra.ShellCompDirectiveNoFileComp } - c := newSkillClient() + c := newSkillClient(cmd.Context()) installed, err := c.List(cmd.Context(), skills.ListOptions{}) if err != nil { return nil, cobra.ShellCompDirectiveError diff --git a/cmd/thv/app/skill_info.go b/cmd/thv/app/skill_info.go index 5305143126..440429085d 100644 --- a/cmd/thv/app/skill_info.go +++ b/cmd/thv/app/skill_info.go @@ -43,7 +43,7 @@ func init() { } func skillInfoCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) info, err := c.Info(cmd.Context(), skills.InfoOptions{ Name: args[0], diff --git a/cmd/thv/app/skill_install.go b/cmd/thv/app/skill_install.go index 592c1039ce..389ad10a3f 100644 --- a/cmd/thv/app/skill_install.go +++ b/cmd/thv/app/skill_install.go @@ -42,7 +42,7 @@ func init() { } func skillInstallCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) _, err := c.Install(cmd.Context(), skills.InstallOptions{ Name: args[0], diff --git a/cmd/thv/app/skill_list.go b/cmd/thv/app/skill_list.go index e80fd5d4ec..0f749cc1ba 100644 --- a/cmd/thv/app/skill_list.go +++ b/cmd/thv/app/skill_list.go @@ -47,7 +47,7 @@ func init() { } func skillListCmdFunc(cmd *cobra.Command, _ []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) installed, err := c.List(cmd.Context(), skills.ListOptions{ Scope: skills.Scope(skillListScope), diff --git a/cmd/thv/app/skill_push.go b/cmd/thv/app/skill_push.go index 5eaa690720..9edd1c3af3 100644 --- a/cmd/thv/app/skill_push.go +++ b/cmd/thv/app/skill_push.go @@ -22,7 +22,7 @@ func init() { } func skillPushCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) err := c.Push(cmd.Context(), skills.PushOptions{ Reference: args[0], diff --git a/cmd/thv/app/skill_uninstall.go b/cmd/thv/app/skill_uninstall.go index 20c76cd58e..9f5ff6a755 100644 --- a/cmd/thv/app/skill_uninstall.go +++ b/cmd/thv/app/skill_uninstall.go @@ -39,7 +39,7 @@ func init() { } func skillUninstallCmdFunc(cmd *cobra.Command, args []string) error { - c := newSkillClient() + c := newSkillClient(cmd.Context()) err := c.Uninstall(cmd.Context(), skills.UninstallOptions{ Name: args[0], diff --git a/cmd/thv/app/skill_validate.go b/cmd/thv/app/skill_validate.go index d696c433b1..49e7254a06 100644 --- a/cmd/thv/app/skill_validate.go +++ b/cmd/thv/app/skill_validate.go @@ -37,7 +37,7 @@ func skillValidateCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to resolve path: %w", err) } - c := newSkillClient() + c := newSkillClient(cmd.Context()) result, err := c.Validate(cmd.Context(), absPath) if err != nil { diff --git a/pkg/api/server.go b/pkg/api/server.go index 23953f3966..43e689fc53 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -17,6 +17,8 @@ package api import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "io" @@ -39,9 +41,11 @@ import ( "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/container" "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/fileutils" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/registry" + "github.com/stacklok/toolhive/pkg/server/discovery" "github.com/stacklok/toolhive/pkg/skills" "github.com/stacklok/toolhive/pkg/skills/gitresolver" "github.com/stacklok/toolhive/pkg/skills/skillsvc" @@ -55,6 +59,7 @@ const ( middlewareTimeout = 60 * time.Second readHeaderTimeout = 10 * time.Second shutdownTimeout = 30 * time.Second + nonceBytes = 16 socketPermissions = 0660 // Socket file permissions (owner/group read-write) maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size ) @@ -65,6 +70,7 @@ type ServerBuilder struct { isUnixSocket bool debugMode bool enableDocs bool + nonce string oidcConfig *auth.TokenValidatorConfig middlewares []func(http.Handler) http.Handler customRoutes map[string]http.Handler @@ -108,6 +114,14 @@ func (b *ServerBuilder) WithDocs(enableDocs bool) *ServerBuilder { return b } +// WithNonce sets the server instance nonce used for discovery verification. +// When non-empty, the server writes a discovery file on startup and returns +// the nonce in the X-Toolhive-Nonce health check header. +func (b *ServerBuilder) WithNonce(nonce string) *ServerBuilder { + b.nonce = nonce + return b +} + // WithOIDCConfig sets the OIDC configuration func (b *ServerBuilder) WithOIDCConfig(oidcConfig *auth.TokenValidatorConfig) *ServerBuilder { b.oidcConfig = oidcConfig @@ -297,7 +311,7 @@ func (b *ServerBuilder) setupDefaultRoutes(r *chi.Mux) { // All other routes get standard timeout standardRouters := map[string]http.Handler{ - "/health": v1.HealthcheckRouter(b.containerRuntime), + "/health": v1.HealthcheckRouter(b.containerRuntime, b.nonce), "/api/v1beta/version": v1.VersionRouter(), "/api/v1beta/registry": v1.RegistryRouter(true), "/api/v1beta/discovery": v1.DiscoveryRouter(), @@ -504,6 +518,7 @@ type Server struct { address string isUnixSocket bool addrType string + nonce string storeCloser io.Closer } @@ -532,14 +547,29 @@ func NewServer(ctx context.Context, builder *ServerBuilder) (*Server, error) { address: builder.address, isUnixSocket: builder.isUnixSocket, addrType: addrType, + nonce: builder.nonce, storeCloser: builder.skillStoreCloser, }, nil } +// ListenURL returns the URL where the server is listening, using the actual +// bound address from the listener (important when binding to port 0). +func (s *Server) ListenURL() string { + if s.isUnixSocket { + return fmt.Sprintf("unix://%s", s.address) + } + return fmt.Sprintf("http://%s", s.listener.Addr().String()) +} + // Start starts the server and blocks until the context is cancelled func (s *Server) Start(ctx context.Context) error { slog.Info("starting server", "type", s.addrType, "address", s.address) + // Write server discovery file so clients can find this instance. + if err := s.writeDiscoveryFile(ctx); err != nil { + return err + } + // Start server in a goroutine serverErr := make(chan error, 1) go func() { @@ -562,6 +592,61 @@ func (s *Server) Start(ctx context.Context) error { } } +// writeDiscoveryFile writes the server discovery file if a nonce is configured. +// It checks for an existing healthy server first to prevent silent orphaning. +// The entire check-then-write sequence is wrapped in a file lock to prevent +// TOCTOU races when two servers start simultaneously. +func (s *Server) writeDiscoveryFile(ctx context.Context) error { + if s.nonce == "" { + return nil + } + + // Ensure the discovery directory exists before acquiring the lock, + // since the lock file is created in the same directory. + discoveryPath := discovery.FilePath() + if err := os.MkdirAll(filepath.Dir(discoveryPath), 0700); err != nil { + return fmt.Errorf("failed to create discovery directory: %w", err) + } + + return fileutils.WithFileLock(discoveryPath, func() error { + // Guard against overwriting another server's discovery file. + result, err := discovery.Discover(ctx) + if err != nil { + slog.Debug("discovery check failed, proceeding with startup", "error", err) + } else { + switch result.State { + case discovery.StateRunning: + return fmt.Errorf("another ToolHive server is already running at %s (PID %d)", result.Info.URL, result.Info.PID) + case discovery.StateStale: + slog.Debug("cleaning up stale discovery file", "pid", result.Info.PID) + if err := discovery.CleanupStale(); err != nil { + slog.Warn("failed to clean up stale discovery file", "error", err) + } + case discovery.StateUnhealthy: + // The process is alive but not responding to health checks. + // This can happen after a crash-restart where the old process + // is hung. We intentionally overwrite the discovery file so + // this new server becomes discoverable. + slog.Warn("existing server is unhealthy, overwriting discovery file", "pid", result.Info.PID) + case discovery.StateNotFound: + // No existing server, proceed normally. + } + } + + info := &discovery.ServerInfo{ + URL: s.ListenURL(), + PID: os.Getpid(), + Nonce: s.nonce, + StartedAt: time.Now().UTC(), + } + if err := discovery.WriteServerInfo(info); err != nil { + return fmt.Errorf("failed to write discovery file: %w", err) + } + slog.Debug("wrote discovery file", "url", info.URL, "pid", info.PID) + return nil + }) +} + // shutdown gracefully shuts down the server func (s *Server) shutdown() error { shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) @@ -579,6 +664,11 @@ func (s *Server) shutdown() error { // cleanup performs cleanup operations func (s *Server) cleanup() { + if s.nonce != "" { + if err := discovery.RemoveServerInfo(); err != nil { + slog.Warn("failed to remove discovery file", "error", err) + } + } if s.storeCloser != nil { if err := s.storeCloser.Close(); err != nil { slog.Warn("failed to close skill store", "error", err) @@ -653,6 +743,16 @@ func (a *clientPathAdapter) ListSkillSupportingClients() []string { return result } +// generateNonce creates a cryptographically random nonce for server instance +// identification. It returns a 32-character hex string (16 random bytes). +func generateNonce() (string, error) { + b := make([]byte, nonceBytes) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate server nonce: %w", err) + } + return hex.EncodeToString(b), nil +} + // Serve starts the server on the given address and serves the API. // It is assumed that the caller sets up appropriate signal handling. // If isUnixSocket is true, address is treated as a UNIX socket path. @@ -666,11 +766,17 @@ func Serve( oidcConfig *auth.TokenValidatorConfig, middlewares ...func(http.Handler) http.Handler, ) error { + nonce, err := generateNonce() + if err != nil { + return err + } + builder := NewServerBuilder(). WithAddress(address). WithUnixSocket(isUnixSocket). WithDebugMode(debugMode). WithDocs(enableDocs). + WithNonce(nonce). WithOIDCConfig(oidcConfig). WithMiddleware(middlewares...) diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go new file mode 100644 index 0000000000..f0c0923401 --- /dev/null +++ b/pkg/api/server_test.go @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateNonce(t *testing.T) { + t.Parallel() + + t.Run("returns valid 32-char hex string", func(t *testing.T) { + t.Parallel() + + nonce, err := generateNonce() + require.NoError(t, err) + + assert.Len(t, nonce, 32) + assert.Regexp(t, regexp.MustCompile(`^[0-9a-f]{32}$`), nonce) + }) + + t.Run("returns unique values on successive calls", func(t *testing.T) { + t.Parallel() + + nonce1, err := generateNonce() + require.NoError(t, err) + + nonce2, err := generateNonce() + require.NoError(t, err) + + assert.NotEqual(t, nonce1, nonce2) + }) +} + +func TestListenURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + server func(t *testing.T) *Server + expected func(s *Server) string + }{ + { + name: "TCP returns http URL with actual port", + server: func(t *testing.T) *Server { + t.Helper() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { listener.Close() }) + return &Server{ + listener: listener, + isUnixSocket: false, + address: "127.0.0.1:0", + } + }, + expected: func(s *Server) string { + return fmt.Sprintf("http://%s", s.listener.Addr().String()) + }, + }, + { + name: "Unix socket returns unix URL", + server: func(_ *testing.T) *Server { + return &Server{ + isUnixSocket: true, + address: "/tmp/test.sock", + } + }, + expected: func(_ *Server) string { + return "unix:///tmp/test.sock" + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + s := tt.server(t) + assert.Equal(t, tt.expected(s), s.ListenURL()) + }) + } +} diff --git a/pkg/api/v1/healtcheck_test.go b/pkg/api/v1/healtcheck_test.go deleted file mode 100644 index 326864908c..0000000000 --- a/pkg/api/v1/healtcheck_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package v1 - -import ( - "errors" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/container/runtime/mocks" -) - -func TestGetHealthcheck(t *testing.T) { - t.Parallel() - - t.Run("returns 204 when runtime is running", func(t *testing.T) { - t.Parallel() - // Create a new gomock controller for this subtest - ctrl := gomock.NewController(t) - t.Cleanup(func() { - ctrl.Finish() - }) - - // Create a mock runtime - mockRuntime := mocks.NewMockRuntime(ctrl) - - // Create healthcheck routes with the mock runtime - routes := &healthcheckRoutes{containerRuntime: mockRuntime} - - // Setup mock to return nil (no error) when IsRunning is called - mockRuntime.EXPECT(). - IsRunning(gomock.Any()). - Return(nil) - - // Create a test request and response recorder - req := httptest.NewRequest(http.MethodGet, "/health", nil) - resp := httptest.NewRecorder() - - // Call the handler - routes.getHealthcheck(resp, req) - - // Assert the response - assert.Equal(t, http.StatusNoContent, resp.Code) - assert.Empty(t, resp.Body.String()) - }) - - t.Run("returns 503 when runtime is not running", func(t *testing.T) { - t.Parallel() - // Create a new gomock controller for this subtest - ctrl := gomock.NewController(t) - t.Cleanup(func() { - ctrl.Finish() - }) - - // Create a mock runtime - mockRuntime := mocks.NewMockRuntime(ctrl) - - // Create healthcheck routes with the mock runtime - routes := &healthcheckRoutes{containerRuntime: mockRuntime} - - // Create an error to return - expectedError := errors.New("container runtime is not available") - - // Setup mock to return an error when IsRunning is called - mockRuntime.EXPECT(). - IsRunning(gomock.Any()). - Return(expectedError) - - // Create a test request and response recorder - req := httptest.NewRequest(http.MethodGet, "/health", nil) - resp := httptest.NewRecorder() - - // Call the handler - routes.getHealthcheck(resp, req) - - // Assert the response - assert.Equal(t, http.StatusServiceUnavailable, resp.Code) - assert.Equal(t, expectedError.Error()+"\n", resp.Body.String()) - }) -} diff --git a/pkg/api/v1/healthcheck.go b/pkg/api/v1/healthcheck.go index 5b065b1cd8..dde22364e1 100644 --- a/pkg/api/v1/healthcheck.go +++ b/pkg/api/v1/healthcheck.go @@ -9,11 +9,14 @@ import ( "github.com/go-chi/chi/v5" rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/server/discovery" ) // HealthcheckRouter sets up healthcheck route. -func HealthcheckRouter(containerRuntime rt.Runtime) http.Handler { - routes := &healthcheckRoutes{containerRuntime: containerRuntime} +// The nonce parameter, when non-empty, is returned via the X-Toolhive-Nonce +// header so clients can verify they are talking to the expected server instance. +func HealthcheckRouter(containerRuntime rt.Runtime, nonce string) http.Handler { + routes := &healthcheckRoutes{containerRuntime: containerRuntime, nonce: nonce} r := chi.NewRouter() r.Get("/", routes.getHealthcheck) return r @@ -21,6 +24,7 @@ func HealthcheckRouter(containerRuntime rt.Runtime) http.Handler { type healthcheckRoutes struct { containerRuntime rt.Runtime + nonce string } // getHealthcheck @@ -35,6 +39,10 @@ func (h *healthcheckRoutes) getHealthcheck(w http.ResponseWriter, r *http.Reques http.Error(w, err.Error(), http.StatusServiceUnavailable) return } + // Return the server nonce so clients can verify instance identity. + if h.nonce != "" { + w.Header().Set(discovery.NonceHeader, h.nonce) + } // If the container runtime is running, we consider the API healthy. w.WriteHeader(http.StatusNoContent) } diff --git a/pkg/api/v1/healthcheck_test.go b/pkg/api/v1/healthcheck_test.go new file mode 100644 index 0000000000..72dc16cb09 --- /dev/null +++ b/pkg/api/v1/healthcheck_test.go @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package v1 + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/container/runtime/mocks" + "github.com/stacklok/toolhive/pkg/server/discovery" +) + +func TestGetHealthcheck(t *testing.T) { + t.Parallel() + + t.Run("returns 204 when runtime is running", func(t *testing.T) { + t.Parallel() + // Create a new gomock controller for this subtest + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with the mock runtime + routes := &healthcheckRoutes{containerRuntime: mockRuntime} + + // Setup mock to return nil (no error) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(nil) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response + assert.Equal(t, http.StatusNoContent, resp.Code) + assert.Empty(t, resp.Body.String()) + }) + + t.Run("returns 503 when runtime is not running", func(t *testing.T) { + t.Parallel() + // Create a new gomock controller for this subtest + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with the mock runtime + routes := &healthcheckRoutes{containerRuntime: mockRuntime} + + // Create an error to return + expectedError := errors.New("container runtime is not available") + + // Setup mock to return an error when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(expectedError) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Equal(t, expectedError.Error()+"\n", resp.Body.String()) + }) +} + +func TestGetHealthcheck_ReturnsNonceHeader(t *testing.T) { + t.Parallel() + + // Create a new gomock controller + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with a nonce value + routes := &healthcheckRoutes{containerRuntime: mockRuntime, nonce: "test-nonce-value"} + + // Setup mock to return nil (healthy) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(nil) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context()) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response status and nonce header + assert.Equal(t, http.StatusNoContent, resp.Code) + assert.Equal(t, "test-nonce-value", resp.Header().Get(discovery.NonceHeader)) +} + +func TestGetHealthcheck_OmitsNonceHeaderWhenEmpty(t *testing.T) { + t.Parallel() + + // Create a new gomock controller + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with an empty nonce + routes := &healthcheckRoutes{containerRuntime: mockRuntime, nonce: ""} + + // Setup mock to return nil (healthy) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(nil) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context()) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response status and absence of nonce header + assert.Equal(t, http.StatusNoContent, resp.Code) + assert.Empty(t, resp.Header().Get(discovery.NonceHeader)) + assert.Empty(t, resp.Header().Values(discovery.NonceHeader)) +} + +func TestGetHealthcheck_NoNonceOnUnhealthy(t *testing.T) { + t.Parallel() + + // Create a new gomock controller + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + // Create a mock runtime + mockRuntime := mocks.NewMockRuntime(ctrl) + + // Create healthcheck routes with a nonce value + routes := &healthcheckRoutes{containerRuntime: mockRuntime, nonce: "test-nonce"} + + // Setup mock to return an error (unhealthy) when IsRunning is called + mockRuntime.EXPECT(). + IsRunning(gomock.Any()). + Return(errors.New("runtime unavailable")) + + // Create a test request and response recorder + req := httptest.NewRequest(http.MethodGet, "/health", nil).WithContext(t.Context()) + resp := httptest.NewRecorder() + + // Call the handler + routes.getHealthcheck(resp, req) + + // Assert the response status and absence of nonce header + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Empty(t, resp.Header().Get(discovery.NonceHeader)) + assert.Empty(t, resp.Header().Values(discovery.NonceHeader)) +} diff --git a/pkg/server/discovery/discover.go b/pkg/server/discovery/discover.go new file mode 100644 index 0000000000..c9d9fcbee4 --- /dev/null +++ b/pkg/server/discovery/discover.go @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "errors" + "log/slog" + "os" + + "github.com/stacklok/toolhive/pkg/process" +) + +// ServerState represents the state of a discovered server. +type ServerState int + +const ( + // StateNotFound means no discovery file exists. + StateNotFound ServerState = iota + // StateRunning means the server is healthy and responding. + StateRunning + // StateStale means the discovery file exists but the process is dead. + StateStale + // StateUnhealthy means the process is alive but the server is not responding. + StateUnhealthy +) + +// String returns a human-readable representation of the server state. +func (s ServerState) String() string { + switch s { + case StateNotFound: + return "not_found" + case StateRunning: + return "running" + case StateStale: + return "stale" + case StateUnhealthy: + return "unhealthy" + default: + return "unknown" + } +} + +// DiscoverResult holds the result of a server discovery attempt. +type DiscoverResult struct { + // State is the discovered server state. + State ServerState + // Info is the server information from the discovery file. + // It is nil when State is StateNotFound. + Info *ServerInfo +} + +// Discover attempts to find a running ToolHive server by reading the discovery +// file and verifying the server is healthy. +func Discover(ctx context.Context) (*DiscoverResult, error) { + return discover(ctx, defaultDiscoveryDir()) +} + +// discover is the internal implementation that accepts a directory for testability. +func discover(ctx context.Context, dir string) (*DiscoverResult, error) { + info, err := readServerInfoFrom(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return &DiscoverResult{State: StateNotFound}, nil + } + return nil, err + } + + // Try health check with nonce verification + if err := CheckHealth(ctx, info.URL, info.Nonce); err == nil { + return &DiscoverResult{State: StateRunning, Info: info}, nil + } + + // Health check failed — check if the process is still alive + alive, err := process.FindProcess(info.PID) + if err != nil { + slog.Debug("cannot determine process state, treating as stale", "pid", info.PID, "error", err) + return &DiscoverResult{State: StateStale, Info: info}, nil + } + + if !alive { + return &DiscoverResult{State: StateStale, Info: info}, nil + } + + return &DiscoverResult{State: StateUnhealthy, Info: info}, nil +} + +// CleanupStale removes a stale discovery file. Clients should call this +// when Discover returns StateStale. +func CleanupStale() error { + return RemoveServerInfo() +} diff --git a/pkg/server/discovery/discover_test.go b/pkg/server/discovery/discover_test.go new file mode 100644 index 0000000000..99b283de51 --- /dev/null +++ b/pkg/server/discovery/discover_test.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDiscover_NotFound(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateNotFound, result.State) + assert.Nil(t, result.Info) +} + +func TestDiscover_Running(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + nonce := "running-nonce" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, nonce) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + info := &ServerInfo{ + URL: srv.URL, + PID: os.Getpid(), + Nonce: nonce, + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateRunning, result.State) + assert.Equal(t, nonce, result.Info.Nonce) +} + +func TestDiscover_Stale_DeadProcess(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:1", + PID: 999999999, + Nonce: "stale-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateStale, result.State) + assert.NotNil(t, result.Info) +} + +func TestDiscover_Unhealthy_AliveButNotResponding(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Server that returns 503 (unhealthy) — process is alive (our own PID) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + info := &ServerInfo{ + URL: srv.URL, + PID: os.Getpid(), + Nonce: "unhealthy-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + assert.Equal(t, StateUnhealthy, result.State) + assert.NotNil(t, result.Info) +} + +func TestDiscover_NonceMismatch_TreatedAsUnhealthy(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Server returns wrong nonce — simulates PID reuse scenario + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, "different-server-nonce") + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + info := &ServerInfo{ + URL: srv.URL, + PID: os.Getpid(), + Nonce: "original-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + result, err := discover(context.Background(), dir) + require.NoError(t, err) + // Nonce mismatch means health check fails, but process is alive + assert.Equal(t, StateUnhealthy, result.State) +} diff --git a/pkg/server/discovery/discovery.go b/pkg/server/discovery/discovery.go new file mode 100644 index 0000000000..04790e19ca --- /dev/null +++ b/pkg/server/discovery/discovery.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package discovery provides server discovery file management for ToolHive. +// It writes, reads, and removes a JSON file that advertises a running server +// so clients (CLI, Studio) can find it without configuration. +package discovery + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/adrg/xdg" + + "github.com/stacklok/toolhive/pkg/fileutils" +) + +const ( + // dirPermissions is the permission mode for the discovery directory. + dirPermissions = 0700 + // filePermissions is the permission mode for the discovery file. + filePermissions = 0600 +) + +// ServerInfo contains the information advertised by a running ToolHive server. +type ServerInfo struct { + // URL is the address where the server is listening. + // For TCP: "http://127.0.0.1:52341" + // For Unix sockets: "unix:///path/to/thv.sock" + URL string `json:"url"` + + // PID is the process ID of the running server. + PID int `json:"pid"` + + // Nonce is a unique identifier generated at server startup. + // It solves PID reuse: clients verify the nonce via /health to confirm + // the discovery file refers to the expected server instance. + Nonce string `json:"nonce"` + + // StartedAt is the UTC timestamp when the server started. + StartedAt time.Time `json:"started_at"` +} + +// defaultDiscoveryDir returns the default directory for the discovery file +// based on the XDG Base Directory Specification. +func defaultDiscoveryDir() string { + return filepath.Join(xdg.StateHome, "toolhive", "server") +} + +// FilePath returns the full path to the server discovery file +// using the default XDG-based directory. +func FilePath() string { + return filepath.Join(defaultDiscoveryDir(), "server.json") +} + +// WriteServerInfo atomically writes the server discovery file. +// It creates the directory if needed, rejects symlinks at the target path, +// and writes with restricted permissions (0600). +func WriteServerInfo(info *ServerInfo) error { + return writeServerInfoTo(defaultDiscoveryDir(), info) +} + +// ReadServerInfo reads and parses the server discovery file. +// Returns os.ErrNotExist if the file does not exist. +func ReadServerInfo() (*ServerInfo, error) { + return readServerInfoFrom(defaultDiscoveryDir()) +} + +// RemoveServerInfo removes the server discovery file. +// It is a no-op if the file does not exist. +func RemoveServerInfo() error { + return removeServerInfoFrom(defaultDiscoveryDir()) +} + +// writeServerInfoTo writes the discovery file into the given directory. +func writeServerInfoTo(dir string, info *ServerInfo) error { + if err := os.MkdirAll(dir, dirPermissions); err != nil { + return fmt.Errorf("failed to create discovery directory: %w", err) + } + + // Tighten permissions on the directory in case it already existed with + // looser permissions. MkdirAll only applies mode to newly-created dirs. + if err := os.Chmod(dir, dirPermissions); err != nil { + return fmt.Errorf("failed to set discovery directory permissions: %w", err) + } + + path := filepath.Join(dir, "server.json") + + // Reject symlinks at the target path to prevent symlink attacks + if fi, err := os.Lstat(path); err == nil { + if fi.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to write discovery file: %s is a symlink", path) + } + } + + data, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal server info: %w", err) + } + + if err := fileutils.AtomicWriteFile(path, data, filePermissions); err != nil { + return fmt.Errorf("failed to write discovery file: %w", err) + } + + return nil +} + +// readServerInfoFrom reads the discovery file from the given directory. +func readServerInfoFrom(dir string) (*ServerInfo, error) { + path := filepath.Join(dir, "server.json") + + // Reject symlinks on the read path, consistent with the write path. + if fi, err := os.Lstat(path); err == nil { + if fi.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("refusing to read discovery file: %s is a symlink", path) + } + } + + data, err := os.ReadFile(path) // #nosec G304 -- path is constructed from a trusted XDG directory, not user input + if err != nil { + return nil, err + } + + var info ServerInfo + if err := json.Unmarshal(data, &info); err != nil { + return nil, fmt.Errorf("failed to parse discovery file: %w", err) + } + + return &info, nil +} + +// removeServerInfoFrom removes the discovery file from the given directory. +func removeServerInfoFrom(dir string) error { + err := os.Remove(filepath.Join(dir, "server.json")) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to remove discovery file: %w", err) + } + return nil +} diff --git a/pkg/server/discovery/discovery_test.go b/pkg/server/discovery/discovery_test.go new file mode 100644 index 0000000000..1ca39a0ada --- /dev/null +++ b/pkg/server/discovery/discovery_test.go @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWriteReadServerInfo_TCP(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:52341", + PID: 12345, + Nonce: "test-nonce-tcp", + StartedAt: time.Date(2026, 3, 23, 10, 0, 0, 0, time.UTC), + } + + require.NoError(t, writeServerInfoTo(dir, info)) + + got, err := readServerInfoFrom(dir) + require.NoError(t, err) + assert.Equal(t, info.URL, got.URL) + assert.Equal(t, info.PID, got.PID) + assert.Equal(t, info.Nonce, got.Nonce) + assert.True(t, info.StartedAt.Equal(got.StartedAt)) +} + +func TestWriteReadServerInfo_UnixSocket(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "unix:///tmp/thv-test.sock", + PID: 54321, + Nonce: "test-nonce-unix", + StartedAt: time.Date(2026, 3, 23, 11, 0, 0, 0, time.UTC), + } + + require.NoError(t, writeServerInfoTo(dir, info)) + + got, err := readServerInfoFrom(dir) + require.NoError(t, err) + assert.Equal(t, info.URL, got.URL) + assert.Equal(t, info.PID, got.PID) + assert.Equal(t, info.Nonce, got.Nonce) +} + +func TestReadServerInfo_NotFound(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + _, err := readServerInfoFrom(dir) + require.ErrorIs(t, err, os.ErrNotExist) +} + +func TestRemoveServerInfo_Exists(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + require.NoError(t, removeServerInfoFrom(dir)) + + _, err := readServerInfoFrom(dir) + require.ErrorIs(t, err, os.ErrNotExist) +} + +func TestRemoveServerInfo_NotFound(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Should not error when file doesn't exist + require.NoError(t, removeServerInfoFrom(dir)) +} + +func TestWriteServerInfo_FilePermissions(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + fi, err := os.Stat(filepath.Join(dir, "server.json")) + require.NoError(t, err) + assert.Equal(t, os.FileMode(filePermissions), fi.Mode().Perm()) +} + +func TestWriteServerInfo_CreatesDirectoryWithCorrectPermissions(t *testing.T) { + t.Parallel() + parent := t.TempDir() + dir := filepath.Join(parent, "nested", "server") + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + fi, err := os.Stat(dir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(dirPermissions), fi.Mode().Perm()) +} + +func TestWriteServerInfo_RejectsSymlink(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Create a symlink at the target path + target := filepath.Join(t.TempDir(), "evil.json") + require.NoError(t, os.WriteFile(target, []byte("{}"), 0600)) + require.NoError(t, os.Symlink(target, filepath.Join(dir, "server.json"))) + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "nonce", + StartedAt: time.Now().UTC(), + } + err := writeServerInfoTo(dir, info) + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") +} + +func TestReadServerInfo_RejectsSymlink(t *testing.T) { + t.Parallel() + + // Write a valid server.json in a real directory. + realDir := t.TempDir() + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "real-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(realDir, info)) + + // Create a second directory with a symlink named server.json that + // points to the real file. + symlinkDir := t.TempDir() + realFile := filepath.Join(realDir, "server.json") + symlinkFile := filepath.Join(symlinkDir, "server.json") + require.NoError(t, os.Symlink(realFile, symlinkFile)) + + _, err := readServerInfoFrom(symlinkDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") +} + +func TestWriteServerInfo_TightensExistingDirPermissions(t *testing.T) { + t.Parallel() + + // Create a directory with deliberately too-loose permissions. + dir := t.TempDir() + require.NoError(t, os.Chmod(dir, 0755)) + + info := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "tighten-nonce", + StartedAt: time.Now().UTC(), + } + require.NoError(t, writeServerInfoTo(dir, info)) + + // Verify the directory permissions were tightened to 0700. + fi, err := os.Stat(dir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(dirPermissions), fi.Mode().Perm()) +} + +func TestWriteServerInfo_OverwritesExistingFile(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + first := &ServerInfo{ + URL: "http://127.0.0.1:8080", + PID: 1, + Nonce: "first", + } + require.NoError(t, writeServerInfoTo(dir, first)) + + second := &ServerInfo{ + URL: "http://127.0.0.1:9090", + PID: 2, + Nonce: "second", + } + require.NoError(t, writeServerInfoTo(dir, second)) + + got, err := readServerInfoFrom(dir) + require.NoError(t, err) + assert.Equal(t, "second", got.Nonce) + assert.Equal(t, "http://127.0.0.1:9090", got.URL) +} diff --git a/pkg/server/discovery/health.go b/pkg/server/discovery/health.go new file mode 100644 index 0000000000..cf80cb718d --- /dev/null +++ b/pkg/server/discovery/health.go @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "crypto/subtle" + "fmt" + "net" + "net/http" + "net/url" + "path/filepath" + "strings" + "time" +) + +const ( + // healthTimeout is the maximum time to wait for a health check response. + healthTimeout = 5 * time.Second + + // NonceHeader is the HTTP header used to return the server nonce. + NonceHeader = "X-Toolhive-Nonce" +) + +// CheckHealth verifies that a server at the given URL is healthy and optionally +// matches the expected nonce. It supports http:// and unix:// URL schemes. +func CheckHealth(ctx context.Context, serverURL string, expectedNonce string) error { + client, requestURL, err := buildHealthClient(serverURL) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(ctx, healthTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + return fmt.Errorf("failed to create health request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("health check failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("unexpected health status: %d", resp.StatusCode) + } + + if expectedNonce != "" { + actualNonce := resp.Header.Get(NonceHeader) + if subtle.ConstantTimeCompare([]byte(actualNonce), []byte(expectedNonce)) != 1 { + return fmt.Errorf("nonce mismatch: expected %q, got %q", expectedNonce, actualNonce) + } + } + + return nil +} + +// buildHealthClient returns an HTTP client and request URL appropriate for +// the given server URL scheme. +func buildHealthClient(serverURL string) (*http.Client, string, error) { + client, baseURL, err := HTTPClientForURL(serverURL) + if err != nil { + return nil, "", err + } + healthURL, err := url.JoinPath(baseURL, "health") + if err != nil { + return nil, "", fmt.Errorf("failed to build health URL: %w", err) + } + return client, healthURL, nil +} + +// HTTPClientForURL returns an HTTP client configured for the given server URL +// and the base URL to use for requests. For unix:// URLs it creates a client +// with a Unix socket transport and returns "http://localhost" as the base URL. +// For http:// URLs it validates the host is a loopback address and returns a +// default client. The returned client has no timeout set; callers should apply +// their own timeout via context or client.Timeout. +func HTTPClientForURL(serverURL string) (*http.Client, string, error) { + switch { + case strings.HasPrefix(serverURL, "unix://"): + socketPath, err := ParseUnixSocketPath(serverURL) + if err != nil { + return nil, "", err + } + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, + }, + } + return client, "http://localhost", nil + + case strings.HasPrefix(serverURL, "http://"): + if err := ValidateLoopbackURL(serverURL); err != nil { + return nil, "", err + } + return &http.Client{}, serverURL, nil + + default: + return nil, "", fmt.Errorf("unsupported URL scheme: %s", serverURL) + } +} + +// ValidateLoopbackURL checks that an http:// URL points to a loopback address. +func ValidateLoopbackURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + host := u.Hostname() + + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("invalid host in URL: %s", host) + } + if !ip.IsLoopback() { + return fmt.Errorf("refusing health check to non-loopback address: %s", host) + } + return nil +} + +// ParseUnixSocketPath extracts and validates the socket path from a unix:// URL. +func ParseUnixSocketPath(rawURL string) (string, error) { + path := strings.TrimPrefix(rawURL, "unix://") + if path == "" { + return "", fmt.Errorf("empty unix socket path") + } + + // Check for traversal before Clean resolves it away + if strings.Contains(path, "..") { + return "", fmt.Errorf("unix socket path must not contain '..': %s", path) + } + + path = filepath.Clean(path) + + if !filepath.IsAbs(path) { + return "", fmt.Errorf("unix socket path must be absolute: %s", path) + } + + return path, nil +} diff --git a/pkg/server/discovery/health_test.go b/pkg/server/discovery/health_test.go new file mode 100644 index 0000000000..3678935999 --- /dev/null +++ b/pkg/server/discovery/health_test.go @@ -0,0 +1,171 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseUnixSocketPath_Valid(t *testing.T) { + t.Parallel() + path, err := ParseUnixSocketPath("unix:///var/run/thv.sock") + require.NoError(t, err) + assert.Equal(t, "/var/run/thv.sock", path) +} + +func TestParseUnixSocketPath_RelativePathRejected(t *testing.T) { + t.Parallel() + _, err := ParseUnixSocketPath("unix://relative/path.sock") + require.Error(t, err) + assert.Contains(t, err.Error(), "absolute") +} + +func TestParseUnixSocketPath_DotDotRejected(t *testing.T) { + t.Parallel() + _, err := ParseUnixSocketPath("unix:///var/run/../etc/evil.sock") + require.Error(t, err) + assert.Contains(t, err.Error(), "..") +} + +func TestParseUnixSocketPath_Empty(t *testing.T) { + t.Parallel() + _, err := ParseUnixSocketPath("unix://") + require.Error(t, err) + assert.Contains(t, err.Error(), "empty") +} + +func TestCheckHealth_TCP_Success(t *testing.T) { + t.Parallel() + expectedNonce := "test-nonce-123" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, expectedNonce) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + err := CheckHealth(context.Background(), srv.URL, expectedNonce) + require.NoError(t, err) +} + +func TestCheckHealth_TCP_NonceMismatch(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, "wrong-nonce") + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + err := CheckHealth(context.Background(), srv.URL, "expected-nonce") + require.Error(t, err) + assert.Contains(t, err.Error(), "nonce mismatch") +} + +func TestCheckHealth_TCP_NoNonceCheck(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + // Empty expectedNonce skips nonce check + err := CheckHealth(context.Background(), srv.URL, "") + require.NoError(t, err) +} + +func TestCheckHealth_UnixSocket_Success(t *testing.T) { + t.Parallel() + socketDir := t.TempDir() + socketPath := filepath.Join(socketDir, "test.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + expectedNonce := "unix-nonce" + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(NonceHeader, expectedNonce) + w.WriteHeader(http.StatusNoContent) + }) + srv := &http.Server{Handler: mux} + go func() { _ = srv.Serve(listener) }() + defer srv.Close() + + err = CheckHealth(context.Background(), "unix://"+socketPath, expectedNonce) + require.NoError(t, err) +} + +func TestCheckHealth_Unreachable(t *testing.T) { + t.Parallel() + err := CheckHealth(context.Background(), "http://127.0.0.1:1", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "health check failed") +} + +func TestCheckHealth_InvalidScheme(t *testing.T) { + t.Parallel() + err := CheckHealth(context.Background(), "ftp://localhost:21", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported URL scheme") +} + +func TestCheckHealth_NonLoopbackRejected(t *testing.T) { + t.Parallel() + err := CheckHealth(context.Background(), "http://192.168.1.1:8080", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "non-loopback") +} + +func TestCheckHealth_UnhealthyStatus(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + err := CheckHealth(context.Background(), srv.URL, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected health status") +} + +func TestValidateLoopbackURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + url string + wantErr bool + }{ + {"IPv4 loopback", "http://127.0.0.1:8080", false}, + {"IPv6 loopback", "http://[::1]:8080", false}, + {"non-loopback", "http://192.168.1.1:8080", true}, + {"hostname", "http://example.com:8080", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := ValidateLoopbackURL(tt.url) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCheckHealth_UnixSocket_NotFound(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(os.TempDir(), "nonexistent-test.sock") + err := CheckHealth(context.Background(), "unix://"+socketPath, "") + require.Error(t, err) +} diff --git a/pkg/skills/client/client.go b/pkg/skills/client/client.go index bd4adc5a7d..7878fb4222 100644 --- a/pkg/skills/client/client.go +++ b/pkg/skills/client/client.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "net/url" "strings" @@ -18,6 +19,7 @@ import ( "github.com/stacklok/toolhive-core/env" "github.com/stacklok/toolhive-core/httperr" + "github.com/stacklok/toolhive/pkg/server/discovery" "github.com/stacklok/toolhive/pkg/skills" ) @@ -74,19 +76,58 @@ func NewClient(baseURL string, opts ...Option) *Client { return c } -// NewDefaultClient creates a Skills API client using the TOOLHIVE_API_URL -// environment variable, falling back to http://127.0.0.1:8080. -func NewDefaultClient(opts ...Option) *Client { - return newDefaultClientWithEnv(&env.OSReader{}, opts...) +// NewDefaultClient creates a Skills API client by trying, in order: +// 1. The TOOLHIVE_API_URL environment variable (explicit override) +// 2. The server discovery file (auto-detected running server) +// 3. The default URL http://127.0.0.1:8080 +// +// The context is used for the server discovery health check; it is not stored. +func NewDefaultClient(ctx context.Context, opts ...Option) *Client { + return newDefaultClientWithEnv(ctx, &env.OSReader{}, opts...) } // newDefaultClientWithEnv is the testable core of NewDefaultClient. -func newDefaultClientWithEnv(envReader env.Reader, opts ...Option) *Client { - base := envReader.Getenv(envAPIURL) - if base == "" { - base = defaultBaseURL +func newDefaultClientWithEnv(ctx context.Context, envReader env.Reader, opts ...Option) *Client { + // 1. Explicit env var override always wins. + if base := envReader.Getenv(envAPIURL); base != "" { + return NewClient(base, opts...) } - return NewClient(base, opts...) + + // 2. Try server discovery. + if base, httpOpts := resolveViaDiscovery(ctx); base != "" { + // Discovery opts go first so caller-supplied opts can override them + // (e.g. a caller-provided WithTimeout replaces the discovery default). + merged := make([]Option, 0, len(httpOpts)+len(opts)) + merged = append(merged, httpOpts...) + merged = append(merged, opts...) + return NewClient(base, merged...) + } + + // 3. Fall back to the default URL. + return NewClient(defaultBaseURL, opts...) +} + +// resolveViaDiscovery attempts to find a running server via the discovery file. +// It returns the base URL and any additional options (e.g. a Unix socket transport). +// On failure it returns empty values and the caller falls back to the default. +func resolveViaDiscovery(ctx context.Context) (string, []Option) { + result, err := discovery.Discover(ctx) + if err != nil { + slog.Debug("server discovery failed", "error", err) + return "", nil + } + if result.State != discovery.StateRunning { + return "", nil + } + + client, baseURL, err := discovery.HTTPClientForURL(result.Info.URL) + if err != nil { + slog.Debug("invalid URL in discovery file", "url", result.Info.URL, "error", err) + return "", nil + } + client.Timeout = defaultTimeout + + return baseURL, []Option{WithHTTPClient(client)} } // --- SkillService implementation --- diff --git a/pkg/skills/client/client_test.go b/pkg/skills/client/client_test.go index dc0e7af04b..9dc2d92fd3 100644 --- a/pkg/skills/client/client_test.go +++ b/pkg/skills/client/client_test.go @@ -564,7 +564,7 @@ func TestNewDefaultClient(t *testing.T) { mockEnv := envmocks.NewMockReader(ctrl) mockEnv.EXPECT().Getenv(envAPIURL).Return("") - c := newDefaultClientWithEnv(mockEnv) + c := newDefaultClientWithEnv(t.Context(), mockEnv) assert.Equal(t, defaultBaseURL, c.baseURL) }) @@ -574,7 +574,7 @@ func TestNewDefaultClient(t *testing.T) { mockEnv := envmocks.NewMockReader(ctrl) mockEnv.EXPECT().Getenv(envAPIURL).Return("http://localhost:9999") - c := newDefaultClientWithEnv(mockEnv) + c := newDefaultClientWithEnv(t.Context(), mockEnv) assert.Equal(t, "http://localhost:9999", c.baseURL) }) @@ -584,7 +584,7 @@ func TestNewDefaultClient(t *testing.T) { mockEnv := envmocks.NewMockReader(ctrl) mockEnv.EXPECT().Getenv(envAPIURL).Return("") - c := newDefaultClientWithEnv(mockEnv, WithTimeout(5*time.Second)) + c := newDefaultClientWithEnv(t.Context(), mockEnv, WithTimeout(5*time.Second)) assert.Equal(t, 5*time.Second, c.httpClient.Timeout) }) }