Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package api

import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
Expand All @@ -39,9 +41,11 @@ import (
"github.com/stacklok/toolhive/pkg/config"
"github.com/stacklok/toolhive/pkg/container"
"github.com/stacklok/toolhive/pkg/container/runtime"
"github.com/stacklok/toolhive/pkg/fileutils"
"github.com/stacklok/toolhive/pkg/groups"
"github.com/stacklok/toolhive/pkg/recovery"
"github.com/stacklok/toolhive/pkg/registry"
"github.com/stacklok/toolhive/pkg/server/discovery"
"github.com/stacklok/toolhive/pkg/skills"
"github.com/stacklok/toolhive/pkg/skills/skillsvc"
"github.com/stacklok/toolhive/pkg/storage/sqlite"
Expand All @@ -54,6 +58,7 @@ const (
middlewareTimeout = 60 * time.Second
readHeaderTimeout = 10 * time.Second
shutdownTimeout = 30 * time.Second
nonceBytes = 16
socketPermissions = 0660 // Socket file permissions (owner/group read-write)
maxRequestBodySize = 1 << 20 // 1MB - Maximum request body size
)
Expand All @@ -64,6 +69,7 @@ type ServerBuilder struct {
isUnixSocket bool
debugMode bool
enableDocs bool
nonce string
oidcConfig *auth.TokenValidatorConfig
middlewares []func(http.Handler) http.Handler
customRoutes map[string]http.Handler
Expand Down Expand Up @@ -107,6 +113,14 @@ func (b *ServerBuilder) WithDocs(enableDocs bool) *ServerBuilder {
return b
}

// WithNonce sets the server instance nonce used for discovery verification.
// When non-empty, the server writes a discovery file on startup and returns
// the nonce in the X-Toolhive-Nonce health check header.
func (b *ServerBuilder) WithNonce(nonce string) *ServerBuilder {
b.nonce = nonce
return b
}

// WithOIDCConfig sets the OIDC configuration
func (b *ServerBuilder) WithOIDCConfig(oidcConfig *auth.TokenValidatorConfig) *ServerBuilder {
b.oidcConfig = oidcConfig
Expand Down Expand Up @@ -293,7 +307,7 @@ func (b *ServerBuilder) setupDefaultRoutes(r *chi.Mux) {

// All other routes get standard timeout
standardRouters := map[string]http.Handler{
"/health": v1.HealthcheckRouter(b.containerRuntime),
"/health": v1.HealthcheckRouter(b.containerRuntime, b.nonce),
"/api/v1beta/version": v1.VersionRouter(),
"/api/v1beta/registry": v1.RegistryRouter(true),
"/api/v1beta/discovery": v1.DiscoveryRouter(),
Expand Down Expand Up @@ -500,6 +514,7 @@ type Server struct {
address string
isUnixSocket bool
addrType string
nonce string
storeCloser io.Closer
}

Expand Down Expand Up @@ -528,14 +543,29 @@ func NewServer(ctx context.Context, builder *ServerBuilder) (*Server, error) {
address: builder.address,
isUnixSocket: builder.isUnixSocket,
addrType: addrType,
nonce: builder.nonce,
storeCloser: builder.skillStoreCloser,
}, nil
}

// ListenURL returns the URL where the server is listening, using the actual
// bound address from the listener (important when binding to port 0).
func (s *Server) ListenURL() string {
if s.isUnixSocket {
return fmt.Sprintf("unix://%s", s.address)
}
return fmt.Sprintf("http://%s", s.listener.Addr().String())
}

// Start starts the server and blocks until the context is cancelled
func (s *Server) Start(ctx context.Context) error {
slog.Info("starting server", "type", s.addrType, "address", s.address)

// Write server discovery file so clients can find this instance.
if err := s.writeDiscoveryFile(ctx); err != nil {
return err
}

// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
Expand All @@ -558,6 +588,54 @@ func (s *Server) Start(ctx context.Context) error {
}
}

// writeDiscoveryFile writes the server discovery file if a nonce is configured.
// It checks for an existing healthy server first to prevent silent orphaning.
// The entire check-then-write sequence is wrapped in a file lock to prevent
// TOCTOU races when two servers start simultaneously.
func (s *Server) writeDiscoveryFile(ctx context.Context) error {
if s.nonce == "" {
return nil
}

return fileutils.WithFileLock(discovery.FilePath(), func() error {
// Guard against overwriting another server's discovery file.
result, err := discovery.Discover(ctx)
if err != nil {
slog.Debug("discovery check failed, proceeding with startup", "error", err)
} else {
switch result.State {
case discovery.StateRunning:
return fmt.Errorf("another ToolHive server is already running at %s (PID %d)", result.Info.URL, result.Info.PID)
case discovery.StateStale:
slog.Debug("cleaning up stale discovery file", "pid", result.Info.PID)
if err := discovery.CleanupStale(); err != nil {
slog.Warn("failed to clean up stale discovery file", "error", err)
}
case discovery.StateUnhealthy:
// The process is alive but not responding to health checks.
// This can happen after a crash-restart where the old process
// is hung. We intentionally overwrite the discovery file so
// this new server becomes discoverable.
slog.Warn("existing server is unhealthy, overwriting discovery file", "pid", result.Info.PID)
case discovery.StateNotFound:
// No existing server, proceed normally.
}
}

info := &discovery.ServerInfo{
URL: s.ListenURL(),
PID: os.Getpid(),
Nonce: s.nonce,
StartedAt: time.Now().UTC(),
}
if err := discovery.WriteServerInfo(info); err != nil {
return fmt.Errorf("failed to write discovery file: %w", err)
}
slog.Debug("wrote discovery file", "url", info.URL, "pid", info.PID)
return nil
})
}

// shutdown gracefully shuts down the server
func (s *Server) shutdown() error {
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
Expand All @@ -575,6 +653,11 @@ func (s *Server) shutdown() error {

// cleanup performs cleanup operations
func (s *Server) cleanup() {
if s.nonce != "" {
if err := discovery.RemoveServerInfo(); err != nil {
slog.Warn("failed to remove discovery file", "error", err)
}
}
if s.storeCloser != nil {
if err := s.storeCloser.Close(); err != nil {
slog.Warn("failed to close skill store", "error", err)
Expand Down Expand Up @@ -649,6 +732,16 @@ func (a *clientPathAdapter) ListSkillSupportingClients() []string {
return result
}

// generateNonce creates a cryptographically random nonce for server instance
// identification. It returns a 32-character hex string (16 random bytes).
func generateNonce() (string, error) {
b := make([]byte, nonceBytes)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate server nonce: %w", err)
}
return hex.EncodeToString(b), nil
}

// Serve starts the server on the given address and serves the API.
// It is assumed that the caller sets up appropriate signal handling.
// If isUnixSocket is true, address is treated as a UNIX socket path.
Expand All @@ -662,11 +755,17 @@ func Serve(
oidcConfig *auth.TokenValidatorConfig,
middlewares ...func(http.Handler) http.Handler,
) error {
nonce, err := generateNonce()
if err != nil {
return err
}

builder := NewServerBuilder().
WithAddress(address).
WithUnixSocket(isUnixSocket).
WithDebugMode(debugMode).
WithDocs(enableDocs).
WithNonce(nonce).
WithOIDCConfig(oidcConfig).
WithMiddleware(middlewares...)

Expand Down
88 changes: 88 additions & 0 deletions pkg/api/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package api

import (
"fmt"
"net"
"regexp"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGenerateNonce(t *testing.T) {
t.Parallel()

t.Run("returns valid 32-char hex string", func(t *testing.T) {
t.Parallel()

nonce, err := generateNonce()
require.NoError(t, err)

assert.Len(t, nonce, 32)
assert.Regexp(t, regexp.MustCompile(`^[0-9a-f]{32}$`), nonce)
})

t.Run("returns unique values on successive calls", func(t *testing.T) {
t.Parallel()

nonce1, err := generateNonce()
require.NoError(t, err)

nonce2, err := generateNonce()
require.NoError(t, err)

assert.NotEqual(t, nonce1, nonce2)
})
}

func TestListenURL(t *testing.T) {
t.Parallel()

tests := []struct {
name string
server func(t *testing.T) *Server
expected func(s *Server) string
}{
{
name: "TCP returns http URL with actual port",
server: func(t *testing.T) *Server {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { listener.Close() })
return &Server{
listener: listener,
isUnixSocket: false,
address: "127.0.0.1:0",
}
},
expected: func(s *Server) string {
return fmt.Sprintf("http://%s", s.listener.Addr().String())
},
},
{
name: "Unix socket returns unix URL",
server: func(_ *testing.T) *Server {
return &Server{
isUnixSocket: true,
address: "/tmp/test.sock",
}
},
expected: func(_ *Server) string {
return "unix:///tmp/test.sock"
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
s := tt.server(t)
assert.Equal(t, tt.expected(s), s.ListenURL())
})
}
}
85 changes: 0 additions & 85 deletions pkg/api/v1/healtcheck_test.go

This file was deleted.

Loading
Loading