diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index bca389862..54fc3fcd4 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -2,11 +2,13 @@ package runner import ( + "bytes" "context" "fmt" "net/http" "os" "os/signal" + "strings" "syscall" "time" @@ -272,6 +274,33 @@ func (r *Runner) Run(ctx context.Context) error { logger.Infof("MCP server %s started successfully", r.Config.ContainerName) + // Wait for the MCP server to accept initialize requests before updating client configurations. + // This prevents timing issues where clients try to connect before the server is fully ready. + // We repeatedly call initialize until it succeeds (up to 5 minutes). + // Note: We skip this check for pure STDIO transport because STDIO servers may reject + // multiple initialize calls (see #1982). + transportType := labels.GetTransportType(r.Config.ContainerLabels) + serverURL := transport.GenerateMCPServerURL( + transportType, + "localhost", + r.Config.Port, + r.Config.ContainerName, + r.Config.RemoteURL) + + // Only wait for initialization on non-STDIO transports + // STDIO servers communicate directly via stdin/stdout and calling initialize multiple times + // can cause issues as the behavior is not specified by the MCP spec + if transportType != "stdio" { + // Repeatedly try calling initialize until it succeeds (up to 5 minutes) + // Some servers (like mcp-optimizer) can take significant time to start up + if err := waitForInitializeSuccess(ctx, serverURL, transportType, 5*time.Minute); err != nil { + logger.Warnf("Warning: Initialize not successful, but continuing: %v", err) + // Continue anyway to maintain backward compatibility, but log a warning + } + } else { + logger.Debugf("Skipping initialize check for STDIO transport") + } + // Update client configurations with the MCP server URL. // Note that this function checks the configuration to determine which // clients should be updated, if any. @@ -279,14 +308,6 @@ func (r *Runner) Run(ctx context.Context) error { if err != nil { logger.Warnf("Warning: Failed to create client manager: %v", err) } else { - transportType := labels.GetTransportType(r.Config.ContainerLabels) - serverURL := transport.GenerateMCPServerURL( - transportType, - "localhost", - r.Config.Port, - r.Config.ContainerName, - r.Config.RemoteURL) - if err := clientManager.AddServerToClients(ctx, r.Config.ContainerName, serverURL, transportType, r.Config.Group); err != nil { logger.Warnf("Warning: Failed to add server to client configurations: %v", err) } @@ -448,3 +469,115 @@ func (r *Runner) Cleanup(ctx context.Context) error { return lastErr } + +// waitForInitializeSuccess repeatedly checks if the MCP server is ready to accept requests. +// This prevents timing issues where clients try to connect before the server is fully ready. +// It makes repeated attempts with exponential backoff up to a maximum timeout. +// Note: This function should not be called for STDIO transport. +func waitForInitializeSuccess(ctx context.Context, serverURL, transportType string, maxWaitTime time.Duration) error { + // Determine the endpoint and method to use based on transport type + var endpoint string + var method string + var payload string + + switch transportType { + case "streamable-http", "streamable": + // For streamable-http, send initialize request to /mcp endpoint + // Format: http://localhost:port/mcp + endpoint = serverURL + method = "POST" + payload = `{"jsonrpc":"2.0","method":"initialize","id":"toolhive-init-check",` + + `"params":{"protocolVersion":"2024-11-05","capabilities":{},` + + `"clientInfo":{"name":"toolhive","version":"1.0"}}}` + case "sse": + // For SSE, just check if the SSE endpoint is available + // We can't easily call initialize without establishing a full SSE connection, + // so we just verify the endpoint responds. + // Format: http://localhost:port/sse#container-name -> http://localhost:port/sse + endpoint = serverURL + // Remove fragment if present (everything after #) + if idx := strings.Index(endpoint, "#"); idx != -1 { + endpoint = endpoint[:idx] + } + method = "GET" + payload = "" + default: + // For other transports, no HTTP check is needed + logger.Debugf("Skipping readiness check for transport type: %s", transportType) + return nil + } + + // Setup retry logic with exponential backoff + startTime := time.Now() + attempt := 0 + delay := 100 * time.Millisecond + maxDelay := 2 * time.Second // Cap at 2 seconds between retries + + logger.Infof("Waiting for MCP server to be ready at %s (timeout: %v)", endpoint, maxWaitTime) + + // Create HTTP client with a reasonable timeout for requests + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + + for { + attempt++ + + // Make the readiness check request + var req *http.Request + var err error + if payload != "" { + req, err = http.NewRequestWithContext(ctx, method, endpoint, bytes.NewBufferString(payload)) + } else { + req, err = http.NewRequestWithContext(ctx, method, endpoint, nil) + } + + if err != nil { + logger.Debugf("Failed to create request (attempt %d): %v", attempt, err) + } else { + if method == "POST" { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("MCP-Protocol-Version", "2024-11-05") + } + + resp, err := httpClient.Do(req) + if err == nil { + //nolint:errcheck // Ignoring close error on response body in error path + defer resp.Body.Close() + + // For GET (SSE), accept 200 OK + // For POST (streamable-http), also accept 200 OK + if resp.StatusCode == http.StatusOK { + elapsed := time.Since(startTime) + logger.Infof("MCP server is ready after %v (attempt %d)", elapsed, attempt) + return nil + } + + logger.Debugf("Server returned status %d (attempt %d)", resp.StatusCode, attempt) + } else { + logger.Debugf("Failed to reach endpoint (attempt %d): %v", attempt, err) + } + } + + // Check if we've exceeded the maximum wait time + elapsed := time.Since(startTime) + if elapsed >= maxWaitTime { + return fmt.Errorf("initialize not successful after %v (%d attempts)", elapsed, attempt) + } + + // Wait before retrying + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for initialize") + case <-time.After(delay): + // Continue to next attempt + } + + // Update delay for next iteration with exponential backoff + delay *= 2 + if delay > maxDelay { + delay = maxDelay + } + } +} diff --git a/test/e2e/osv_mcp_server_test.go b/test/e2e/osv_mcp_server_test.go index 9488a52f0..c84b691da 100644 --- a/test/e2e/osv_mcp_server_test.go +++ b/test/e2e/osv_mcp_server_test.go @@ -20,7 +20,7 @@ func generateUniqueServerName(prefix string) string { return fmt.Sprintf("%s-%d-%d-%d", prefix, os.Getpid(), time.Now().UnixNano(), GinkgoRandomSeed()) } -var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { +var _ = Describe("OsvMcpServer", Label("mcp", "streamable-http", "e2e"), Serial, func() { var config *e2e.TestConfig BeforeEach(func() { @@ -31,7 +31,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { Expect(err).ToNot(HaveOccurred(), "thv binary should be available") }) - Describe("Running OSV MCP server with SSE transport", func() { + Describe("Running OSV MCP server with streamable-http transport", func() { Context("when starting the server from registry", func() { var serverName string @@ -47,11 +47,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { } }) - It("should successfully start and be accessible via SSE [Serial]", func() { - By("Starting the OSV MCP server with SSE transport and audit enabled") + It("should successfully start and be accessible via streamable-http [Serial]", func() { + By("Starting the OSV MCP server with streamable-http transport and audit enabled") stdout, stderr := e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "--enable-audit", "osv").ExpectSuccess() @@ -59,38 +59,38 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { Expect(stdout+stderr).To(ContainSubstring("osv"), "Output should mention the OSV server") By("Waiting for the server to be running") - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) - Expect(err).ToNot(HaveOccurred(), "Server should be running within 60 seconds") + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) + Expect(err).ToNot(HaveOccurred(), "Server should be running within 5 minutes") - By("Verifying the server appears in the list with SSE transport") + By("Verifying the server appears in the list with streamable-http transport") stdout, _ = e2e.NewTHVCommand(config, "list").ExpectSuccess() Expect(stdout).To(ContainSubstring(serverName), "Server should appear in the list") Expect(stdout).To(ContainSubstring("running"), "Server should be in running state") - Expect(stdout).To(ContainSubstring("sse"), "Server should show SSE transport") + Expect(stdout).To(ContainSubstring("mcp"), "Server should show mcp endpoint") }) - It("should be accessible via HTTP SSE endpoint [Serial]", func() { + It("should be accessible via HTTP streamable-http endpoint [Serial]", func() { By("Starting the OSV MCP server with audit enabled") e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "--enable-audit", "osv").ExpectSuccess() By("Waiting for the server to be running") - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred()) By("Getting the server URL") serverURL, err := e2e.GetMCPServerURL(config, serverName) Expect(err).ToNot(HaveOccurred(), "Should be able to get server URL") Expect(serverURL).To(ContainSubstring("http"), "URL should be HTTP-based") - Expect(serverURL).To(ContainSubstring("/sse"), "URL should contain SSE endpoint") + Expect(serverURL).To(ContainSubstring("/mcp"), "URL should contain MCP endpoint") By("Waiting before starting the HTTP request") time.Sleep(10 * time.Second) - By("Making an HTTP request to the SSE endpoint") + By("Making an HTTP request to the streamable-http endpoint") client := &http.Client{Timeout: 10 * time.Second} var resp *http.Response @@ -112,7 +112,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { time.Sleep(10 * time.Second) } - Expect(httpErr).ToNot(HaveOccurred(), "Should be able to connect to SSE endpoint") + Expect(httpErr).ToNot(HaveOccurred(), "Should be able to connect to streamable-http endpoint") Expect(resp).ToNot(BeNil(), "Response should not be nil") defer resp.Body.Close() @@ -124,11 +124,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { By("Starting the OSV MCP server") e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "osv").ExpectSuccess() By("Waiting for the server to be running") - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred()) By("Getting the server URL") @@ -136,11 +136,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { Expect(err).ToNot(HaveOccurred()) By("Waiting for MCP server to be ready") - err = e2e.WaitForMCPServerReady(config, serverURL, "sse", 60*time.Second) + err = e2e.WaitForMCPServerReady(config, serverURL, "streamable-http", 5*time.Minute) Expect(err).ToNot(HaveOccurred(), "MCP server should be ready for protocol operations") By("Creating MCP client and initializing connection") - mcpClient, err := e2e.NewMCPClientForSSE(config, serverURL) + mcpClient, err := e2e.NewMCPClientForStreamableHTTP(config, serverURL) Expect(err).ToNot(HaveOccurred(), "Should be able to create MCP client") defer mcpClient.Close() @@ -179,23 +179,23 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { // Start ONE server for ALL OSV-specific tests e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "osv").ExpectSuccess() - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred()) // Get server URL serverURL, err = e2e.GetMCPServerURL(config, serverName) Expect(err).ToNot(HaveOccurred()) - err = e2e.WaitForMCPServerReady(config, serverURL, "sse", 60*time.Second) + err = e2e.WaitForMCPServerReady(config, serverURL, "streamable-http", 5*time.Minute) Expect(err).ToNot(HaveOccurred()) }) BeforeEach(func() { // Create fresh MCP client for each test var err error - mcpClient, err = e2e.NewMCPClientForSSE(config, serverURL) + mcpClient, err = e2e.NewMCPClientForStreamableHTTP(config, serverURL) Expect(err).ToNot(HaveOccurred()) // Create context that will be cancelled in AfterEach @@ -325,9 +325,9 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { // Start a server for lifecycle tests e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "osv").ExpectSuccess() - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred()) }) @@ -339,7 +339,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { } }) - It("should stop the SSE server successfully [Serial]", func() { + It("should stop the streamable-http server successfully [Serial]", func() { By("Stopping the server") stdout, _ := e2e.NewTHVCommand(config, "stop", serverName).ExpectSuccess() Expect(stdout).To(ContainSubstring(serverName), "Output should mention the server name") @@ -355,16 +355,16 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { ), "Server should be stopped (exited) or removed from list") }) - It("should restart the SSE server successfully [Serial]", func() { + It("should restart the streamable-http server successfully [Serial]", func() { By("Restarting the server") stdout, _ := e2e.NewTHVCommand(config, "restart", serverName).ExpectSuccess() Expect(stdout).To(ContainSubstring(serverName)) By("Waiting for the server to be running again") - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred()) - By("Verifying SSE endpoint is accessible again") + By("Verifying streamable-http endpoint is accessible again") serverURL, err := e2e.GetMCPServerURL(config, serverName) Expect(err).ToNot(HaveOccurred()) @@ -378,7 +378,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { }) }) - Describe("Error handling for SSE transport", func() { + Describe("Error handling for streamable-http transport", func() { Context("when providing invalid configuration", func() { var serverName string @@ -405,7 +405,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { // Check if the command succeeded or failed if err != nil { - // If it failed, that's expected for SSE-only servers + // If it failed, that's expected for streamable-http-only servers Expect(stderr).To(ContainSubstring("transport"), "Error should mention transport issue") } else { // If it succeeded, OSV supports both transports @@ -425,10 +425,10 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { By("Starting the first OSV MCP server") e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", "osv").ExpectSuccess() + "--transport", "streamable-http", "osv").ExpectSuccess() // ensure it's actually up before attempting the duplicate - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred(), "first server should start") By("Attempting to start a second server with the same name") @@ -436,7 +436,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { // examine stdout/stderr stdout, stderr, runErr := e2e.NewTHVCommand(config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "osv").Run() // The second run must fail because the name already exists @@ -469,7 +469,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { out, errOut, _ := e2e.NewTHVCommand( config, "run", "--name", serverName, - "--transport", "sse", + "--transport", "streamable-http", "--foreground", "osv", ).RunWithTimeout(5 * time.Minute) @@ -489,7 +489,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { // 2) Wait until the server is reported as running. By("waiting for foreground server to be running") - err := e2e.WaitForMCPServer(config, serverName, 60*time.Second) + err := e2e.WaitForMCPServer(config, serverName, 5*time.Minute) Expect(err).ToNot(HaveOccurred(), "server should reach running state") // 3) Verify workload is running via workload manager @@ -516,7 +516,7 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() { Expect(stdout).To(ContainSubstring("running"), "server should be running") if serverURL, gerr := e2e.GetMCPServerURL(config, serverName); gerr == nil { - rerr := e2e.WaitForMCPServerReady(config, serverURL, "sse", 15*time.Second) + rerr := e2e.WaitForMCPServerReady(config, serverURL, "streamable-http", 5*time.Minute) Expect(rerr).ToNot(HaveOccurred(), "server should be protocol-ready") } diff --git a/test/e2e/proxy_oauth_test.go b/test/e2e/proxy_oauth_test.go index e1fcc06b5..95572299b 100644 --- a/test/e2e/proxy_oauth_test.go +++ b/test/e2e/proxy_oauth_test.go @@ -82,7 +82,7 @@ var _ = Describe("Proxy OAuth Authentication E2E", Label("proxy", "oauth", "e2e" // Wait for OIDC server to be ready Eventually(func() error { return checkServerHealth(fmt.Sprintf("%s/.well-known/openid-configuration", mockOIDCBaseURL)) - }, 30*time.Second, 1*time.Second).Should(Succeed()) + }, 5*time.Minute, 1*time.Second).Should(Succeed()) // Start OSV MCP server that will be our target By("Starting OSV MCP server as target") @@ -92,7 +92,7 @@ var _ = Describe("Proxy OAuth Authentication E2E", Label("proxy", "oauth", "e2e" "osv").ExpectSuccess() // Wait for OSV server to be ready - err = e2e.WaitForMCPServer(config, osvServerName, 60*time.Second) + err = e2e.WaitForMCPServer(config, osvServerName, 5*time.Minute) Expect(err).ToNot(HaveOccurred()) }) @@ -376,7 +376,7 @@ var _ = Describe("Proxy OAuth Authentication E2E", Label("proxy", "oauth", "e2e" proxyURL := fmt.Sprintf("http://localhost:%d/mcp", proxyPort) // Wait for proxy to be ready for MCP connections - err = e2e.WaitForMCPServerReady(config, proxyURL, "streamable-http", 60*time.Second) + err = e2e.WaitForMCPServerReady(config, proxyURL, "streamable-http", 5*time.Minute) if err != nil { GinkgoWriter.Printf("MCP connection through proxy failed: %v\n", err) Skip("Skipping MCP test due to proxy not being ready") @@ -443,7 +443,7 @@ var _ = Describe("Proxy OAuth Authentication E2E", Label("proxy", "oauth", "e2e" By("Reconnecting via MCP to trigger token refresh") proxyURL := fmt.Sprintf("http://localhost:%d/mcp", proxyPort) - err = e2e.WaitForMCPServerReady(config, proxyURL, "streamable-http", 10*time.Second) + err = e2e.WaitForMCPServerReady(config, proxyURL, "streamable-http", 5*time.Minute) Expect(err).ToNot(HaveOccurred(), "MCP server not ready after token expiry") mcpClient, err := e2e.NewMCPClientForStreamableHTTP(config, proxyURL)