From ef66d5ae4ea184e15e386315df142431e636fc4e Mon Sep 17 00:00:00 2001 From: Dmytro Rashko Date: Wed, 24 Jun 2026 15:19:15 +0200 Subject: [PATCH] Migrate to GO SDK Signed-off-by: Dmytro Rashko --- Makefile | 4 +- cmd/main.go | 134 +- cmd/metrics_wrap_test.go | 127 -- cmd/testdata/tool_names_v0.2.1.txt | 132 ++ cmd/tools_regression_test.go | 81 ++ go.mod | 10 +- go.sum | 22 +- internal/cache/cache_test.go | 27 + internal/commands/builder_setters_test.go | 51 + internal/errors/tool_errors.go | 2 +- internal/errors/tool_errors_branches_test.go | 93 ++ internal/logger/logger_test.go | 15 + internal/mcp/mcp.go | 128 ++ internal/mcp/mcp_test.go | 131 ++ internal/telemetry/config_test.go | 12 + internal/telemetry/middleware.go | 78 - internal/telemetry/middleware_test.go | 677 +-------- pkg/argo/argo.go | 283 ++-- pkg/argo/argo_test.go | 165 +-- pkg/cilium/cilium.go | 1363 +++++++++--------- pkg/cilium/cilium_test.go | 457 +++--- pkg/helm/helm.go | 315 ++-- pkg/helm/helm_test.go | 189 +-- pkg/istio/istio.go | 416 +++--- pkg/istio/istio_test.go | 194 +-- pkg/k8s/k8s.go | 1016 +++++++------ pkg/k8s/k8s_test.go | 563 +++----- pkg/kubescape/kubescape.go | 438 +++--- pkg/kubescape/kubescape_test.go | 234 ++- pkg/prometheus/prometheus.go | 204 +-- pkg/prometheus/prometheus_test.go | 251 ++-- pkg/prometheus/promql.go | 20 +- pkg/utils/common.go | 44 +- pkg/utils/common_test.go | 21 +- pkg/utils/datetime_test.go | 18 +- test/e2e/helpers_test.go | 118 +- 36 files changed, 3773 insertions(+), 4260 deletions(-) delete mode 100644 cmd/metrics_wrap_test.go create mode 100644 cmd/testdata/tool_names_v0.2.1.txt create mode 100644 cmd/tools_regression_test.go create mode 100644 internal/commands/builder_setters_test.go create mode 100644 internal/errors/tool_errors_branches_test.go create mode 100644 internal/mcp/mcp.go create mode 100644 internal/mcp/mcp_test.go diff --git a/Makefile b/Makefile index 2c7dd490..d7305e14 100644 --- a/Makefile +++ b/Makefile @@ -58,11 +58,11 @@ tidy: ## Run go mod tidy to ensure dependencies are up to date. .PHONY: test test: build lint ## Run all tests with build, lint, and coverage - go test -tags=test -v -cover ./pkg/... ./internal/... + go test -tags=test -v -cover ./pkg/... ./internal/... ./cmd/... .PHONY: test-only test-only: ## Run tests only (without build/lint for faster iteration) - go test -tags=test -v -cover ./pkg/... ./internal/... + go test -tags=test -v -cover ./pkg/... ./internal/... ./cmd/... .PHONY: e2e e2e: test retag diff --git a/cmd/main.go b/cmd/main.go index 943b7db2..85e21e74 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -16,6 +16,7 @@ import ( "github.com/joho/godotenv" "github.com/kagent-dev/tools/internal/logger" + mcpserver "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/internal/metrics" "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/internal/version" @@ -33,8 +34,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" ) var ( @@ -140,16 +140,14 @@ func run(cmd *cobra.Command, args []string) { logger.Get().Info("Running in read-only mode - write operations are disabled") } - mcp := server.NewMCPServer( - Name, - Version, - ) + mcpSrv := sdkmcp.NewServer(&sdkmcp.Implementation{Name: Name, Version: Version}, nil) + + // Attach a single receiving middleware that instruments every tools/call + // with an OTel span and Prometheus invocation counters. Per-tool provider + // labels are recorded as each provider registers its tools. + mcpSrv.AddReceivingMiddleware(mcpserver.ToolMiddleware()) - // Register tools and wrap handlers with metrics instrumentation. - // registerMCP returns a map of tool_name -> tool_provider so that - // wrapToolHandlersWithMetrics knows which provider each tool belongs to. - toolProviders := registerMCP(mcp, tools, *kubeconfig, readOnly) - wrapToolHandlersWithMetrics(mcp, toolProviders) + registerMCP(mcpSrv, tools, *kubeconfig, readOnly) // Create wait group for server goroutines var wg sync.WaitGroup @@ -167,11 +165,12 @@ func run(cmd *cobra.Command, args []string) { if stdio { go func() { defer wg.Done() - runStdioServer(ctx, mcp) + runStdioServer(ctx, mcpSrv) }() } else { - sseServer := server.NewStreamableHTTPServer(mcp, - server.WithHeartbeatInterval(30*time.Second), + sseServer := sdkmcp.NewStreamableHTTPHandler( + func(*http.Request) *sdkmcp.Server { return mcpSrv }, + nil, ) // Create a mux to handle different routes @@ -293,29 +292,27 @@ func writeResponse(w http.ResponseWriter, data []byte) error { return err } -func runStdioServer(ctx context.Context, mcp *server.MCPServer) { +func runStdioServer(ctx context.Context, mcpSrv *sdkmcp.Server) { logger.Get().Info("Running KAgent Tools Server STDIO:", "tools", strings.Join(tools, ",")) - stdioServer := server.NewStdioServer(mcp) - if err := stdioServer.Listen(ctx, os.Stdin, os.Stdout); err != nil { + if err := mcpSrv.Run(ctx, &sdkmcp.StdioTransport{}); err != nil { logger.Get().Info("Stdio server stopped", "error", err) } } -// registerMCP registers tool providers with the MCP server and returns a mapping -// of tool_name -> tool_provider. This mapping is built using the ListTools() diff -// technique: we snapshot the tool list before and after each provider registers, -// so we know exactly which tools belong to which provider. -func registerMCP(mcp *server.MCPServer, enabledToolProviders []string, kubeconfig string, readOnly bool) map[string]string { - // A map to hold tool providers and their registration functions - toolProviderMap := map[string]func(*server.MCPServer){ - "argo": func(s *server.MCPServer) { argo.RegisterTools(s, readOnly) }, - "cilium": func(s *server.MCPServer) { cilium.RegisterTools(s, readOnly) }, - "helm": func(s *server.MCPServer) { helm.RegisterTools(s, readOnly) }, - "istio": func(s *server.MCPServer) { istio.RegisterTools(s, readOnly) }, - "k8s": func(s *server.MCPServer) { k8s.RegisterTools(s, nil, kubeconfig, readOnly) }, - "kubescape": func(s *server.MCPServer) { kubescape.RegisterTools(s, kubeconfig, readOnly) }, - "prometheus": func(s *server.MCPServer) { prometheus.RegisterTools(s, readOnly) }, - "utils": func(s *server.MCPServer) { utils.RegisterTools(s, readOnly) }, +// registerMCP registers the enabled tool providers with the MCP server. Each +// provider's RegisterTools call records tool->provider mappings and the tool +// inventory metric centrally (see internal/mcp.AddTool); invocation metrics and +// tracing are applied by the receiving middleware installed in run(). +func registerMCP(mcpSrv *sdkmcp.Server, enabledToolProviders []string, kubeconfig string, readOnly bool) { + toolProviderMap := map[string]func(*sdkmcp.Server){ + "argo": func(s *sdkmcp.Server) { argo.RegisterTools(s, readOnly) }, + "cilium": func(s *sdkmcp.Server) { cilium.RegisterTools(s, readOnly) }, + "helm": func(s *sdkmcp.Server) { helm.RegisterTools(s, readOnly) }, + "istio": func(s *sdkmcp.Server) { istio.RegisterTools(s, readOnly) }, + "k8s": func(s *sdkmcp.Server) { k8s.RegisterTools(s, nil, kubeconfig, readOnly) }, + "kubescape": func(s *sdkmcp.Server) { kubescape.RegisterTools(s, kubeconfig, readOnly) }, + "prometheus": func(s *sdkmcp.Server) { prometheus.RegisterTools(s, readOnly) }, + "utils": func(s *sdkmcp.Server) { utils.RegisterTools(s, readOnly) }, } // If no specific tools are specified, register all available tools. @@ -325,82 +322,11 @@ func registerMCP(mcp *server.MCPServer, enabledToolProviders []string, kubeconfi } } - // toolToProvider maps each tool name to its provider (e.g., "kubectl_get" -> "k8s"). - // This is used later by wrapToolHandlersWithMetrics to set the correct tool_provider label. - toolToProvider := make(map[string]string) - for _, toolProviderName := range enabledToolProviders { if registerFunc, ok := toolProviderMap[toolProviderName]; ok { - // Snapshot the tool list before this provider registers its tools. - // We need this because ListTools() returns ALL tools from ALL providers, - // so the only way to know which tools belong to THIS provider is to compare - // the list before and after registration. - toolsBefore := mcp.ListTools() - - registerFunc(mcp) - - // Determine which tools were just registered by this provider - // by finding tools that exist now but didn't exist before. - // Record each one in Prometheus so we can observe the full tool inventory. - for toolName := range mcp.ListTools() { - if _, existed := toolsBefore[toolName]; !existed { - metrics.KagentToolsMCPRegisteredTools.WithLabelValues(toolName, toolProviderName).Set(1) - toolToProvider[toolName] = toolProviderName - } - } + registerFunc(mcpSrv) } else { logger.Get().Error("Unknown tool specified", "provider", toolProviderName) } } - - return toolToProvider -} - -// wrapToolHandlersWithMetrics applies the wrapper/middleware pattern to instrument -// all registered MCP tool handlers with Prometheus invocation counters. -// -// How it works: -// 1. Grab all registered tools from the MCP server using ListTools() -// 2. For each tool, wrap its handler with a function that increments metrics -// 3. Replace all tools in the MCP server using SetTools() -// -// The wrapper function: -// - Increments kagent_tools_mcp_invocations_total on every call -// - Increments kagent_tools_mcp_invocations_failure_total when the handler returns a -// non-nil Go error OR when result.IsError is true (the MCP convention for tool-level -// failures - handlers return NewToolResultError(...), nil, not a Go error) -// - Calls the original handler unchanged - the tool's behaviour is not affected -// -// This uses the standard middleware/decorator pattern: the original handler and the -// wrapped handler have the same function signature, so they are interchangeable. -// No changes are required in any pkg/ file - all instrumentation happens centrally here. -func wrapToolHandlersWithMetrics(mcpServer *server.MCPServer, toolToProvider map[string]string) { - allTools := mcpServer.ListTools() - wrapped := make([]server.ServerTool, 0, len(allTools)) - - for name, st := range allTools { - originalHandler := st.Handler - toolName := name // capture for closure - provider := toolToProvider[toolName] - - wrapped = append(wrapped, server.ServerTool{ - Tool: st.Tool, - Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - metrics.KagentToolsMCPInvocationsTotal.WithLabelValues(toolName, provider).Inc() - - result, err := originalHandler(ctx, req) - - // Count as failure if the Go error is non-nil OR if the tool returned - // a result with IsError=true (the MCP convention for tool-level failures, - // which always return nil for the Go error). - if err != nil || (result != nil && result.IsError) { - metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues(toolName, provider).Inc() - } - - return result, err - }, - }) - } - - mcpServer.SetTools(wrapped...) } diff --git a/cmd/metrics_wrap_test.go b/cmd/metrics_wrap_test.go deleted file mode 100644 index 0b8ca730..00000000 --- a/cmd/metrics_wrap_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package main - -import ( - "context" - "fmt" - "testing" - - "github.com/kagent-dev/tools/internal/metrics" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - promtest "github.com/prometheus/client_golang/prometheus/testutil" -) - -// newTestServer creates a fresh MCP server and resets the metric counters so -// tests do not interfere with each other. -func newTestServer() *server.MCPServer { - metrics.KagentToolsMCPInvocationsTotal.Reset() - metrics.KagentToolsMCPInvocationsFailureTotal.Reset() - return server.NewMCPServer("test-server", "test") -} - -// invokeWrapped registers handler on s, wraps all handlers with metrics, then -// calls the wrapped handler for toolName and returns its result. -func invokeWrapped(t *testing.T, s *server.MCPServer, toolName string, provider string, handler server.ToolHandlerFunc) (*mcp.CallToolResult, error) { - t.Helper() - s.AddTool(mcp.Tool{Name: toolName}, handler) - wrapToolHandlersWithMetrics(s, map[string]string{toolName: provider}) - st, ok := s.ListTools()[toolName] - if !ok { - t.Fatalf("tool %q not found after wrapping", toolName) - } - return st.Handler(context.Background(), mcp.CallToolRequest{}) -} - -// TestWrapToolHandlersWithMetrics_IsErrorIncrementsFailureCounter is the -// critical regression test for the bug identified in PR review: -// -// Handlers signal tool-level failures via NewToolResultError(...), nil -// (result.IsError=true, Go error=nil), so checking only `err != nil` would -// never count these as failures. -// -// To replicate manually: -// -// go test -v -run TestWrapToolHandlersWithMetrics_IsErrorIncrementsFailureCounter ./cmd/ -func TestWrapToolHandlersWithMetrics_IsErrorIncrementsFailureCounter(t *testing.T) { - s := newTestServer() - - result, err := invokeWrapped(t, s, "failing_tool", "test", - func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // This is the pattern used 214 times across pkg/ - returns a tool-level - // error with IsError=true but a nil Go error. - return mcp.NewToolResultError("kubectl: resource not found"), nil - }, - ) - - if err != nil { - t.Fatalf("expected nil Go error from handler, got: %v", err) - } - if !result.IsError { - t.Fatal("expected result.IsError=true") - } - - total := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsTotal.WithLabelValues("failing_tool", "test")) - if total != 1 { - t.Errorf("invocations_total: expected 1, got %v", total) - } - - failures := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues("failing_tool", "test")) - if failures != 1 { - t.Errorf("invocations_failure_total: expected 1, got %v (IsError=true was not counted as failure)", failures) - } -} - -// TestWrapToolHandlersWithMetrics_SuccessDoesNotIncrementFailureCounter verifies -// that a successful tool call does not touch the failure counter. -// -// To replicate manually: -// -// go test -v -run TestWrapToolHandlersWithMetrics_SuccessDoesNotIncrementFailureCounter ./cmd/ -func TestWrapToolHandlersWithMetrics_SuccessDoesNotIncrementFailureCounter(t *testing.T) { - s := newTestServer() - - _, err := invokeWrapped(t, s, "success_tool", "test", - func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return mcp.NewToolResultText("all good"), nil - }, - ) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - total := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsTotal.WithLabelValues("success_tool", "test")) - if total != 1 { - t.Errorf("invocations_total: expected 1, got %v", total) - } - - failures := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues("success_tool", "test")) - if failures != 0 { - t.Errorf("invocations_failure_total: expected 0 for a successful call, got %v", failures) - } -} - -// TestWrapToolHandlersWithMetrics_GoErrorIncrementsFailureCounter verifies -// that a real Go error (e.g. infrastructure failure) is also counted. -// -// To replicate manually: -// -// go test -v -run TestWrapToolHandlersWithMetrics_GoErrorIncrementsFailureCounter ./cmd/ -func TestWrapToolHandlersWithMetrics_GoErrorIncrementsFailureCounter(t *testing.T) { - s := newTestServer() - - _, err := invokeWrapped(t, s, "broken_tool", "test", - func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return nil, fmt.Errorf("connection refused") - }, - ) - - if err == nil { - t.Fatal("expected a Go error, got nil") - } - - failures := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues("broken_tool", "test")) - if failures != 1 { - t.Errorf("invocations_failure_total: expected 1 for Go error, got %v", failures) - } -} diff --git a/cmd/testdata/tool_names_v0.2.1.txt b/cmd/testdata/tool_names_v0.2.1.txt new file mode 100644 index 00000000..54a5c769 --- /dev/null +++ b/cmd/testdata/tool_names_v0.2.1.txt @@ -0,0 +1,132 @@ +# Tool names registered by the v0.2.1 release (pre go-sdk migration). +# Source of truth: `git grep 'mcp.NewTool("...")' v0.2.1 -- pkg/`. +# TestNoToolNameRegressions asserts every name below still exists in the +# current build so the SDK migration never silently renames/drops a tool. +# Add new tools freely; never remove a line without a deliberate, documented +# breaking change. +# Note: kubescape_get_sbom / kubescape_list_sboms were commented out (not +# registered) in v0.2.1, so they are intentionally absent here. +argo_check_plugin_logs +argo_pause_rollout +argo_promote_rollout +argo_rollouts_list +argo_set_rollout_image +argo_verify_argo_rollouts_controller_install +argo_verify_gateway_plugin +argo_verify_kubectl_plugin_install +cilium_connect_to_remote_cluster +cilium_delete_key_from_kv_store +cilium_delete_pcap_recorder +cilium_delete_policy_rules +cilium_delete_service +cilium_delete_xdp_cidr_filters +cilium_disconnect_endpoint +cilium_disconnect_remote_cluster +cilium_display_encryption_state +cilium_display_policy_node_information +cilium_display_selectors +cilium_flush_ipsec_state +cilium_fqdn_cache +cilium_get_bpf_map +cilium_get_daemon_status +cilium_get_endpoint_details +cilium_get_endpoint_health +cilium_get_endpoint_logs +cilium_get_endpoints_list +cilium_get_identity_details +cilium_get_kv_store_key +cilium_get_pcap_recorder +cilium_get_service_information +cilium_install_cilium +cilium_list_bgp_peers +cilium_list_bgp_routes +cilium_list_bpf_map_events +cilium_list_bpf_maps +cilium_list_cluster_nodes +cilium_list_envoy_config +cilium_list_identities +cilium_list_ip_addresses +cilium_list_local_redirect_policies +cilium_list_metrics +cilium_list_node_ids +cilium_list_pcap_recorders +cilium_list_services +cilium_list_xdp_cidr_filters +cilium_manage_endpoint_config +cilium_manage_endpoint_labels +cilium_request_debugging_information +cilium_set_kv_store_key +cilium_show_cluster_mesh_status +cilium_show_configuration_options +cilium_show_dns_names +cilium_show_features_status +cilium_show_ip_cache_information +cilium_show_load_information +cilium_status_and_version +cilium_toggle_cluster_mesh +cilium_toggle_configuration_option +cilium_toggle_hubble +cilium_uninstall_cilium +cilium_update_pcap_recorder +cilium_update_service +cilium_update_xdp_cidr_filters +cilium_upgrade_cilium +cilium_validate_cilium_network_policies +datetime_get_current_time +helm_get_release +helm_list_releases +helm_repo_add +helm_repo_update +helm_uninstall +helm_upgrade +istio_analyze_cluster_configuration +istio_apply_waypoint +istio_delete_waypoint +istio_generate_manifest +istio_generate_waypoint +istio_install_istio +istio_list_waypoints +istio_proxy_config +istio_proxy_status +istio_remote_clusters +istio_version +istio_waypoint_status +istio_ztunnel_config +k8s_annotate_resource +k8s_apply_manifest +k8s_check_service_connectivity +k8s_create_resource +k8s_create_resource_from_url +k8s_delete_resource +k8s_describe_resource +k8s_execute_command +k8s_generate_resource +k8s_get_available_api_resources +k8s_get_cluster_configuration +k8s_get_events +k8s_get_pod_logs +k8s_get_resource_yaml +k8s_get_resources +k8s_label_resource +k8s_patch_resource +k8s_patch_status +k8s_remove_annotation +k8s_remove_label +k8s_rollout +k8s_scale +kubescape_check_health +kubescape_get_application_profile +kubescape_get_configuration_scan +kubescape_get_network_neighborhood +kubescape_get_vulnerability_details +kubescape_list_application_profiles +kubescape_list_configuration_scans +kubescape_list_network_neighborhoods +kubescape_list_vulnerabilities +kubescape_list_vulnerability_manifests +prometheus_label_names_tool +prometheus_promql_tool +prometheus_query_range_tool +prometheus_query_tool +prometheus_targets_tool +shell diff --git a/cmd/tools_regression_test.go b/cmd/tools_regression_test.go new file mode 100644 index 00000000..83ae9b33 --- /dev/null +++ b/cmd/tools_regression_test.go @@ -0,0 +1,81 @@ +package main + +import ( + "bufio" + "context" + "os" + "sort" + "strings" + "testing" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// registeredToolNames spins up an in-process MCP server with every provider +// registered (readOnly=false so mutating tools are included too), connects an +// in-memory client, and returns the set of advertised tool names — the same +// list a real MCP client would see over the wire. +func registeredToolNames(t *testing.T) map[string]bool { + t.Helper() + ctx := context.Background() + + srv := sdkmcp.NewServer(&sdkmcp.Implementation{Name: "regression", Version: "test"}, nil) + registerMCP(srv, nil, "", false) // nil providers => register them all + + serverT, clientT := sdkmcp.NewInMemoryTransports() + go func() { _ = srv.Run(ctx, serverT) }() + + client := sdkmcp.NewClient(&sdkmcp.Implementation{Name: "regression-client", Version: "test"}, nil) + session, err := client.Connect(ctx, clientT, nil) + require.NoError(t, err) + defer func() { _ = session.Close() }() + + names := make(map[string]bool) + for tool, err := range session.Tools(ctx, nil) { + require.NoError(t, err) + names[tool.Name] = true + } + require.NotEmpty(t, names, "expected the server to advertise tools") + return names +} + +// readGoldenToolNames loads the committed list of tool names, ignoring blank +// lines and '#' comments. +func readGoldenToolNames(t *testing.T, path string) []string { + t.Helper() + f, err := os.Open(path) + require.NoError(t, err) + defer func() { _ = f.Close() }() + + var names []string + sc := bufio.NewScanner(f) + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + names = append(names, line) + } + require.NoError(t, sc.Err()) + return names +} + +// TestNoToolNameRegressions guards the go-sdk migration: every tool name shipped +// in the v0.2.1 release must still be registered under the same name. New tools +// are allowed; renames or removals are caught here. +func TestNoToolNameRegressions(t *testing.T) { + current := registeredToolNames(t) + old := readGoldenToolNames(t, "testdata/tool_names_v0.2.1.txt") + + var missing []string + for _, name := range old { + if !current[name] { + missing = append(missing, name) + } + } + sort.Strings(missing) + + assert.Emptyf(t, missing, "%d tool(s) from v0.2.1 are missing/renamed in the current build: %v", len(missing), missing) +} diff --git a/go.mod b/go.mod index 7535dbd8..3bd83439 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/kubescape/k8s-interface v0.0.203 github.com/kubescape/storage v0.0.239 - github.com/mark3labs/mcp-go v0.43.2 + github.com/modelcontextprotocol/go-sdk v1.6.1 github.com/onsi/ginkgo/v2 v2.27.2 github.com/onsi/gomega v1.38.2 github.com/prometheus/client_golang v1.23.2 @@ -38,13 +38,11 @@ require ( github.com/armosec/gojay v1.2.17 // indirect github.com/armosec/utils-go v0.0.58 // indirect github.com/armosec/utils-k8s-go v0.0.35 // indirect - github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/becheran/wildmatch-go v1.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/bmatcuk/doublestar/v4 v4.9.1 // indirect github.com/briandowns/spinner v1.23.2 // indirect - github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -99,6 +97,7 @@ require ( github.com/google/gnostic-models v0.7.1 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-containerregistry v0.20.6 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect github.com/google/licensecheck v0.3.1 // indirect github.com/google/pprof v0.0.0-20251114195745-4902fdda35c8 // indirect github.com/google/uuid v1.6.0 // indirect @@ -106,14 +105,12 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/invopop/jsonschema v0.13.0 // indirect github.com/jinzhu/copier v0.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/kubescape/go-logger v0.0.26 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/mackerelio/go-osstat v0.2.6 // indirect - github.com/mailru/easyjson v0.9.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect @@ -138,6 +135,8 @@ require ( github.com/sasha-s/go-deadlock v0.3.6 // indirect github.com/scylladb/go-set v1.0.3-0.20200225121959-cc7b2070d91e // indirect github.com/seccomp/libseccomp-golang v0.10.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect @@ -155,7 +154,6 @@ require ( github.com/vishvananda/netns v0.0.5 // indirect github.com/wagoodman/go-partybus v0.0.0-20230516145632-8ccac152c651 // indirect github.com/wagoodman/go-progress v0.0.0-20230925121702-07e42b3cdba0 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yl2chen/cidranger v1.0.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect diff --git a/go.sum b/go.sum index 8e473845..145106da 100644 --- a/go.sum +++ b/go.sum @@ -100,8 +100,6 @@ github.com/armosec/utils-go v0.0.58 h1:g9RnRkxZAmzTfPe2ruMo2OXSYLwVSegQSkSavOfma github.com/armosec/utils-go v0.0.58/go.mod h1:CdqKHKruVJMCxGcZXYW9J+5P9FZou8dMzVpcB0Xt8pk= github.com/armosec/utils-k8s-go v0.0.35 h1:CliNObhAca5UYl84m5OQecOTm9ZfMFI8648pYhQJiu4= github.com/armosec/utils-k8s-go v0.0.35/go.mod h1:iHwR/KhMFtdd8Px1oYexLZYOHqmdknfGTZ8b7sZS0Ms= -github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= -github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/becheran/wildmatch-go v1.0.0 h1:mE3dGGkTmpKtT4Z+88t8RStG40yN9T+kFEGj2PZFSzA= github.com/becheran/wildmatch-go v1.0.0/go.mod h1:gbMvj0NtVdJ15Mg/mH9uxk2R1QCistMyU7d9KFzroX4= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -117,8 +115,6 @@ github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBT github.com/briandowns/spinner v1.23.2 h1:Zc6ecUnI+YzLmJniCfDNaMbW0Wid1d5+qcTq4L2FW8w= github.com/briandowns/spinner v1.23.2/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= -github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= -github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= @@ -305,6 +301,8 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gohugoio/hashstructure v0.5.0 h1:G2fjSBU36RdwEJBWJ+919ERvOVqAg9tfcYp47K9swqg= github.com/gohugoio/hashstructure v0.5.0/go.mod h1:Ser0TniXuu/eauYmrwM4o64EBvySxNzITEOLlm4igec= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -363,6 +361,8 @@ github.com/google/go-containerregistry v0.20.6/go.mod h1:T0x8MuoAoKX/873bkeSfLD2 github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/licensecheck v0.3.1 h1:QoxgoDkaeC4nFrtGN1jV7IPmDCHFNIVh54e5hSt6sPs= github.com/google/licensecheck v0.3.1/go.mod h1:ORkR35t/JjW+emNKtfJDII0zlciG9JgbT7SmsohlHmY= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -446,8 +446,6 @@ github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1: github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= -github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= @@ -493,10 +491,6 @@ github.com/mackerelio/go-osstat v0.2.6 h1:gs4U8BZeS1tjrL08tt5VUliVvSWP26Ai2Ob8Lr github.com/mackerelio/go-osstat v0.2.6/go.mod h1:lRy8V9ZuHpuRVZh+vyTkODeDPl3/d5MgXHtLSaqG8bA= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= -github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= -github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= @@ -534,6 +528,8 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg= github.com/moby/sys/mountinfo v0.7.2/go.mod h1:1YOa8w8Ih7uW0wALDUgT1dTTSBrZ+HiBLGws92L2RU4= +github.com/modelcontextprotocol/go-sdk v1.6.1 h1:0zOSupjKUxPKSocPT1Wtago+mUHU2/uZ4xSOY0FGReU= +github.com/modelcontextprotocol/go-sdk v1.6.1/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -630,6 +626,10 @@ github.com/scylladb/go-set v1.0.3-0.20200225121959-cc7b2070d91e/go.mod h1:DkpGd7 github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/seccomp/libseccomp-golang v0.10.0 h1:aA4bp+/Zzi0BnWZ2F1wgNBs5gTpm+na2rWM6M9YjLpY= github.com/seccomp/libseccomp-golang v0.10.0/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -735,8 +735,6 @@ github.com/wagoodman/go-partybus v0.0.0-20230516145632-8ccac152c651 h1:jIVmlAFIq github.com/wagoodman/go-partybus v0.0.0-20230516145632-8ccac152c651/go.mod h1:b26F2tHLqaoRQf8DywqzVaV1MQ9yvjb0OMcNl7Nxu20= github.com/wagoodman/go-progress v0.0.0-20230925121702-07e42b3cdba0 h1:0KGbf+0SMg+UFy4e1A/CPVvXn21f1qtWdeJwxZFoQG8= github.com/wagoodman/go-progress v0.0.0-20230925121702-07e42b3cdba0/go.mod h1:jLXFoL31zFaHKAAyZUh+sxiTDFe1L1ZHrcK2T1itVKA= -github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index cc7cf641..2840d618 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -8,6 +8,33 @@ import ( "github.com/stretchr/testify/assert" ) +func TestInvalidateHelpers(t *testing.T) { + // Seed each cache, then assert the type-specific invalidators clear it. + InitCaches() + for _, ct := range []CacheType{CacheTypeKubernetes, CacheTypeHelm, CacheTypeIstio, CacheTypeCommand} { + GetCacheByType(ct).Set("k", "v") + } + + InvalidateKubernetesCache() + InvalidateHelmCache() + InvalidateIstioCache() + InvalidateCommandCache() + + for _, ct := range []CacheType{CacheTypeKubernetes, CacheTypeHelm, CacheTypeIstio, CacheTypeCommand} { + if _, ok := GetCacheByType(ct).Get("k"); ok { + t.Errorf("expected %s cache to be invalidated", ct.String()) + } + } + + // Known command routes to its mapped cache; unknown falls back to command cache. + GetCacheByType(CacheTypeKubernetes).Set("k", "v") + InvalidateCacheForCommand("kubectl") + if _, ok := GetCacheByType(CacheTypeKubernetes).Get("k"); ok { + t.Error("expected kubectl command to invalidate kubernetes cache") + } + assert.NotPanics(t, func() { InvalidateCacheForCommand("totally-unknown-cmd") }) +} + func TestNewCache(t *testing.T) { cache := NewCache[string]("test-cache", 1*time.Minute, 100, 10*time.Second) diff --git a/internal/commands/builder_setters_test.go b/internal/commands/builder_setters_test.go new file mode 100644 index 00000000..1ed7ea84 --- /dev/null +++ b/internal/commands/builder_setters_test.go @@ -0,0 +1,51 @@ +package commands + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestBuilderSettersValidation exercises both the accept and reject branches of +// the validating setters so invalid input is silently dropped, not applied. +func TestBuilderSettersValidation(t *testing.T) { + t.Run("WithToken", func(t *testing.T) { + cb := NewCommandBuilder("kubectl").WithToken("secret") + assert.Equal(t, "secret", cb.token) + // Empty token is a no-op and keeps the previous value. + cb.WithToken("") + assert.Equal(t, "secret", cb.token) + }) + + t.Run("WithContext", func(t *testing.T) { + cb := NewCommandBuilder("kubectl").WithContext("prod-cluster") + assert.Equal(t, "prod-cluster", cb.context) + // Injection attempt is rejected, leaving the prior value intact. + cb.WithContext("ctx; rm -rf /") + assert.Equal(t, "prod-cluster", cb.context) + }) + + t.Run("WithKubeconfig", func(t *testing.T) { + cb := NewCommandBuilder("kubectl").WithKubeconfig("/home/user/.kube/config") + assert.Equal(t, "/home/user/.kube/config", cb.kubeconfig) + // Path traversal is rejected. + cb.WithKubeconfig("../../etc/passwd") + assert.Equal(t, "/home/user/.kube/config", cb.kubeconfig) + }) + + t.Run("WithLabel", func(t *testing.T) { + cb := NewCommandBuilder("kubectl").WithLabel("app", "nginx") + assert.Equal(t, "nginx", cb.labels["app"]) + // Empty key is invalid and must not be stored. + cb.WithLabel("", "x") + assert.NotContains(t, cb.labels, "") + }) + + t.Run("WithAnnotation", func(t *testing.T) { + cb := NewCommandBuilder("kubectl").WithAnnotation("team", "sre") + assert.Equal(t, "sre", cb.annotations["team"]) + // Invalid key format is rejected. + cb.WithAnnotation("bad key!", "v") + assert.NotContains(t, cb.annotations, "bad key!") + }) +} diff --git a/internal/errors/tool_errors.go b/internal/errors/tool_errors.go index 12a7fd9c..5b67252b 100644 --- a/internal/errors/tool_errors.go +++ b/internal/errors/tool_errors.go @@ -5,7 +5,7 @@ import ( "strings" "time" - "github.com/mark3labs/mcp-go/mcp" + mcp "github.com/kagent-dev/tools/internal/mcp" ) // ToolError represents a structured error with context and recovery suggestions diff --git a/internal/errors/tool_errors_branches_test.go b/internal/errors/tool_errors_branches_test.go new file mode 100644 index 00000000..1f2ca86b --- /dev/null +++ b/internal/errors/tool_errors_branches_test.go @@ -0,0 +1,93 @@ +package errors + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +// errorCodeCase exercises one keyword-driven branch of a component error +// constructor and asserts the resulting error code and retryability. +type errorCodeCase struct { + cause string + expectedCode string + expectedRetry bool +} + +func runErrorCodeCases(t *testing.T, component string, ctor func(string, error) *ToolError, cases []errorCodeCase) { + t.Helper() + for _, c := range cases { + t.Run(c.expectedCode, func(t *testing.T) { + err := ctor("op", errors.New(c.cause)) + assert.Equal(t, component, err.Component) + assert.Equal(t, c.expectedCode, err.ErrorCode) + assert.Equal(t, c.expectedRetry, err.IsRetryable) + assert.NotEmpty(t, err.Suggestions) + }) + } +} + +func TestNewIstioErrorBranches(t *testing.T) { + runErrorCodeCases(t, "Istio", NewIstioError, []errorCodeCase{ + {"resource not found", "ISTIO_RESOURCE_NOT_FOUND", false}, + {"connection refused", "ISTIO_CONNECTION_ERROR", true}, + {"boom", "ISTIO_GENERIC_ERROR", true}, + }) +} + +func TestNewPrometheusErrorBranches(t *testing.T) { + runErrorCodeCases(t, "Prometheus", NewPrometheusError, []errorCodeCase{ + {"connection refused", "PROMETHEUS_CONNECTION_ERROR", true}, + {"parse error", "PROMETHEUS_QUERY_ERROR", false}, + {"boom", "PROMETHEUS_GENERIC_ERROR", true}, + }) +} + +func TestNewArgoErrorBranches(t *testing.T) { + runErrorCodeCases(t, "Argo Rollouts", NewArgoError, []errorCodeCase{ + {"rollout not found", "ARGO_ROLLOUT_NOT_FOUND", false}, + {"plugin missing", "ARGO_PLUGIN_ERROR", true}, + {"boom", "ARGO_GENERIC_ERROR", true}, + }) +} + +func TestNewCiliumErrorBranches(t *testing.T) { + runErrorCodeCases(t, "Cilium", NewCiliumError, []errorCodeCase{ + {"cilium not found", "CILIUM_NOT_FOUND", false}, + {"connection lost", "CILIUM_CONNECTION_ERROR", true}, + {"boom", "CILIUM_GENERIC_ERROR", true}, + }) +} + +func TestNewKubescapeErrorBranches(t *testing.T) { + // "not found" branches further specialize by operation keyword. + notFoundOps := []string{"vulnerability scan", "sbom build", "configuration scan", "application_profile get", "network_neighborhood get", "other op"} + for _, op := range notFoundOps { + t.Run("not_found/"+op, func(t *testing.T) { + err := NewKubescapeError(op, errors.New("resource not found")) + assert.Equal(t, "Kubescape", err.Component) + assert.Equal(t, "KUBESCAPE_RESOURCE_NOT_FOUND", err.ErrorCode) + assert.False(t, err.IsRetryable) + assert.NotEmpty(t, err.Suggestions) + }) + } + + runErrorCodeCases(t, "Kubescape", NewKubescapeError, []errorCodeCase{ + {"connection refused", "KUBESCAPE_CONNECTION_ERROR", true}, + {"timeout exceeded", "KUBESCAPE_CONNECTION_ERROR", true}, + {"forbidden", "KUBESCAPE_PERMISSION_ERROR", false}, + {"boom", "KUBESCAPE_GENERIC_ERROR", true}, + }) +} + +// TestToMCPResultRendersAllSections ensures the optional resource/context +// sections of ToMCPResult are exercised. +func TestToMCPResultRendersAllSections(t *testing.T) { + res := NewKubernetesError("op", errors.New("not found")). + WithResource("Pod", "web"). + WithContext("namespace", "default"). + ToMCPResult() + assert.True(t, res.IsError) + assert.NotEmpty(t, res.Content) +} diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index f6befc5c..fc532a54 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -9,9 +9,24 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" ) +func TestWithContext(t *testing.T) { + // Without a span the base logger is returned unchanged. + assert.NotNil(t, WithContext(context.Background())) + + // With a valid span context, trace_id/span_id are attached (exercises the branch). + sc := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: trace.TraceID{0x01}, + SpanID: trace.SpanID{0x02}, + TraceFlags: trace.FlagsSampled, + }) + ctx := trace.ContextWithSpanContext(context.Background(), sc) + assert.NotNil(t, WithContext(ctx)) +} + func TestRedactArgsForLog(t *testing.T) { t.Run("redacts token value", func(t *testing.T) { args := []string{"get", "pods", "--token", "secret-token-123", "-n", "default"} diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go new file mode 100644 index 00000000..7b41e13f --- /dev/null +++ b/internal/mcp/mcp.go @@ -0,0 +1,128 @@ +// Package mcp adapts the modelcontextprotocol/go-sdk server to the kagent-tools +// providers. It re-exports the SDK types the providers need, supplies result +// constructors compatible with the previous mark3labs helpers, and centralizes +// tracing/metrics instrumentation as a single receiving middleware so provider +// packages register tools with one typed call and no per-tool wrapping. +package mcp + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/kagent-dev/tools/internal/metrics" + sdk "github.com/modelcontextprotocol/go-sdk/mcp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" +) + +// Re-exported SDK types so provider packages depend on a single import. +type ( + // Server is the MCP server tools are registered on. + Server = sdk.Server + // Tool describes a tool's name, description and (inferred) input schema. + Tool = sdk.Tool + // CallToolRequest is the server-side request passed to a tool handler. + CallToolRequest = sdk.CallToolRequest + // CallToolResult is the result returned by a tool handler. + CallToolResult = sdk.CallToolResult + // Implementation identifies the server to clients. + Implementation = sdk.Implementation + // Content is a single piece of tool result content. + Content = sdk.Content + // TextContent is textual tool result content. + TextContent = sdk.TextContent + // RequestExtra carries transport-level extras (e.g. HTTP headers) on a request. + RequestExtra = sdk.RequestExtra +) + +// NewServer constructs a new MCP server. +var NewServer = sdk.NewServer + +// NewToolResultText returns a successful text result. +func NewToolResultText(text string) *sdk.CallToolResult { + return &sdk.CallToolResult{Content: []sdk.Content{&sdk.TextContent{Text: text}}} +} + +// NewToolResultError returns a tool-level error result (IsError=true). Handlers +// return this together with a nil Go error, per MCP convention. +func NewToolResultError(message string) *sdk.CallToolResult { + return &sdk.CallToolResult{Content: []sdk.Content{&sdk.TextContent{Text: message}}, IsError: true} +} + +// Header returns the HTTP headers carried with the request, or nil for stdio / +// in-process calls. Used for bearer-token passthrough. +func Header(req *sdk.CallToolRequest) http.Header { + if req != nil && req.Extra != nil { + return req.Extra.Header + } + return nil +} + +// providerByTool maps a registered tool name to its provider for metric labels. +var providerByTool sync.Map + +// AddTool registers a typed tool and records its provider for metrics. The input +// schema is inferred from In's json/jsonschema struct tags by the SDK. +func AddTool[In, Out any](s *sdk.Server, provider string, t *sdk.Tool, h sdk.ToolHandlerFor[In, Out]) { + providerByTool.Store(t.Name, provider) + metrics.KagentToolsMCPRegisteredTools.WithLabelValues(t.Name, provider).Set(1) + sdk.AddTool(s, t, h) +} + +func providerOf(tool string) string { + if v, ok := providerByTool.Load(tool); ok { + return v.(string) + } + return "" +} + +// ToolMiddleware instruments every tools/call with an OTel span and Prometheus +// invocation counters. Register once via server.AddReceivingMiddleware. +func ToolMiddleware() sdk.Middleware { + return func(next sdk.MethodHandler) sdk.MethodHandler { + return func(ctx context.Context, method string, req sdk.Request) (sdk.Result, error) { + if method != "tools/call" { + return next(ctx, method, req) + } + + toolName := "" + if ctr, ok := req.(*sdk.CallToolRequest); ok && ctr.Params != nil { + toolName = ctr.Params.Name + } + provider := providerOf(toolName) + + tracer := otel.Tracer("kagent-tools/mcp") + ctx, span := tracer.Start(ctx, "mcp.tool."+toolName) + defer span.End() + span.SetAttributes( + attribute.String("mcp.tool.name", toolName), + attribute.String("mcp.tool.provider", provider), + ) + + metrics.KagentToolsMCPInvocationsTotal.WithLabelValues(toolName, provider).Inc() + start := time.Now() + + res, err := next(ctx, method, req) + + span.SetAttributes(attribute.Float64("mcp.tool.duration_seconds", time.Since(start).Seconds())) + + failed := err != nil + if ctres, ok := res.(*sdk.CallToolResult); ok && ctres != nil && ctres.IsError { + failed = true + } + if failed { + metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues(toolName, provider).Inc() + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + } else { + span.SetStatus(codes.Ok, "ok") + } + return res, err + } + } +} diff --git a/internal/mcp/mcp_test.go b/internal/mcp/mcp_test.go new file mode 100644 index 00000000..2289c776 --- /dev/null +++ b/internal/mcp/mcp_test.go @@ -0,0 +1,131 @@ +package mcp + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/kagent-dev/tools/internal/metrics" + sdk "github.com/modelcontextprotocol/go-sdk/mcp" + promtest "github.com/prometheus/client_golang/prometheus/testutil" +) + +// invokeMiddleware runs ToolMiddleware around next for a tools/call to toolName +// (registered to provider) and returns the result/error. +func invokeMiddleware(toolName, provider string, next sdk.MethodHandler) (sdk.Result, error) { + metrics.KagentToolsMCPInvocationsTotal.Reset() + metrics.KagentToolsMCPInvocationsFailureTotal.Reset() + providerByTool.Store(toolName, provider) + + h := ToolMiddleware()(next) + req := &sdk.CallToolRequest{Params: &sdk.CallToolParamsRaw{Name: toolName}} + return h(context.Background(), "tools/call", req) +} + +func TestHeader(t *testing.T) { + assert := func(cond bool, msg string) { + if !cond { + t.Fatal(msg) + } + } + + // nil request and request without Extra yield no headers. + assert(Header(nil) == nil, "nil request should give nil header") + assert(Header(&sdk.CallToolRequest{}) == nil, "request without Extra should give nil header") + + h := http.Header{"Authorization": []string{"Bearer t"}} + req := &sdk.CallToolRequest{Extra: &sdk.RequestExtra{Header: h}} + if got := Header(req).Get("Authorization"); got != "Bearer t" { + t.Fatalf("expected passthrough header, got %q", got) + } +} + +func TestAddToolRecordsProvider(t *testing.T) { + metrics.KagentToolsMCPRegisteredTools.Reset() + s := NewServer(&Implementation{Name: "t", Version: "v"}, nil) + + type in struct { + Name string `json:"name"` + } + AddTool(s, "myprovider", &Tool{Name: "my_tool"}, func(_ context.Context, _ *CallToolRequest, _ in) (*CallToolResult, any, error) { + return NewToolResultText("ok"), nil, nil + }) + + if got := providerOf("my_tool"); got != "myprovider" { + t.Errorf("providerOf: expected myprovider, got %q", got) + } + if got := providerOf("unknown_tool"); got != "" { + t.Errorf("providerOf unknown: expected empty, got %q", got) + } + if v := promtest.ToFloat64(metrics.KagentToolsMCPRegisteredTools.WithLabelValues("my_tool", "myprovider")); v != 1 { + t.Errorf("registered_tools metric: expected 1, got %v", v) + } +} + +// TestToolMiddleware_IsErrorIncrementsFailureCounter is the regression test for +// the bug identified in PR review: handlers signal tool-level failures via +// NewToolResultError(...) (IsError=true, Go error=nil), so checking only +// `err != nil` would never count these as failures. +func TestToolMiddleware_IsErrorIncrementsFailureCounter(t *testing.T) { + result, err := invokeMiddleware("failing_tool", "test", + func(_ context.Context, _ string, _ sdk.Request) (sdk.Result, error) { + return NewToolResultError("kubectl: resource not found"), nil + }, + ) + if err != nil { + t.Fatalf("expected nil Go error, got: %v", err) + } + if ctr, ok := result.(*sdk.CallToolResult); !ok || !ctr.IsError { + t.Fatal("expected result.IsError=true") + } + + total := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsTotal.WithLabelValues("failing_tool", "test")) + if total != 1 { + t.Errorf("invocations_total: expected 1, got %v", total) + } + failures := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues("failing_tool", "test")) + if failures != 1 { + t.Errorf("invocations_failure_total: expected 1, got %v (IsError=true was not counted)", failures) + } +} + +// TestToolMiddleware_SuccessDoesNotIncrementFailureCounter verifies a successful +// call leaves the failure counter untouched. +func TestToolMiddleware_SuccessDoesNotIncrementFailureCounter(t *testing.T) { + _, err := invokeMiddleware("success_tool", "test", + func(_ context.Context, _ string, _ sdk.Request) (sdk.Result, error) { + return NewToolResultText("all good"), nil + }, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + total := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsTotal.WithLabelValues("success_tool", "test")) + if total != 1 { + t.Errorf("invocations_total: expected 1, got %v", total) + } + failures := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues("success_tool", "test")) + if failures != 0 { + t.Errorf("invocations_failure_total: expected 0, got %v", failures) + } +} + +// TestToolMiddleware_GoErrorIncrementsFailureCounter verifies a real Go error is +// counted as a failure. +func TestToolMiddleware_GoErrorIncrementsFailureCounter(t *testing.T) { + _, err := invokeMiddleware("broken_tool", "test", + func(_ context.Context, _ string, _ sdk.Request) (sdk.Result, error) { + return nil, fmt.Errorf("connection refused") + }, + ) + if err == nil { + t.Fatal("expected a Go error, got nil") + } + + failures := promtest.ToFloat64(metrics.KagentToolsMCPInvocationsFailureTotal.WithLabelValues("broken_tool", "test")) + if failures != 1 { + t.Errorf("invocations_failure_total: expected 1, got %v", failures) + } +} diff --git a/internal/telemetry/config_test.go b/internal/telemetry/config_test.go index fe6454b5..e116a1e2 100644 --- a/internal/telemetry/config_test.go +++ b/internal/telemetry/config_test.go @@ -8,6 +8,18 @@ import ( "github.com/stretchr/testify/assert" ) +func TestGetEnvFloat(t *testing.T) { + const key = "KAGENT_TEST_ENV_FLOAT" + + assert.Equal(t, 1.5, getEnvFloat(key, 1.5)) // unset -> fallback + + t.Setenv(key, "0.25") + assert.Equal(t, 0.25, getEnvFloat(key, 1.5)) // parsed + + t.Setenv(key, "not-a-float") + assert.Equal(t, 1.5, getEnvFloat(key, 1.5)) // parse error -> fallback +} + func TestLoad(t *testing.T) { // Reset singleton for testing once = sync.Once{} diff --git a/internal/telemetry/middleware.go b/internal/telemetry/middleware.go index 720a99b8..3bc8f1a5 100644 --- a/internal/telemetry/middleware.go +++ b/internal/telemetry/middleware.go @@ -2,13 +2,8 @@ package telemetry import ( "context" - "encoding/json" - "fmt" "net/http" - "time" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -16,8 +11,6 @@ import ( "go.opentelemetry.io/otel/trace" ) -type ToolHandler func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - // contextKey is used for storing HTTP context in the request context type contextKey string @@ -83,70 +76,6 @@ func ExtractTraceInfo(ctx context.Context) (traceID, spanID string) { return traceID, spanID } -func WithTracing(toolName string, handler ToolHandler) ToolHandler { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - tracer := otel.Tracer("kagent-tools/mcp") - - spanName := fmt.Sprintf("mcp.tool.%s", toolName) - ctx, span := tracer.Start(ctx, spanName) - defer span.End() - - // Extract HTTP headers from context and add as span attributes - headers := ExtractHTTPHeaders(ctx) - for key, value := range headers { - span.SetAttributes(attribute.String(fmt.Sprintf("http.header.%s", key), value)) - } - - // Extract parent trace information - parentTraceID, parentSpanID := ExtractTraceInfo(ctx) - if parentTraceID != "" { - span.SetAttributes( - attribute.String("http.parent_trace_id", parentTraceID), - attribute.String("http.parent_span_id", parentSpanID), - ) - } - - span.SetAttributes( - attribute.String("mcp.tool.name", toolName), - attribute.String("mcp.request.id", request.Params.Name), - ) - - if request.Params.Arguments != nil { - if argsJSON, err := json.Marshal(request.Params.Arguments); err == nil { - span.SetAttributes(attribute.String("mcp.request.arguments", string(argsJSON))) - } - } - - span.AddEvent("tool.execution.start") - startTime := time.Now() - - result, err := handler(ctx, request) - - duration := time.Since(startTime) - span.SetAttributes(attribute.Float64("mcp.tool.duration_seconds", duration.Seconds())) - - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - span.AddEvent("tool.execution.error", trace.WithAttributes( - attribute.String("error.message", err.Error()), - )) - } else { - span.SetStatus(codes.Ok, "tool execution completed successfully") - span.AddEvent("tool.execution.success") - - if result != nil { - span.SetAttributes(attribute.Bool("mcp.result.is_error", result.IsError)) - if result.Content != nil { - span.SetAttributes(attribute.Int("mcp.result.content_count", len(result.Content))) - } - } - } - - return result, err - } -} - func StartSpan(ctx context.Context, operationName string, attrs ...attribute.KeyValue) (context.Context, trace.Span) { tracer := otel.Tracer("kagent-tools") ctx, span := tracer.Start(ctx, operationName) @@ -170,10 +99,3 @@ func RecordSuccess(span trace.Span, message string) { func AddEvent(span trace.Span, name string, attrs ...attribute.KeyValue) { span.AddEvent(name, trace.WithAttributes(attrs...)) } - -// AdaptToolHandler adapts a telemetry.ToolHandler to a server.ToolHandlerFunc. -func AdaptToolHandler(th ToolHandler) server.ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return th(ctx, req) - } -} diff --git a/internal/telemetry/middleware_test.go b/internal/telemetry/middleware_test.go index bcbf494c..9d1a4517 100644 --- a/internal/telemetry/middleware_test.go +++ b/internal/telemetry/middleware_test.go @@ -3,19 +3,52 @@ package telemetry import ( "context" "errors" + "net/http" + "net/http/httptest" "testing" - "time" - "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/sdk/trace" - "go.opentelemetry.io/otel/trace/noop" ) +func TestHTTPMiddleware(t *testing.T) { + provider, _ := setupTracing() + defer func() { _ = provider.Shutdown(context.Background()) }() + + var gotHeaders map[string]string + var gotTraceID, gotSpanID string + next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotHeaders = ExtractHTTPHeaders(r.Context()) + gotTraceID, gotSpanID = ExtractTraceInfo(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("Authorization", "Bearer abc") + req.Header.Set("User-Agent", "agent/1.0") + req.Header.Set("X-Ignored", "nope") + + HTTPMiddleware(next).ServeHTTP(httptest.NewRecorder(), req) + + require.NotNil(t, gotHeaders) + assert.Equal(t, "Bearer abc", gotHeaders["Authorization"]) + assert.Equal(t, "agent/1.0", gotHeaders["User-Agent"]) + assert.NotContains(t, gotHeaders, "X-Ignored") + // No inbound trace context here, so trace/span IDs stay empty. + assert.Empty(t, gotTraceID) + assert.Empty(t, gotSpanID) +} + +func TestExtractHelpersDefaults(t *testing.T) { + assert.Empty(t, ExtractHTTPHeaders(context.Background())) + tid, sid := ExtractTraceInfo(context.Background()) + assert.Empty(t, tid) + assert.Empty(t, sid) +} + // InMemoryExporter is a simple in-memory exporter for testing type InMemoryExporter struct { spans []trace.ReadOnlySpan @@ -45,345 +78,7 @@ func setupTracing() (*trace.TracerProvider, *InMemoryExporter) { return provider, exporter } -func TestWithTracing(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("test response") - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{textContent}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - Arguments: map[string]interface{}{ - "param1": "value1", - "param2": 42, - }, - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - require.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - assert.Len(t, result.Content, 1) - textContent, ok := mcp.AsTextContent(result.Content[0]) - require.True(t, ok) - assert.Equal(t, "test response", textContent.Text) - - // Verify span was created - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) - assert.Equal(t, codes.Ok, span.Status().Code) - // Note: SDK may not preserve description in test environment - // assert.Equal(t, "tool execution completed successfully", span.Status().Description) - - // Verify attributes - attributes := span.Attributes() - hasToolName := false - hasRequestID := false - hasIsError := false - hasContentCount := false - - for _, attr := range attributes { - if attr.Key == "mcp.tool.name" && attr.Value.AsString() == "test-tool" { - hasToolName = true - } - if attr.Key == "mcp.request.id" && attr.Value.AsString() == "test-tool" { - hasRequestID = true - } - if attr.Key == "mcp.result.is_error" && attr.Value.AsBool() == false { - hasIsError = true - } - if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 1 { - hasContentCount = true - } - } - - assert.True(t, hasToolName) - assert.True(t, hasRequestID) - assert.True(t, hasIsError) - assert.True(t, hasContentCount) - - // Verify events - events := span.Events() - assert.Len(t, events, 2) - assert.Equal(t, "tool.execution.start", events[0].Name) - assert.Equal(t, "tool.execution.success", events[1].Name) -} - -func TestWithTracingError(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler that returns an error - testError := errors.New("test error") - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return nil, testError - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - assert.Error(t, err) - assert.Equal(t, testError, err) - assert.Nil(t, result) - - // Verify span was created with error - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) - assert.Equal(t, codes.Error, span.Status().Code) - // Note: SDK may not preserve description in test environment - // assert.Equal(t, "test error", span.Status().Description) - - // Verify events - span.RecordError() adds an "exception" event, plus our custom events - events := span.Events() - assert.Len(t, events, 3) - assert.Equal(t, "tool.execution.start", events[0].Name) - assert.Equal(t, "exception", events[1].Name) // Added by span.RecordError() - assert.Equal(t, "tool.execution.error", events[2].Name) -} - -func TestWithTracingErrorResult(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler that returns an error result - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("error occurred") - return &mcp.CallToolResult{ - IsError: true, - Content: []mcp.Content{textContent}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - require.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsError) - - // Verify span was created successfully (no error from handler) - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) - assert.Equal(t, codes.Ok, span.Status().Code) - - // Verify attributes - attributes := span.Attributes() - hasIsError := false - hasContentCount := false - - for _, attr := range attributes { - if attr.Key == "mcp.result.is_error" && attr.Value.AsBool() == true { - hasIsError = true - } - if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 1 { - hasContentCount = true - } - } - - assert.True(t, hasIsError) - assert.True(t, hasContentCount) -} - -func TestWithTracingWithArguments(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("test response") - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{textContent}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request with arguments - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - Arguments: map[string]interface{}{ - "string_param": "hello", - "number_param": 42, - "bool_param": true, - "array_param": []interface{}{"a", "b", "c"}, - "object_param": map[string]interface{}{ - "nested": "value", - }, - }, - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - require.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - - // Verify span was created - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) - - // Verify that arguments were added as an attribute (they are JSON-encoded) - attributes := span.Attributes() - hasArguments := false - - for _, attr := range attributes { - if attr.Key == "mcp.request.arguments" { - hasArguments = true - // Arguments should be JSON-encoded - assert.NotEmpty(t, attr.Value.AsString()) - } - } - - assert.True(t, hasArguments) -} - -func TestWithTracingNilArguments(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("test response") - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{textContent}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request without arguments - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - require.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - - // Verify span was created - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) -} - func TestStartSpan(t *testing.T) { - // Initialize OpenTelemetry provider, exporter := setupTracing() defer func() { if err := provider.Shutdown(context.Background()); err != nil { @@ -391,30 +86,22 @@ func TestStartSpan(t *testing.T) { } }() - // Start a span _, span := StartSpan(context.Background(), "test-span", attribute.String("key1", "value1"), attribute.Int("key2", 42), ) - - // End the span span.End() - // Force flush to ensure spans are exported if err := provider.ForceFlush(context.Background()); err != nil { t.Errorf("Failed to flush provider: %v", err) } - // Verify span was created spans := exporter.GetSpans() assert.Len(t, spans, 1) - - resultSpan := spans[0] - assert.Equal(t, "test-span", resultSpan.Name()) + assert.Equal(t, "test-span", spans[0].Name()) } func TestStartSpanNoAttributes(t *testing.T) { - // Initialize OpenTelemetry provider, exporter := setupTracing() defer func() { if err := provider.Shutdown(context.Background()); err != nil { @@ -422,27 +109,19 @@ func TestStartSpanNoAttributes(t *testing.T) { } }() - // Start a span without attributes _, span := StartSpan(context.Background(), "test-span") - - // End the span span.End() - // Force flush to ensure spans are exported if err := provider.ForceFlush(context.Background()); err != nil { t.Errorf("Failed to flush provider: %v", err) } - // Verify span was created spans := exporter.GetSpans() assert.Len(t, spans, 1) - - resultSpan := spans[0] - assert.Equal(t, "test-span", resultSpan.Name()) + assert.Equal(t, "test-span", spans[0].Name()) } func TestRecordError(t *testing.T) { - // Initialize OpenTelemetry provider, exporter := setupTracing() defer func() { if err := provider.Shutdown(context.Background()); err != nil { @@ -450,33 +129,22 @@ func TestRecordError(t *testing.T) { } }() - // Start a span _, span := StartSpan(context.Background(), "test-span") - - // Record an error - testError := errors.New("test error") - RecordError(span, testError, "test error") - - // End the span + RecordError(span, errors.New("test error"), "test error") span.End() - // Force flush to ensure spans are exported if err := provider.ForceFlush(context.Background()); err != nil { t.Errorf("Failed to flush provider: %v", err) } - // Verify span was created with error spans := exporter.GetSpans() assert.Len(t, spans, 1) - - resultSpan := spans[0] - assert.Equal(t, "test-span", resultSpan.Name()) - assert.Equal(t, codes.Error, resultSpan.Status().Code) - assert.Equal(t, "test error", resultSpan.Status().Description) + assert.Equal(t, "test-span", spans[0].Name()) + assert.Equal(t, codes.Error, spans[0].Status().Code) + assert.Equal(t, "test error", spans[0].Status().Description) } func TestRecordSuccess(t *testing.T) { - // Initialize OpenTelemetry provider, exporter := setupTracing() defer func() { if err := provider.Shutdown(context.Background()); err != nil { @@ -484,33 +152,21 @@ func TestRecordSuccess(t *testing.T) { } }() - // Start a span _, span := StartSpan(context.Background(), "test-span") - - // Record success RecordSuccess(span, "operation completed successfully") - - // End the span span.End() - // Force flush to ensure spans are exported if err := provider.ForceFlush(context.Background()); err != nil { t.Errorf("Failed to flush provider: %v", err) } - // Verify span was created with success spans := exporter.GetSpans() assert.Len(t, spans, 1) - - resultSpan := spans[0] - assert.Equal(t, "test-span", resultSpan.Name()) - assert.Equal(t, codes.Ok, resultSpan.Status().Code) - // Note: SDK may not preserve description in test environment - // assert.Equal(t, "operation completed successfully", resultSpan.Status().Description) + assert.Equal(t, "test-span", spans[0].Name()) + assert.Equal(t, codes.Ok, spans[0].Status().Code) } func TestAddEvent(t *testing.T) { - // Initialize OpenTelemetry provider, exporter := setupTracing() defer func() { if err := provider.Shutdown(context.Background()); err != nil { @@ -518,38 +174,25 @@ func TestAddEvent(t *testing.T) { } }() - // Start a span _, span := StartSpan(context.Background(), "test-span") - - // Add an event AddEvent(span, "test-event", attribute.String("event_key", "event_value"), attribute.Int("event_num", 123), ) - - // End the span span.End() - // Force flush to ensure spans are exported if err := provider.ForceFlush(context.Background()); err != nil { t.Errorf("Failed to flush provider: %v", err) } - // Verify span was created with event spans := exporter.GetSpans() assert.Len(t, spans, 1) - - resultSpan := spans[0] - assert.Equal(t, "test-span", resultSpan.Name()) - - // Verify event - events := resultSpan.Events() + events := spans[0].Events() assert.Len(t, events, 1) assert.Equal(t, "test-event", events[0].Name) } func TestAddEventNoAttributes(t *testing.T) { - // Initialize OpenTelemetry provider, exporter := setupTracing() defer func() { if err := provider.Shutdown(context.Background()); err != nil { @@ -557,245 +200,17 @@ func TestAddEventNoAttributes(t *testing.T) { } }() - // Start a span _, span := StartSpan(context.Background(), "test-span") - - // Add an event without attributes AddEvent(span, "test-event") - - // End the span span.End() - // Force flush to ensure spans are exported if err := provider.ForceFlush(context.Background()); err != nil { t.Errorf("Failed to flush provider: %v", err) } - // Verify span was created with event spans := exporter.GetSpans() assert.Len(t, spans, 1) - - resultSpan := spans[0] - assert.Equal(t, "test-span", resultSpan.Name()) - - // Verify event - events := resultSpan.Events() + events := spans[0].Events() assert.Len(t, events, 1) assert.Equal(t, "test-event", events[0].Name) } - -func TestAdaptToolHandler(t *testing.T) { - // Create a test handler - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("test response") - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{textContent}, - }, nil - } - - // Adapt the handler - adapted := AdaptToolHandler(testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the adapted handler - result, err := adapted(context.Background(), request) - - // Verify result - require.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - assert.Len(t, result.Content, 1) - textContent, ok := mcp.AsTextContent(result.Content[0]) - require.True(t, ok) - assert.Equal(t, "test response", textContent.Text) -} - -func TestWithTracingNilResult(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler that returns nil result - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return nil, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - require.NoError(t, err) - assert.Nil(t, result) - - // Verify span was created - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) - assert.Equal(t, codes.Ok, span.Status().Code) -} - -func TestWithTracingNoContent(t *testing.T) { - // Initialize OpenTelemetry - provider, exporter := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler that returns result with no content - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Force flush to ensure spans are exported - if err := provider.ForceFlush(context.Background()); err != nil { - t.Errorf("Failed to flush provider: %v", err) - } - - // Verify result - require.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - assert.Len(t, result.Content, 0) - - // Verify span was created - spans := exporter.GetSpans() - assert.Len(t, spans, 1) - - span := spans[0] - assert.Equal(t, "mcp.tool.test-tool", span.Name()) - assert.Equal(t, codes.Ok, span.Status().Code) - - // Verify attributes - attributes := span.Attributes() - hasContentCount := false - - for _, attr := range attributes { - if attr.Key == "mcp.result.content_count" && attr.Value.AsInt64() == 0 { - hasContentCount = true - } - } - - assert.True(t, hasContentCount) -} - -func TestWithTracingNoopTracer(t *testing.T) { - // Set up noop tracer provider - otel.SetTracerProvider(noop.NewTracerProvider()) - - // Create a test handler - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("test response") - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{textContent}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Execute the handler - result, err := tracedHandler(context.Background(), request) - - // Verify result (should work normally with noop tracer) - require.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsError) - assert.Len(t, result.Content, 1) - textContent, ok := mcp.AsTextContent(result.Content[0]) - require.True(t, ok) - assert.Equal(t, "test response", textContent.Text) -} - -func TestWithTracingPerformance(t *testing.T) { - // Initialize OpenTelemetry - provider, _ := setupTracing() - defer func() { - if err := provider.Shutdown(context.Background()); err != nil { - t.Errorf("Failed to shutdown provider: %v", err) - } - }() - - // Create a test handler - testHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - textContent := mcp.NewTextContent("test response") - return &mcp.CallToolResult{ - IsError: false, - Content: []mcp.Content{textContent}, - }, nil - } - - // Wrap with tracing - tracedHandler := WithTracing("test-tool", testHandler) - - // Create test request - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "test-tool", - }, - } - - // Time execution - start := time.Now() - for i := 0; i < 100; i++ { - _, err := tracedHandler(context.Background(), request) - require.NoError(t, err) - } - duration := time.Since(start) - - // Verify performance is reasonable (should complete in less than 1 second) - assert.Less(t, duration, time.Second) -} diff --git a/pkg/argo/argo.go b/pkg/argo/argo.go index 758a4fb2..01aeaa0e 100644 --- a/pkg/argo/argo.go +++ b/pkg/argo/argo.go @@ -14,36 +14,43 @@ import ( "time" "github.com/kagent-dev/tools/internal/commands" - "github.com/kagent-dev/tools/internal/telemetry" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/pkg/utils" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" ) -// Argo Rollouts tools +type verifyArgoRolloutsControllerInstallInput struct { + Namespace string `json:"namespace" jsonschema:"The namespace where Argo Rollouts is installed"` + Label string `json:"label" jsonschema:"The label of the Argo Rollouts controller pods"` +} -func handleVerifyArgoRolloutsControllerInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - ns := mcp.ParseString(request, "namespace", "argo-rollouts") - label := mcp.ParseString(request, "label", "app.kubernetes.io/component=rollouts-controller") +func handleVerifyArgoRolloutsControllerInstall(ctx context.Context, request *mcp.CallToolRequest, in verifyArgoRolloutsControllerInstallInput) (*mcp.CallToolResult, any, error) { + ns := in.Namespace + if ns == "" { + ns = "argo-rollouts" + } + label := in.Label + if label == "" { + label = "app.kubernetes.io/component=rollouts-controller" + } cmd := []string{"get", "pods", "-n", ns, "-l", label, "-o", "jsonpath={.items[*].status.phase}"} output, err := runArgoRolloutCommand(ctx, cmd) if err != nil { - return mcp.NewToolResultError("Error: " + err.Error()), nil + return mcp.NewToolResultError("Error: " + err.Error()), nil, nil } output = strings.TrimSpace(output) if output == "" { - return mcp.NewToolResultText("Error: No pods found"), nil + return mcp.NewToolResultText("Error: No pods found"), nil, nil } if strings.HasPrefix(output, "Error") { - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } podStatuses := strings.Fields(output) if len(podStatuses) == 0 { - return mcp.NewToolResultText("Error: No pod statuses returned"), nil + return mcp.NewToolResultText("Error: No pod statuses returned"), nil, nil } allRunning := true @@ -55,24 +62,25 @@ func handleVerifyArgoRolloutsControllerInstall(ctx context.Context, request mcp. } if allRunning { - return mcp.NewToolResultText("All pods are running"), nil - } else { - return mcp.NewToolResultText("Error: Not all pods are running (" + strings.Join(podStatuses, " ") + ")"), nil + return mcp.NewToolResultText("All pods are running"), nil, nil } + return mcp.NewToolResultText("Error: Not all pods are running (" + strings.Join(podStatuses, " ") + ")"), nil, nil } -func handleVerifyKubectlPluginInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +type verifyKubectlPluginInstallInput struct{} + +func handleVerifyKubectlPluginInstall(ctx context.Context, request *mcp.CallToolRequest, in verifyKubectlPluginInstallInput) (*mcp.CallToolResult, any, error) { args := []string{"argo", "rollouts", "version"} output, err := runArgoRolloutCommand(ctx, args) if err != nil { - return mcp.NewToolResultText("Kubectl Argo Rollouts plugin is not installed: " + err.Error()), nil + return mcp.NewToolResultText("Kubectl Argo Rollouts plugin is not installed: " + err.Error()), nil, nil } if strings.HasPrefix(output, "Error") { - return mcp.NewToolResultText("Kubectl Argo Rollouts plugin is not installed: " + output), nil + return mcp.NewToolResultText("Kubectl Argo Rollouts plugin is not installed: " + output), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } func runArgoRolloutCommand(ctx context.Context, args []string) (string, error) { @@ -83,81 +91,86 @@ func runArgoRolloutCommand(ctx context.Context, args []string) (string, error) { Execute(ctx) } -func handlePromoteRollout(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - rolloutName := mcp.ParseString(request, "rollout_name", "") - ns := mcp.ParseString(request, "namespace", "") - fullStr := mcp.ParseString(request, "full", "false") - full := fullStr == "true" +type promoteRolloutInput struct { + RolloutName string `json:"rollout_name" jsonschema:"The name of the rollout to promote"` + Namespace string `json:"namespace" jsonschema:"The namespace of the rollout"` + Full bool `json:"full" jsonschema:"Promote the rollout to the final step"` +} - if rolloutName == "" { - return mcp.NewToolResultError("rollout_name parameter is required"), nil +func handlePromoteRollout(ctx context.Context, request *mcp.CallToolRequest, in promoteRolloutInput) (*mcp.CallToolResult, any, error) { + if in.RolloutName == "" { + return mcp.NewToolResultError("rollout_name parameter is required"), nil, nil } cmd := []string{"argo", "rollouts", "promote"} - if ns != "" { - cmd = append(cmd, "-n", ns) + if in.Namespace != "" { + cmd = append(cmd, "-n", in.Namespace) } - cmd = append(cmd, rolloutName) - if full { + cmd = append(cmd, in.RolloutName) + if in.Full { cmd = append(cmd, "--full") } output, err := runArgoRolloutCommand(ctx, cmd) if err != nil { - return mcp.NewToolResultError("Error promoting rollout: " + err.Error()), nil + return mcp.NewToolResultError("Error promoting rollout: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handlePauseRollout(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - rolloutName := mcp.ParseString(request, "rollout_name", "") - ns := mcp.ParseString(request, "namespace", "") +type pauseRolloutInput struct { + RolloutName string `json:"rollout_name" jsonschema:"The name of the rollout to pause"` + Namespace string `json:"namespace" jsonschema:"The namespace of the rollout"` +} - if rolloutName == "" { - return mcp.NewToolResultError("rollout_name parameter is required"), nil +func handlePauseRollout(ctx context.Context, request *mcp.CallToolRequest, in pauseRolloutInput) (*mcp.CallToolResult, any, error) { + if in.RolloutName == "" { + return mcp.NewToolResultError("rollout_name parameter is required"), nil, nil } cmd := []string{"argo", "rollouts", "pause"} - if ns != "" { - cmd = append(cmd, "-n", ns) + if in.Namespace != "" { + cmd = append(cmd, "-n", in.Namespace) } - cmd = append(cmd, rolloutName) + cmd = append(cmd, in.RolloutName) output, err := runArgoRolloutCommand(ctx, cmd) if err != nil { - return mcp.NewToolResultError("Error pausing rollout: " + err.Error()), nil + return mcp.NewToolResultError("Error pausing rollout: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleSetRolloutImage(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - rolloutName := mcp.ParseString(request, "rollout_name", "") - containerImage := mcp.ParseString(request, "container_image", "") - ns := mcp.ParseString(request, "namespace", "") +type setRolloutImageInput struct { + RolloutName string `json:"rollout_name" jsonschema:"The name of the rollout to set the image for"` + ContainerImage string `json:"container_image" jsonschema:"The container image to set for the rollout"` + Namespace string `json:"namespace" jsonschema:"The namespace of the rollout"` +} - if rolloutName == "" { - return mcp.NewToolResultError("rollout_name parameter is required"), nil +func handleSetRolloutImage(ctx context.Context, request *mcp.CallToolRequest, in setRolloutImageInput) (*mcp.CallToolResult, any, error) { + if in.RolloutName == "" { + return mcp.NewToolResultError("rollout_name parameter is required"), nil, nil } - if containerImage == "" { - return mcp.NewToolResultError("container_image parameter is required"), nil + if in.ContainerImage == "" { + return mcp.NewToolResultError("container_image parameter is required"), nil, nil } - cmd := []string{"argo", "rollouts", "set", "image", rolloutName, containerImage} - if ns != "" { - cmd = append(cmd, "-n", ns) + cmd := []string{"argo", "rollouts", "set", "image", in.RolloutName, in.ContainerImage} + if in.Namespace != "" { + cmd = append(cmd, "-n", in.Namespace) } output, err := runArgoRolloutCommand(ctx, cmd) if err != nil { - return mcp.NewToolResultError("Error setting rollout image: " + err.Error()), nil + return mcp.NewToolResultError("Error setting rollout image: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -// Gateway Plugin Status struct +// GatewayPluginStatus struct type GatewayPluginStatus struct { Installed bool `json:"installed"` Version string `json:"version,omitempty"` @@ -284,11 +297,22 @@ data: } } -func handleVerifyGatewayPlugin(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - version := mcp.ParseString(request, "version", "") - namespace := mcp.ParseString(request, "namespace", "argo-rollouts") - shouldInstallStr := mcp.ParseString(request, "should_install", "true") - shouldInstall := shouldInstallStr == "true" +type verifyGatewayPluginInput struct { + Version string `json:"version" jsonschema:"The version of the plugin to check"` + Namespace string `json:"namespace" jsonschema:"The namespace for the plugin resources"` + ShouldInstall *bool `json:"should_install" jsonschema:"Whether to install the plugin if not found"` +} + +func handleVerifyGatewayPlugin(ctx context.Context, request *mcp.CallToolRequest, in verifyGatewayPluginInput) (*mcp.CallToolResult, any, error) { + version := in.Version + namespace := in.Namespace + if namespace == "" { + namespace = "argo-rollouts" + } + shouldInstall := true + if in.ShouldInstall != nil { + shouldInstall = *in.ShouldInstall + } // Check if ConfigMap exists and is configured cmd := []string{"get", "configmap", "argo-rollouts-config", "-n", namespace, "-o", "yaml"} @@ -298,7 +322,7 @@ func handleVerifyGatewayPlugin(ctx context.Context, request mcp.CallToolRequest) Installed: true, ErrorMessage: "Gateway API plugin is already configured", } - return mcp.NewToolResultText(status.String()), nil + return mcp.NewToolResultText(status.String()), nil, nil } if !shouldInstall { @@ -306,18 +330,29 @@ func handleVerifyGatewayPlugin(ctx context.Context, request mcp.CallToolRequest) Installed: false, ErrorMessage: "Gateway API plugin is not configured and installation is disabled", } - return mcp.NewToolResultText(status.String()), nil + return mcp.NewToolResultText(status.String()), nil, nil } // Configure plugin status := configureGatewayPlugin(ctx, version, namespace) - return mcp.NewToolResultText(status.String()), nil + return mcp.NewToolResultText(status.String()), nil, nil +} + +type checkPluginLogsInput struct { + Namespace string `json:"namespace" jsonschema:"The namespace of the plugin resources"` + Timeout int `json:"timeout" jsonschema:"Timeout for log collection in seconds"` } -func handleCheckPluginLogs(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "argo-rollouts") +func handleCheckPluginLogs(ctx context.Context, request *mcp.CallToolRequest, in checkPluginLogsInput) (*mcp.CallToolResult, any, error) { + namespace := in.Namespace + if namespace == "" { + namespace = "argo-rollouts" + } // timeout parameter is parsed but not used currently - _ = mcp.ParseString(request, "timeout", "60") + if in.Timeout == 0 { + in.Timeout = 60 + } + _ = in.Timeout cmd := []string{"logs", "-n", namespace, "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"} output, err := runArgoRolloutCommand(ctx, cmd) @@ -326,7 +361,7 @@ func handleCheckPluginLogs(ctx context.Context, request mcp.CallToolRequest) (*m Installed: false, ErrorMessage: err.Error(), } - return mcp.NewToolResultText(status.String()), nil + return mcp.NewToolResultText(status.String()), nil, nil } // Parse download information @@ -344,19 +379,30 @@ func handleCheckPluginLogs(ctx context.Context, request mcp.CallToolRequest) (*m Architecture: versionMatches[2], DownloadTime: downloadTime, } - return mcp.NewToolResultText(status.String()), nil + return mcp.NewToolResultText(status.String()), nil, nil } status := GatewayPluginStatus{ Installed: false, ErrorMessage: "Plugin installation not found in logs", } - return mcp.NewToolResultText(status.String()), nil + return mcp.NewToolResultText(status.String()), nil, nil +} + +type listRolloutsInput struct { + Namespace string `json:"namespace" jsonschema:"The namespace of the rollout"` + Type string `json:"type" jsonschema:"What to list: rollouts or experiments"` } -func handleListRollouts(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - ns := mcp.ParseString(request, "namespace", "argo-rollouts") - tt := mcp.ParseString(request, "type", "rollouts") +func handleListRollouts(ctx context.Context, request *mcp.CallToolRequest, in listRolloutsInput) (*mcp.CallToolResult, any, error) { + ns := in.Namespace + if ns == "" { + ns = "argo-rollouts" + } + tt := in.Type + if tt == "" { + tt = "rollouts" + } cmd := []string{"argo", "rollouts", "list", tt} if ns != "" { @@ -365,67 +411,58 @@ func handleListRollouts(ctx context.Context, request mcp.CallToolRequest) (*mcp. output, err := runArgoRolloutCommand(ctx, cmd) if err != nil { - return mcp.NewToolResultError("Error listing rollouts: " + err.Error()), nil + return mcp.NewToolResultError("Error listing rollouts: " + err.Error()), nil, nil } if strings.HasPrefix(output, "Error") { - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func RegisterTools(s *server.MCPServer, readOnly bool) { +func RegisterTools(s *mcp.Server, readOnly bool) { // Read-only tools - always registered - s.AddTool(mcp.NewTool("argo_verify_argo_rollouts_controller_install", - mcp.WithDescription("Verify that the Argo Rollouts controller is installed and running"), - mcp.WithString("namespace", mcp.Description("The namespace where Argo Rollouts is installed")), - mcp.WithString("label", mcp.Description("The label of the Argo Rollouts controller pods")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_argo_rollouts_controller_install", handleVerifyArgoRolloutsControllerInstall))) - - s.AddTool(mcp.NewTool("argo_verify_kubectl_plugin_install", - mcp.WithDescription("Verify that the kubectl Argo Rollouts plugin is installed"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_kubectl_plugin_install", handleVerifyKubectlPluginInstall))) - - s.AddTool(mcp.NewTool("argo_rollouts_list", - mcp.WithDescription("List rollouts or experiments"), - mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), - mcp.WithString("type", mcp.Description("What to list: rollouts or experiments"), mcp.DefaultString("rollouts")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_rollouts_list", handleListRollouts))) - - s.AddTool(mcp.NewTool("argo_check_plugin_logs", - mcp.WithDescription("Check the logs of the Argo Rollouts Gateway API plugin"), - mcp.WithString("namespace", mcp.Description("The namespace of the plugin resources")), - mcp.WithString("timeout", mcp.Description("Timeout for log collection in seconds")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_check_plugin_logs", handleCheckPluginLogs))) + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_verify_argo_rollouts_controller_install", + Description: "Verify that the Argo Rollouts controller is installed and running", + }, handleVerifyArgoRolloutsControllerInstall) + + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_verify_kubectl_plugin_install", + Description: "Verify that the kubectl Argo Rollouts plugin is installed", + }, handleVerifyKubectlPluginInstall) + + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_rollouts_list", + Description: "List rollouts or experiments", + }, handleListRollouts) + + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_check_plugin_logs", + Description: "Check the logs of the Argo Rollouts Gateway API plugin", + }, handleCheckPluginLogs) // Write tools - only registered when not in read-only mode if !readOnly { - s.AddTool(mcp.NewTool("argo_promote_rollout", - mcp.WithDescription("Promote a paused rollout to the next step"), - mcp.WithString("rollout_name", mcp.Description("The name of the rollout to promote"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), - mcp.WithString("full", mcp.Description("Promote the rollout to the final step")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_promote_rollout", handlePromoteRollout))) - - s.AddTool(mcp.NewTool("argo_pause_rollout", - mcp.WithDescription("Pause a rollout"), - mcp.WithString("rollout_name", mcp.Description("The name of the rollout to pause"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_pause_rollout", handlePauseRollout))) - - s.AddTool(mcp.NewTool("argo_set_rollout_image", - mcp.WithDescription("Set the image of a rollout"), - mcp.WithString("rollout_name", mcp.Description("The name of the rollout to set the image for"), mcp.Required()), - mcp.WithString("container_image", mcp.Description("The container image to set for the rollout"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the rollout")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_set_rollout_image", handleSetRolloutImage))) - - s.AddTool(mcp.NewTool("argo_verify_gateway_plugin", - mcp.WithDescription("Verify the installation status of the Argo Rollouts Gateway API plugin"), - mcp.WithString("version", mcp.Description("The version of the plugin to check")), - mcp.WithString("namespace", mcp.Description("The namespace for the plugin resources")), - mcp.WithString("should_install", mcp.Description("Whether to install the plugin if not found")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("argo_verify_gateway_plugin", handleVerifyGatewayPlugin))) + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_promote_rollout", + Description: "Promote a paused rollout to the next step", + }, handlePromoteRollout) + + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_pause_rollout", + Description: "Pause a rollout", + }, handlePauseRollout) + + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_set_rollout_image", + Description: "Set the image of a rollout", + }, handleSetRolloutImage) + + mcp.AddTool(s, "argo", &mcp.Tool{ + Name: "argo_verify_gateway_plugin", + Description: "Verify the installation status of the Argo Rollouts Gateway API plugin", + }, handleVerifyGatewayPlugin) } } diff --git a/pkg/argo/argo_test.go b/pkg/argo/argo_test.go index ce00d7b8..148ef1b2 100644 --- a/pkg/argo/argo_test.go +++ b/pkg/argo/argo_test.go @@ -6,19 +6,18 @@ import ( "testing" "github.com/kagent-dev/tools/internal/cmd" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterTools(t *testing.T) { t.Run("read-write", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, false) }) t.Run("read-only", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, true) }) } @@ -29,7 +28,7 @@ func TestHandleListRollouts(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "list", "rollouts", "-n", "argo-rollouts"}, "NAME STATUS\nmyapp Healthy", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleListRollouts(ctx, mcp.CallToolRequest{}) + result, _, err := handleListRollouts(ctx, &mcp.CallToolRequest{}, listRolloutsInput{}) assert.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "myapp") @@ -40,9 +39,7 @@ func TestHandleListRollouts(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "list", "experiments", "-n", "prod"}, "NAME", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"type": "experiments", "namespace": "prod"} - result, err := handleListRollouts(ctx, req) + result, _, err := handleListRollouts(ctx, &mcp.CallToolRequest{}, listRolloutsInput{Type: "experiments", Namespace: "prod"}) assert.NoError(t, err) assert.False(t, result.IsError) }) @@ -52,7 +49,7 @@ func TestHandleListRollouts(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "list", "rollouts", "-n", "argo-rollouts"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleListRollouts(ctx, mcp.CallToolRequest{}) + result, _, err := handleListRollouts(ctx, &mcp.CallToolRequest{}, listRolloutsInput{}) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "Error listing rollouts") @@ -67,7 +64,7 @@ Download complete, it took 1.5s` mock.AddCommandString("kubectl", []string{"logs", "-n", "argo-rollouts", "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"}, logs, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleCheckPluginLogs(ctx, mcp.CallToolRequest{}) + result, _, err := handleCheckPluginLogs(ctx, &mcp.CallToolRequest{}, checkPluginLogsInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), "0.5.0") assert.Contains(t, getResultText(result), `"installed": true`) @@ -78,7 +75,7 @@ Download complete, it took 1.5s` mock.AddCommandString("kubectl", []string{"logs", "-n", "argo-rollouts", "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"}, "no plugin here", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleCheckPluginLogs(ctx, mcp.CallToolRequest{}) + result, _, err := handleCheckPluginLogs(ctx, &mcp.CallToolRequest{}, checkPluginLogsInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), "Plugin installation not found") }) @@ -88,7 +85,7 @@ Download complete, it took 1.5s` mock.AddCommandString("kubectl", []string{"logs", "-n", "argo-rollouts", "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleCheckPluginLogs(ctx, mcp.CallToolRequest{}) + result, _, err := handleCheckPluginLogs(ctx, &mcp.CallToolRequest{}, checkPluginLogsInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), `"installed": false`) }) @@ -122,7 +119,7 @@ func TestHandleVerifyGatewayPluginAlreadyConfigured(t *testing.T) { mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml"}, "data:\n trafficRouterPlugins: argoproj-labs/gatewayAPI", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleVerifyGatewayPlugin(ctx, mcp.CallToolRequest{}) + result, _, err := handleVerifyGatewayPlugin(ctx, &mcp.CallToolRequest{}, verifyGatewayPluginInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), "already configured") } @@ -134,7 +131,7 @@ func TestHandleVerifyArgoRolloutsControllerInstallStatuses(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("kubectl", baseCmd, "Running Running", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), "All pods are running") }) @@ -143,7 +140,7 @@ func TestHandleVerifyArgoRolloutsControllerInstallStatuses(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("kubectl", baseCmd, "Running Pending", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), "Not all pods are running") }) @@ -152,7 +149,7 @@ func TestHandleVerifyArgoRolloutsControllerInstallStatuses(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("kubectl", baseCmd, "", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{}) assert.NoError(t, err) assert.Contains(t, getResultText(result), "No pods found") }) @@ -161,7 +158,7 @@ func TestHandleVerifyArgoRolloutsControllerInstallStatuses(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("kubectl", baseCmd, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{}) assert.NoError(t, err) assert.True(t, result.IsError) }) @@ -172,7 +169,7 @@ func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { return "" } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { return textContent.Text } return "" @@ -189,12 +186,7 @@ func TestHandlePromoteRollout(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - } - - result, err := handlePromoteRollout(ctx, request) + result, _, err := handlePromoteRollout(ctx, &mcp.CallToolRequest{}, promoteRolloutInput{RolloutName: "myapp"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -215,13 +207,7 @@ func TestHandlePromoteRollout(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "-n", "production", "myapp"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - "namespace": "production", - } - - result, err := handlePromoteRollout(ctx, request) + result, _, err := handlePromoteRollout(ctx, &mcp.CallToolRequest{}, promoteRolloutInput{RolloutName: "myapp", Namespace: "production"}) assert.NoError(t, err) assert.False(t, result.IsError) @@ -240,13 +226,7 @@ func TestHandlePromoteRollout(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp", "--full"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - "full": "true", - } - - result, err := handlePromoteRollout(ctx, request) + result, _, err := handlePromoteRollout(ctx, &mcp.CallToolRequest{}, promoteRolloutInput{RolloutName: "myapp", Full: true}) assert.NoError(t, err) assert.False(t, result.IsError) @@ -262,12 +242,7 @@ func TestHandlePromoteRollout(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing rollout_name - } - - result, err := handlePromoteRollout(ctx, request) + result, _, err := handlePromoteRollout(ctx, &mcp.CallToolRequest{}, promoteRolloutInput{}) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "rollout_name parameter is required") @@ -282,12 +257,7 @@ func TestHandlePromoteRollout(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "promote", "myapp"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - } - - result, err := handlePromoteRollout(ctx, request) + result, _, err := handlePromoteRollout(ctx, &mcp.CallToolRequest{}, promoteRolloutInput{RolloutName: "myapp"}) assert.NoError(t, err) // MCP handlers should not return Go errors assert.True(t, result.IsError) @@ -304,12 +274,7 @@ func TestHandlePauseRollout(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "myapp"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - } - - result, err := handlePauseRollout(ctx, request) + result, _, err := handlePauseRollout(ctx, &mcp.CallToolRequest{}, pauseRolloutInput{RolloutName: "myapp"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -333,13 +298,7 @@ func TestHandlePauseRollout(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "pause", "-n", "production", "myapp"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - "namespace": "production", - } - - result, err := handlePauseRollout(ctx, request) + result, _, err := handlePauseRollout(ctx, &mcp.CallToolRequest{}, pauseRolloutInput{RolloutName: "myapp", Namespace: "production"}) assert.NoError(t, err) assert.False(t, result.IsError) @@ -355,12 +314,7 @@ func TestHandlePauseRollout(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - // Missing rollout_name - } - - result, err := handlePauseRollout(ctx, request) + result, _, err := handlePauseRollout(ctx, &mcp.CallToolRequest{}, pauseRolloutInput{}) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "rollout_name parameter is required") @@ -380,13 +334,7 @@ func TestHandleSetRolloutImage(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:latest"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - "container_image": "nginx:latest", - } - - result, err := handleSetRolloutImage(ctx, request) + result, _, err := handleSetRolloutImage(ctx, &mcp.CallToolRequest{}, setRolloutImageInput{RolloutName: "myapp", ContainerImage: "nginx:latest"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -410,14 +358,7 @@ func TestHandleSetRolloutImage(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "set", "image", "myapp", "nginx:1.20", "-n", "production"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - "container_image": "nginx:1.20", - "namespace": "production", - } - - result, err := handleSetRolloutImage(ctx, request) + result, _, err := handleSetRolloutImage(ctx, &mcp.CallToolRequest{}, setRolloutImageInput{RolloutName: "myapp", ContainerImage: "nginx:1.20", Namespace: "production"}) assert.NoError(t, err) assert.False(t, result.IsError) @@ -433,13 +374,7 @@ func TestHandleSetRolloutImage(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "container_image": "nginx:latest", - // Missing rollout_name - } - - result, err := handleSetRolloutImage(ctx, request) + result, _, err := handleSetRolloutImage(ctx, &mcp.CallToolRequest{}, setRolloutImageInput{ContainerImage: "nginx:latest"}) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "rollout_name parameter is required") @@ -453,13 +388,7 @@ func TestHandleSetRolloutImage(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "rollout_name": "myapp", - // Missing container_image - } - - result, err := handleSetRolloutImage(ctx, request) + result, _, err := handleSetRolloutImage(ctx, &mcp.CallToolRequest{}, setRolloutImageInput{RolloutName: "myapp"}) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "container_image parameter is required") @@ -526,12 +455,8 @@ func TestHandleVerifyGatewayPlugin(t *testing.T) { mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "should_install": "false", - } - - result, err := handleVerifyGatewayPlugin(ctx, request) + shouldInstall := false + result, _, err := handleVerifyGatewayPlugin(ctx, &mcp.CallToolRequest{}, verifyGatewayPluginInput{ShouldInstall: &shouldInstall}) assert.NoError(t, err) assert.NotNil(t, result) @@ -553,13 +478,8 @@ func TestHandleVerifyGatewayPlugin(t *testing.T) { mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "custom-namespace", "-o", "yaml"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "should_install": "false", - "namespace": "custom-namespace", - } - - result, err := handleVerifyGatewayPlugin(ctx, request) + shouldInstall := false + result, _, err := handleVerifyGatewayPlugin(ctx, &mcp.CallToolRequest{}, verifyGatewayPluginInput{ShouldInstall: &shouldInstall, Namespace: "custom-namespace"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -582,8 +502,7 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, request) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -603,12 +522,7 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app.kubernetes.io/name=argo-rollouts", "-n", "custom-argo", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "custom-argo", - } - - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, request) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{Namespace: "custom-argo"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -628,12 +542,7 @@ func TestHandleVerifyArgoRolloutsControllerInstall(t *testing.T) { mock.AddCommandString("kubectl", []string{"get", "pods", "-l", "app=custom-rollouts", "-n", "argo-rollouts", "-o", "jsonpath={.items[*].metadata.name}"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "label": "app=custom-rollouts", - } - - result, err := handleVerifyArgoRolloutsControllerInstall(ctx, request) + result, _, err := handleVerifyArgoRolloutsControllerInstall(ctx, &mcp.CallToolRequest{}, verifyArgoRolloutsControllerInstallInput{Label: "app=custom-rollouts"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -656,8 +565,7 @@ func TestHandleVerifyKubectlPluginInstall(t *testing.T) { mock.AddCommandString("kubectl", []string{"argo", "rollouts", "version"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - result, err := handleVerifyKubectlPluginInstall(ctx, request) + result, _, err := handleVerifyKubectlPluginInstall(ctx, &mcp.CallToolRequest{}, verifyKubectlPluginInstallInput{}) assert.NoError(t, err) assert.False(t, result.IsError) @@ -674,8 +582,7 @@ func TestHandleVerifyKubectlPluginInstall(t *testing.T) { mock.AddCommandString("kubectl", []string{"plugin", "list"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - result, err := handleVerifyKubectlPluginInstall(ctx, request) + result, _, err := handleVerifyKubectlPluginInstall(ctx, &mcp.CallToolRequest{}, verifyKubectlPluginInstallInput{}) assert.NoError(t, err) // MCP handlers should not return Go errors assert.NotNil(t, result) diff --git a/pkg/cilium/cilium.go b/pkg/cilium/cilium.go index b92a8f1c..af146b31 100644 --- a/pkg/cilium/cilium.go +++ b/pkg/cilium/cilium.go @@ -6,13 +6,211 @@ import ( "strings" "github.com/kagent-dev/tools/internal/commands" - "github.com/kagent-dev/tools/internal/telemetry" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/pkg/utils" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" ) +type noInput struct{} + +type nodeNameInput struct { + NodeName string `json:"node_name" jsonschema:"The name of the node to run the command on"` +} + +type upgradeCiliumInput struct { + ClusterName string `json:"cluster_name" jsonschema:"The name of the cluster to upgrade Cilium on"` + DatapathMode string `json:"datapath_mode" jsonschema:"The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)"` +} + +type installCiliumInput struct { + ClusterName string `json:"cluster_name" jsonschema:"The name of the cluster to install Cilium on"` + ClusterID string `json:"cluster_id" jsonschema:"The ID of the cluster to install Cilium on"` + DatapathMode string `json:"datapath_mode" jsonschema:"The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)"` +} + +type connectToRemoteClusterInput struct { + ClusterName string `json:"cluster_name" jsonschema:"The name of the destination cluster"` + Context string `json:"context" jsonschema:"The kubectl context for the destination cluster"` +} + +type disconnectRemoteClusterInput struct { + ClusterName string `json:"cluster_name" jsonschema:"The name of the destination cluster"` +} + +type enableToggleInput struct { + Enable *bool `json:"enable" jsonschema:"Set to true to enable, false to disable"` +} + +type getDaemonStatusInput struct { + ShowAllAddresses bool `json:"show_all_addresses" jsonschema:"Whether to show all addresses"` + ShowAllClusters bool `json:"show_all_clusters" jsonschema:"Whether to show all clusters"` + ShowAllControllers bool `json:"show_all_controllers" jsonschema:"Whether to show all controllers"` + ShowHealth bool `json:"show_health" jsonschema:"Whether to show health"` + ShowAllNodes bool `json:"show_all_nodes" jsonschema:"Whether to show all nodes"` + ShowAllRedirects bool `json:"show_all_redirects" jsonschema:"Whether to show all redirects"` + Brief bool `json:"brief" jsonschema:"Whether to show a brief status"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the daemon status for"` +} + +type getEndpointDetailsInput struct { + EndpointID string `json:"endpoint_id" jsonschema:"The ID of the endpoint to get details for"` + Labels string `json:"labels" jsonschema:"The labels of the endpoint to get details for"` + OutputFormat string `json:"output_format" jsonschema:"The output format of the endpoint details (json, yaml, jsonpath)"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the endpoint details for"` +} + +type getEndpointLogsInput struct { + EndpointID string `json:"endpoint_id" jsonschema:"The ID of the endpoint to get logs for"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the endpoint logs for"` +} + +type getEndpointHealthInput struct { + EndpointID string `json:"endpoint_id" jsonschema:"The ID of the endpoint to get health for"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the endpoint health for"` +} + +type manageEndpointLabelsInput struct { + EndpointID string `json:"endpoint_id" jsonschema:"The ID of the endpoint to manage labels for"` + Labels string `json:"labels" jsonschema:"Space-separated labels to manage (e.g., 'key1=value1 key2=value2')"` + Action string `json:"action" jsonschema:"The action to perform on the labels (add or delete)"` + NodeName string `json:"node_name" jsonschema:"The name of the node to manage the endpoint labels on"` +} + +type manageEndpointConfigurationInput struct { + EndpointID string `json:"endpoint_id" jsonschema:"The ID of the endpoint to manage configuration for"` + Config string `json:"config" jsonschema:"The configuration to manage for the endpoint provided as a space-separated list of key-value pairs (e.g. 'DropNotification=false TraceNotification=false')"` + NodeName string `json:"node_name" jsonschema:"The name of the node to manage the endpoint configuration on"` +} + +type disconnectEndpointInput struct { + EndpointID string `json:"endpoint_id" jsonschema:"The ID of the endpoint to disconnect"` + NodeName string `json:"node_name" jsonschema:"The name of the node to disconnect the endpoint from"` +} + +type showConfigurationOptionsInput struct { + ListAll bool `json:"list_all" jsonschema:"Whether to list all configuration options"` + ListReadOnly bool `json:"list_read_only" jsonschema:"Whether to list read-only configuration options"` + ListOptions bool `json:"list_options" jsonschema:"Whether to list options"` + NodeName string `json:"node_name" jsonschema:"The name of the node to show the configuration options for"` +} + +type toggleConfigurationOptionInput struct { + Option string `json:"option" jsonschema:"The option to toggle"` + Value *bool `json:"value" jsonschema:"The value to set the option to (true/false)"` + NodeName string `json:"node_name" jsonschema:"The name of the node to toggle the configuration option for"` +} + +type getIdentityDetailsInput struct { + IdentityID string `json:"identity_id" jsonschema:"The ID of the identity to get details for"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the identity details for"` +} + +type listEnvoyConfigInput struct { + ResourceName string `json:"resource_name" jsonschema:"The name of the resource to get the Envoy configuration for"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the Envoy configuration for"` +} + +type fqdnCacheInput struct { + Command string `json:"command" jsonschema:"The command to perform on the FQDN cache (list, clean, or a specific command)"` + NodeName string `json:"node_name" jsonschema:"The name of the node to manage the FQDN cache for"` +} + +type showIPCacheInformationInput struct { + CIDR string `json:"cidr" jsonschema:"The CIDR of the IP to get cache information for"` + Labels string `json:"labels" jsonschema:"The labels of the IP to get cache information for"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the IP cache information for"` +} + +type kvStoreKeyInput struct { + Key string `json:"key" jsonschema:"The key in the kvstore"` + NodeName string `json:"node_name" jsonschema:"The name of the node to run the kvstore command on"` +} + +type setKVStoreKeyInput struct { + Key string `json:"key" jsonschema:"The key to set in the kvstore"` + Value string `json:"value" jsonschema:"The value to set in the kvstore"` + NodeName string `json:"node_name" jsonschema:"The name of the node to set the key in"` +} + +type bpfMapInput struct { + MapName string `json:"map_name" jsonschema:"The name of the BPF map"` + NodeName string `json:"node_name" jsonschema:"The name of the node to run the BPF map command on"` +} + +type listMetricsInput struct { + MatchPattern string `json:"match_pattern" jsonschema:"The match pattern to filter metrics by"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the metrics for"` +} + +type displayPolicyNodeInformationInput struct { + Labels string `json:"labels" jsonschema:"The labels to get policy node information for"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get policy node information for"` +} + +type deletePolicyRulesInput struct { + Labels string `json:"labels" jsonschema:"The labels to delete policy rules for"` + All bool `json:"all" jsonschema:"Whether to delete all policy rules"` + NodeName string `json:"node_name" jsonschema:"The name of the node to delete policy rules for"` +} + +type xdpCIDRFiltersInput struct { + CIDRPrefixes string `json:"cidr_prefixes" jsonschema:"The CIDR prefixes for the XDP filters"` + Revision string `json:"revision" jsonschema:"The revision of the XDP filters"` + NodeName string `json:"node_name" jsonschema:"The name of the node to run the XDP filter command on"` +} + +type validateCiliumNetworkPoliciesInput struct { + EnableK8s bool `json:"enable_k8s" jsonschema:"Whether to enable k8s API discovery"` + EnableK8sAPIDiscovery bool `json:"enable_k8s_api_discovery" jsonschema:"Whether to enable k8s API discovery"` + NodeName string `json:"node_name" jsonschema:"The name of the node to validate the Cilium network policies for"` +} + +type pcapRecorderIDInput struct { + RecorderID string `json:"recorder_id" jsonschema:"The ID of the PCAP recorder"` + NodeName string `json:"node_name" jsonschema:"The name of the node to run the PCAP recorder command on"` +} + +type updatePCAPRecorderInput struct { + RecorderID string `json:"recorder_id" jsonschema:"The ID of the PCAP recorder to update"` + Filters string `json:"filters" jsonschema:"The filters to update the PCAP recorder with"` + Caplen string `json:"caplen" jsonschema:"The caplen to update the PCAP recorder with"` + ID string `json:"id" jsonschema:"The id to update the PCAP recorder with"` + NodeName string `json:"node_name" jsonschema:"The name of the node to update the PCAP recorder on"` +} + +type listServicesInput struct { + ShowClusterMeshAffinity bool `json:"show_cluster_mesh_affinity" jsonschema:"Whether to show cluster mesh affinity"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the services for"` +} + +type getServiceInformationInput struct { + ServiceID string `json:"service_id" jsonschema:"The ID of the service to get information about"` + NodeName string `json:"node_name" jsonschema:"The name of the node to get the service information for"` +} + +type deleteServiceInput struct { + ServiceID string `json:"service_id" jsonschema:"The ID of the service to delete"` + All bool `json:"all" jsonschema:"Whether to delete all services"` + NodeName string `json:"node_name" jsonschema:"The name of the node to delete the service from"` +} + +type updateServiceInput struct { + BackendWeights string `json:"backend_weights" jsonschema:"The backend weights to update the service with"` + Backends string `json:"backends" jsonschema:"The backends to update the service with"` + Frontend string `json:"frontend" jsonschema:"The frontend to update the service with"` + ID string `json:"id" jsonschema:"The ID of the service to update"` + K8sClusterInternal bool `json:"k8s_cluster_internal" jsonschema:"Whether to update the k8s cluster internal flag"` + K8sExtTrafficPolicy string `json:"k8s_ext_traffic_policy" jsonschema:"The k8s ext traffic policy to update the service with"` + K8sExternal bool `json:"k8s_external" jsonschema:"Whether to update the k8s external flag"` + K8sHostPort bool `json:"k8s_host_port" jsonschema:"Whether to update the k8s host port flag"` + K8sIntTrafficPolicy string `json:"k8s_int_traffic_policy" jsonschema:"The k8s int traffic policy to update the service with"` + K8sLoadBalancer bool `json:"k8s_load_balancer" jsonschema:"Whether to update the k8s load balancer flag"` + K8sNodePort bool `json:"k8s_node_port" jsonschema:"Whether to update the k8s node port flag"` + LocalRedirect bool `json:"local_redirect" jsonschema:"Whether to update the local redirect flag"` + Protocol string `json:"protocol" jsonschema:"The protocol to update the service with"` + States string `json:"states" jsonschema:"The states to update the service with"` + NodeName string `json:"node_name" jsonschema:"The name of the node to update the service on"` +} + func runCiliumCliWithContext(ctx context.Context, args ...string) (string, error) { kubeconfigPath := utils.GetKubeconfig() return commands.NewCommandBuilder("cilium"). @@ -21,24 +219,24 @@ func runCiliumCliWithContext(ctx context.Context, args ...string) (string, error Execute(ctx) } -func handleCiliumStatusAndVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleCiliumStatusAndVersion(ctx context.Context, request *mcp.CallToolRequest, in noInput) (*mcp.CallToolResult, any, error) { status, err := runCiliumCliWithContext(ctx, "status") if err != nil { - return mcp.NewToolResultError("Error getting Cilium status: " + err.Error()), nil + return mcp.NewToolResultError("Error getting Cilium status: " + err.Error()), nil, nil } version, err := runCiliumCliWithContext(ctx, "version") if err != nil { - return mcp.NewToolResultError("Error getting Cilium version: " + err.Error()), nil + return mcp.NewToolResultError("Error getting Cilium version: " + err.Error()), nil, nil } result := status + "\n" + version - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -func handleUpgradeCilium(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName := mcp.ParseString(request, "cluster_name", "") - datapathMode := mcp.ParseString(request, "datapath_mode", "") +func handleUpgradeCilium(ctx context.Context, request *mcp.CallToolRequest, in upgradeCiliumInput) (*mcp.CallToolResult, any, error) { + clusterName := in.ClusterName + datapathMode := in.DatapathMode args := []string{"upgrade"} if clusterName != "" { @@ -50,16 +248,16 @@ func handleUpgradeCilium(ctx context.Context, request mcp.CallToolRequest) (*mcp output, err := runCiliumCliWithContext(ctx, args...) if err != nil { - return mcp.NewToolResultError("Error upgrading Cilium: " + err.Error()), nil + return mcp.NewToolResultError("Error upgrading Cilium: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleInstallCilium(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName := mcp.ParseString(request, "cluster_name", "") - clusterID := mcp.ParseString(request, "cluster_id", "") - datapathMode := mcp.ParseString(request, "datapath_mode", "") +func handleInstallCilium(ctx context.Context, request *mcp.CallToolRequest, in installCiliumInput) (*mcp.CallToolResult, any, error) { + clusterName := in.ClusterName + clusterID := in.ClusterID + datapathMode := in.DatapathMode args := []string{"install"} if clusterName != "" { @@ -74,99 +272,100 @@ func handleInstallCilium(ctx context.Context, request mcp.CallToolRequest) (*mcp output, err := runCiliumCliWithContext(ctx, args...) if err != nil { - return mcp.NewToolResultError("Error installing Cilium: " + err.Error()), nil + return mcp.NewToolResultError("Error installing Cilium: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleUninstallCilium(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleUninstallCilium(ctx context.Context, request *mcp.CallToolRequest, in noInput) (*mcp.CallToolResult, any, error) { output, err := runCiliumCliWithContext(ctx, "uninstall") if err != nil { - return mcp.NewToolResultError("Error uninstalling Cilium: " + err.Error()), nil + return mcp.NewToolResultError("Error uninstalling Cilium: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleConnectToRemoteCluster(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName := mcp.ParseString(request, "cluster_name", "") - context := mcp.ParseString(request, "context", "") +func handleConnectToRemoteCluster(ctx context.Context, request *mcp.CallToolRequest, in connectToRemoteClusterInput) (*mcp.CallToolResult, any, error) { + clusterName := in.ClusterName + destContext := in.Context if clusterName == "" { - return mcp.NewToolResultError("cluster_name parameter is required"), nil + return mcp.NewToolResultError("cluster_name parameter is required"), nil, nil } args := []string{"clustermesh", "connect", "--destination-cluster", clusterName} - if context != "" { - args = append(args, "--destination-context", context) + if destContext != "" { + args = append(args, "--destination-context", destContext) } output, err := runCiliumCliWithContext(ctx, args...) if err != nil { - return mcp.NewToolResultError("Error connecting to remote cluster: " + err.Error()), nil + return mcp.NewToolResultError("Error connecting to remote cluster: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDisconnectRemoteCluster(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - clusterName := mcp.ParseString(request, "cluster_name", "") +func handleDisconnectRemoteCluster(ctx context.Context, request *mcp.CallToolRequest, in disconnectRemoteClusterInput) (*mcp.CallToolResult, any, error) { + clusterName := in.ClusterName if clusterName == "" { - return mcp.NewToolResultError("cluster_name parameter is required"), nil + return mcp.NewToolResultError("cluster_name parameter is required"), nil, nil } args := []string{"clustermesh", "disconnect", "--destination-cluster", clusterName} output, err := runCiliumCliWithContext(ctx, args...) if err != nil { - return mcp.NewToolResultError("Error disconnecting from remote cluster: " + err.Error()), nil + return mcp.NewToolResultError("Error disconnecting from remote cluster: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListBGPPeers(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleListBGPPeers(ctx context.Context, request *mcp.CallToolRequest, in noInput) (*mcp.CallToolResult, any, error) { output, err := runCiliumCliWithContext(ctx, "bgp", "peers") if err != nil { - return mcp.NewToolResultError("Error listing BGP peers: " + err.Error()), nil + return mcp.NewToolResultError("Error listing BGP peers: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListBGPRoutes(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleListBGPRoutes(ctx context.Context, request *mcp.CallToolRequest, in noInput) (*mcp.CallToolResult, any, error) { output, err := runCiliumCliWithContext(ctx, "bgp", "routes") if err != nil { - return mcp.NewToolResultError("Error listing BGP routes: " + err.Error()), nil + return mcp.NewToolResultError("Error listing BGP routes: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleShowClusterMeshStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleShowClusterMeshStatus(ctx context.Context, request *mcp.CallToolRequest, in noInput) (*mcp.CallToolResult, any, error) { output, err := runCiliumCliWithContext(ctx, "clustermesh", "status") if err != nil { - return mcp.NewToolResultError("Error getting cluster mesh status: " + err.Error()), nil + return mcp.NewToolResultError("Error getting cluster mesh status: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleShowFeaturesStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleShowFeaturesStatus(ctx context.Context, request *mcp.CallToolRequest, in noInput) (*mcp.CallToolResult, any, error) { output, err := runCiliumCliWithContext(ctx, "features", "status") if err != nil { - return mcp.NewToolResultError("Error getting features status: " + err.Error()), nil + return mcp.NewToolResultError("Error getting features status: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleToggleHubble(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - enableStr := mcp.ParseString(request, "enable", "true") - enable := enableStr == "true" - +func handleToggleHubble(ctx context.Context, request *mcp.CallToolRequest, in enableToggleInput) (*mcp.CallToolResult, any, error) { + enable := true + if in.Enable != nil { + enable = *in.Enable + } var action string if enable { action = "enable" @@ -176,16 +375,17 @@ func handleToggleHubble(ctx context.Context, request mcp.CallToolRequest) (*mcp. output, err := runCiliumCliWithContext(ctx, "hubble", action) if err != nil { - return mcp.NewToolResultError("Error toggling Hubble: " + err.Error()), nil + return mcp.NewToolResultError("Error toggling Hubble: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleToggleClusterMesh(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - enableStr := mcp.ParseString(request, "enable", "true") - enable := enableStr == "true" - +func handleToggleClusterMesh(ctx context.Context, request *mcp.CallToolRequest, in enableToggleInput) (*mcp.CallToolResult, any, error) { + enable := true + if in.Enable != nil { + enable = *in.Enable + } var action string if enable { action = "enable" @@ -195,408 +395,111 @@ func handleToggleClusterMesh(ctx context.Context, request mcp.CallToolRequest) ( output, err := runCiliumCliWithContext(ctx, "clustermesh", action) if err != nil { - return mcp.NewToolResultError("Error toggling cluster mesh: " + err.Error()), nil + return mcp.NewToolResultError("Error toggling cluster mesh: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func RegisterTools(s *server.MCPServer, readOnly bool) { +func RegisterTools(s *mcp.Server, readOnly bool) { // Read-only tools - always registered - s.AddTool(mcp.NewTool("cilium_status_and_version", - mcp.WithDescription("Get the status and version of Cilium installation"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_status_and_version", handleCiliumStatusAndVersion))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_status_and_version", Description: "Get the status and version of Cilium installation"}, handleCiliumStatusAndVersion) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_bgp_peers", Description: "List BGP peers"}, handleListBGPPeers) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_bgp_routes", Description: "List BGP routes"}, handleListBGPRoutes) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_show_cluster_mesh_status", Description: "Show cluster mesh status"}, handleShowClusterMeshStatus) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_show_features_status", Description: "Show Cilium features status"}, handleShowFeaturesStatus) - s.AddTool(mcp.NewTool("cilium_list_bgp_peers", - mcp.WithDescription("List BGP peers"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bgp_peers", handleListBGPPeers))) + if !readOnly { + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_upgrade_cilium", Description: "Upgrade Cilium on the cluster"}, handleUpgradeCilium) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_install_cilium", Description: "Install Cilium on the cluster"}, handleInstallCilium) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_uninstall_cilium", Description: "Uninstall Cilium from the cluster"}, handleUninstallCilium) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_connect_to_remote_cluster", Description: "Connect to a remote cluster for cluster mesh"}, handleConnectToRemoteCluster) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_disconnect_remote_cluster", Description: "Disconnect from a remote cluster"}, handleDisconnectRemoteCluster) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_toggle_hubble", Description: "Enable or disable Hubble"}, handleToggleHubble) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_toggle_cluster_mesh", Description: "Enable or disable cluster mesh"}, handleToggleClusterMesh) + } - s.AddTool(mcp.NewTool("cilium_list_bgp_routes", - mcp.WithDescription("List BGP routes"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bgp_routes", handleListBGPRoutes))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_daemon_status", Description: "Get the status of the Cilium daemon for the cluster"}, handleGetDaemonStatus) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_endpoints_list", Description: "Get the list of all endpoints in the cluster"}, handleGetEndpointsList) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_endpoint_details", Description: "List the details of an endpoint in the cluster"}, handleGetEndpointDetails) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_show_configuration_options", Description: "Show Cilium configuration options"}, handleShowConfigurationOptions) - s.AddTool(mcp.NewTool("cilium_show_cluster_mesh_status", - mcp.WithDescription("Show cluster mesh status"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_cluster_mesh_status", handleShowClusterMeshStatus))) + if !readOnly { + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_toggle_configuration_option", Description: "Toggle a Cilium configuration option"}, handleToggleConfigurationOption) + } - s.AddTool(mcp.NewTool("cilium_show_features_status", - mcp.WithDescription("Show Cilium features status"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_features_status", handleShowFeaturesStatus))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_services", Description: "List services for the cluster"}, handleListServices) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_service_information", Description: "Get information about a service in the cluster"}, handleGetServiceInformation) - // Write tools - only registered when write operations are enabled - if !readOnly { - s.AddTool(mcp.NewTool("cilium_upgrade_cilium", - mcp.WithDescription("Upgrade Cilium on the cluster"), - mcp.WithString("cluster_name", mcp.Description("The name of the cluster to upgrade Cilium on")), - mcp.WithString("datapath_mode", mcp.Description("The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_upgrade_cilium", handleUpgradeCilium))) - - s.AddTool(mcp.NewTool("cilium_install_cilium", - mcp.WithDescription("Install Cilium on the cluster"), - mcp.WithString("cluster_name", mcp.Description("The name of the cluster to install Cilium on")), - mcp.WithString("cluster_id", mcp.Description("The ID of the cluster to install Cilium on")), - mcp.WithString("datapath_mode", mcp.Description("The datapath mode to use for Cilium (tunnel, native, aws-eni, gke, azure, aks-byocni)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_install_cilium", handleInstallCilium))) - - s.AddTool(mcp.NewTool("cilium_uninstall_cilium", - mcp.WithDescription("Uninstall Cilium from the cluster"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_uninstall_cilium", handleUninstallCilium))) - - s.AddTool(mcp.NewTool("cilium_connect_to_remote_cluster", - mcp.WithDescription("Connect to a remote cluster for cluster mesh"), - mcp.WithString("cluster_name", mcp.Description("The name of the destination cluster"), mcp.Required()), - mcp.WithString("context", mcp.Description("The kubectl context for the destination cluster")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_connect_to_remote_cluster", handleConnectToRemoteCluster))) - - s.AddTool(mcp.NewTool("cilium_disconnect_remote_cluster", - mcp.WithDescription("Disconnect from a remote cluster"), - mcp.WithString("cluster_name", mcp.Description("The name of the destination cluster"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_disconnect_remote_cluster", handleDisconnectRemoteCluster))) - - s.AddTool(mcp.NewTool("cilium_toggle_hubble", - mcp.WithDescription("Enable or disable Hubble"), - mcp.WithString("enable", mcp.Description("Set to 'true' to enable, 'false' to disable")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_hubble", handleToggleHubble))) - - s.AddTool(mcp.NewTool("cilium_toggle_cluster_mesh", - mcp.WithDescription("Enable or disable cluster mesh"), - mcp.WithString("enable", mcp.Description("Set to 'true' to enable, 'false' to disable")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_cluster_mesh", handleToggleClusterMesh))) - } - - // Add tools that are also needed by cilium-manager agent - s.AddTool(mcp.NewTool("cilium_get_daemon_status", - mcp.WithDescription("Get the status of the Cilium daemon for the cluster"), - mcp.WithString("show_all_addresses", mcp.Description("Whether to show all addresses")), - mcp.WithString("show_all_clusters", mcp.Description("Whether to show all clusters")), - mcp.WithString("show_all_controllers", mcp.Description("Whether to show all controllers")), - mcp.WithString("show_health", mcp.Description("Whether to show health")), - mcp.WithString("show_all_nodes", mcp.Description("Whether to show all nodes")), - mcp.WithString("show_all_redirects", mcp.Description("Whether to show all redirects")), - mcp.WithString("brief", mcp.Description("Whether to show a brief status")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the daemon status for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_daemon_status", handleGetDaemonStatus))) - - s.AddTool(mcp.NewTool("cilium_get_endpoints_list", - mcp.WithDescription("Get the list of all endpoints in the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoints list for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoints_list", handleGetEndpointsList))) - - s.AddTool(mcp.NewTool("cilium_get_endpoint_details", - mcp.WithDescription("List the details of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get details for")), - mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")), - mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_details", handleGetEndpointDetails))) - - s.AddTool(mcp.NewTool("cilium_show_configuration_options", - mcp.WithDescription("Show Cilium configuration options"), - mcp.WithString("list_all", mcp.Description("Whether to list all configuration options")), - mcp.WithString("list_read_only", mcp.Description("Whether to list read-only configuration options")), - mcp.WithString("list_options", mcp.Description("Whether to list options")), - mcp.WithString("node_name", mcp.Description("The name of the node to show the configuration options for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_configuration_options", handleShowConfigurationOptions))) - - // Write tool - toggle_configuration_option - if !readOnly { - s.AddTool(mcp.NewTool("cilium_toggle_configuration_option", - mcp.WithDescription("Toggle a Cilium configuration option"), - mcp.WithString("option", mcp.Description("The option to toggle"), mcp.Required()), - mcp.WithString("value", mcp.Description("The value to set the option to (true/false)"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to toggle the configuration option for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_toggle_configuration_option", handleToggleConfigurationOption))) - } - - s.AddTool(mcp.NewTool("cilium_list_services", - mcp.WithDescription("List services for the cluster"), - mcp.WithString("show_cluster_mesh_affinity", mcp.Description("Whether to show cluster mesh affinity")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the services for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_services", handleListServices))) - - s.AddTool(mcp.NewTool("cilium_get_service_information", - mcp.WithDescription("Get information about a service in the cluster"), - mcp.WithString("service_id", mcp.Description("The ID of the service to get information about"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the service information for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_service_information", handleGetServiceInformation))) - - // Write tools - service management if !readOnly { - s.AddTool(mcp.NewTool("cilium_update_service", - mcp.WithDescription("Update a service in the cluster"), - mcp.WithString("backend_weights", mcp.Description("The backend weights to update the service with")), - mcp.WithString("backends", mcp.Description("The backends to update the service with"), mcp.Required()), - mcp.WithString("frontend", mcp.Description("The frontend to update the service with"), mcp.Required()), - mcp.WithString("id", mcp.Description("The ID of the service to update"), mcp.Required()), - mcp.WithString("k8s_cluster_internal", mcp.Description("Whether to update the k8s cluster internal flag")), - mcp.WithString("k8s_ext_traffic_policy", mcp.Description("The k8s ext traffic policy to update the service with")), - mcp.WithString("k8s_external", mcp.Description("Whether to update the k8s external flag")), - mcp.WithString("k8s_host_port", mcp.Description("Whether to update the k8s host port flag")), - mcp.WithString("k8s_int_traffic_policy", mcp.Description("The k8s int traffic policy to update the service with")), - mcp.WithString("k8s_load_balancer", mcp.Description("Whether to update the k8s load balancer flag")), - mcp.WithString("k8s_node_port", mcp.Description("Whether to update the k8s node port flag")), - mcp.WithString("local_redirect", mcp.Description("Whether to update the local redirect flag")), - mcp.WithString("protocol", mcp.Description("The protocol to update the service with")), - mcp.WithString("states", mcp.Description("The states to update the service with")), - mcp.WithString("node_name", mcp.Description("The name of the node to update the service on")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_service", handleUpdateService))) - - s.AddTool(mcp.NewTool("cilium_delete_service", - mcp.WithDescription("Delete a service from the cluster"), - mcp.WithString("service_id", mcp.Description("The ID of the service to delete")), - mcp.WithString("all", mcp.Description("Whether to delete all services (true/false)")), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the service from")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_service", handleDeleteService))) - } - - // Debug tools (previously in RegisterCiliumDbgTools) - s.AddTool(mcp.NewTool("cilium_get_endpoint_details", - mcp.WithDescription("List the details of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get details for")), - mcp.WithString("labels", mcp.Description("The labels of the endpoint to get details for")), - mcp.WithString("output_format", mcp.Description("The output format of the endpoint details (json, yaml, jsonpath)")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint details for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_details", handleGetEndpointDetails))) - - s.AddTool(mcp.NewTool("cilium_get_endpoint_logs", - mcp.WithDescription("Get the logs of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get logs for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint logs for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_logs", handleGetEndpointLogs))) - - s.AddTool(mcp.NewTool("cilium_get_endpoint_health", - mcp.WithDescription("Get the health of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to get health for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the endpoint health for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_endpoint_health", handleGetEndpointHealth))) - - // Write tools - endpoint management - if !readOnly { - s.AddTool(mcp.NewTool("cilium_manage_endpoint_labels", - mcp.WithDescription("Manage the labels (add or delete) of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage labels for"), mcp.Required()), - mcp.WithString("labels", mcp.Description("Space-separated labels to manage (e.g., 'key1=value1 key2=value2')"), mcp.Required()), - mcp.WithString("action", mcp.Description("The action to perform on the labels (add or delete)"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint labels on")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_manage_endpoint_labels", handleManageEndpointLabels))) - - s.AddTool(mcp.NewTool("cilium_manage_endpoint_config", - mcp.WithDescription("Manage the configuration of an endpoint in the cluster"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to manage configuration for"), mcp.Required()), - mcp.WithString("config", mcp.Description("The configuration to manage for the endpoint provided as a space-separated list of key-value pairs (e.g. 'DropNotification=false TraceNotification=false')"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to manage the endpoint configuration on")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_manage_endpoint_config", handleManageEndpointConfiguration))) - - s.AddTool(mcp.NewTool("cilium_disconnect_endpoint", - mcp.WithDescription("Disconnect an endpoint from the network"), - mcp.WithString("endpoint_id", mcp.Description("The ID of the endpoint to disconnect"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to disconnect the endpoint from")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_disconnect_endpoint", handleDisconnectEndpoint))) - } - - s.AddTool(mcp.NewTool("cilium_list_identities", - mcp.WithDescription("List all identities in the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to list the identities for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_identities", handleListIdentities))) - - s.AddTool(mcp.NewTool("cilium_get_identity_details", - mcp.WithDescription("Get the details of an identity in the cluster"), - mcp.WithString("identity_id", mcp.Description("The ID of the identity to get details for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the identity details for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_identity_details", handleGetIdentityDetails))) - - s.AddTool(mcp.NewTool("cilium_request_debugging_information", - mcp.WithDescription("Request debugging information for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the debugging information for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_request_debugging_information", handleRequestDebuggingInformation))) - - s.AddTool(mcp.NewTool("cilium_display_encryption_state", - mcp.WithDescription("Display the encryption state for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the encryption state for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_encryption_state", handleDisplayEncryptionState))) - - // Write tool - flush_ipsec_state - if !readOnly { - s.AddTool(mcp.NewTool("cilium_flush_ipsec_state", - mcp.WithDescription("Flush the IPsec state for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to flush the IPsec state for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_flush_ipsec_state", handleFlushIPsecState))) - } - - s.AddTool(mcp.NewTool("cilium_list_envoy_config", - mcp.WithDescription("List the Envoy configuration for a resource in the cluster"), - mcp.WithString("resource_name", mcp.Description("The name of the resource to get the Envoy configuration for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the Envoy configuration for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_envoy_config", handleListEnvoyConfig))) - - s.AddTool(mcp.NewTool("cilium_fqdn_cache", - mcp.WithDescription("Manage the FQDN cache for the cluster"), - mcp.WithString("command", mcp.Description("The command to perform on the FQDN cache (list, clean, or a specific command)"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to manage the FQDN cache for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_fqdn_cache", handleFQDNCache))) - - s.AddTool(mcp.NewTool("cilium_show_dns_names", - mcp.WithDescription("Show the DNS names for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the DNS names for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_dns_names", handleShowDNSNames))) - - s.AddTool(mcp.NewTool("cilium_list_ip_addresses", - mcp.WithDescription("List the IP addresses for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the IP addresses for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_ip_addresses", handleListIPAddresses))) - - s.AddTool(mcp.NewTool("cilium_show_ip_cache_information", - mcp.WithDescription("Show the IP cache information for the cluster"), - mcp.WithString("cidr", mcp.Description("The CIDR of the IP to get cache information for")), - mcp.WithString("labels", mcp.Description("The labels of the IP to get cache information for")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the IP cache information for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_ip_cache_information", handleShowIPCacheInformation))) - - // Write tool - delete_key_from_kv_store + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_update_service", Description: "Update a service in the cluster"}, handleUpdateService) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_delete_service", Description: "Delete a service from the cluster"}, handleDeleteService) + } + + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_endpoint_details", Description: "List the details of an endpoint in the cluster"}, handleGetEndpointDetails) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_endpoint_logs", Description: "Get the logs of an endpoint in the cluster"}, handleGetEndpointLogs) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_endpoint_health", Description: "Get the health of an endpoint in the cluster"}, handleGetEndpointHealth) + if !readOnly { - s.AddTool(mcp.NewTool("cilium_delete_key_from_kv_store", - mcp.WithDescription("Delete a key from the kvstore for the cluster"), - mcp.WithString("key", mcp.Description("The key to delete from the kvstore"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the key from")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_key_from_kv_store", handleDeleteKeyFromKVStore))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_manage_endpoint_labels", Description: "Manage the labels (add or delete) of an endpoint in the cluster"}, handleManageEndpointLabels) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_manage_endpoint_config", Description: "Manage the configuration of an endpoint in the cluster"}, handleManageEndpointConfiguration) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_disconnect_endpoint", Description: "Disconnect an endpoint from the network"}, handleDisconnectEndpoint) } - s.AddTool(mcp.NewTool("cilium_get_kv_store_key", - mcp.WithDescription("Get a key from the kvstore for the cluster"), - mcp.WithString("key", mcp.Description("The key to get from the kvstore"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the key from")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_kv_store_key", handleGetKVStoreKey))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_identities", Description: "List all identities in the cluster"}, handleListIdentities) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_identity_details", Description: "Get the details of an identity in the cluster"}, handleGetIdentityDetails) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_request_debugging_information", Description: "Request debugging information for the cluster"}, handleRequestDebuggingInformation) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_display_encryption_state", Description: "Display the encryption state for the cluster"}, handleDisplayEncryptionState) - // Write tool - set_kv_store_key if !readOnly { - s.AddTool(mcp.NewTool("cilium_set_kv_store_key", - mcp.WithDescription("Set a key in the kvstore for the cluster"), - mcp.WithString("key", mcp.Description("The key to set in the kvstore"), mcp.Required()), - mcp.WithString("value", mcp.Description("The value to set in the kvstore"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to set the key in")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_set_kv_store_key", handleSetKVStoreKey))) - } - - s.AddTool(mcp.NewTool("cilium_show_load_information", - mcp.WithDescription("Show load information for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the load information for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_show_load_information", handleShowLoadInformation))) - - s.AddTool(mcp.NewTool("cilium_list_local_redirect_policies", - mcp.WithDescription("List local redirect policies for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the local redirect policies for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_local_redirect_policies", handleListLocalRedirectPolicies))) - - s.AddTool(mcp.NewTool("cilium_list_bpf_map_events", - mcp.WithDescription("List BPF map events for the cluster"), - mcp.WithString("map_name", mcp.Description("The name of the BPF map to get events for"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map events for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bpf_map_events", handleListBPFMapEvents))) - - s.AddTool(mcp.NewTool("cilium_get_bpf_map", - mcp.WithDescription("Get BPF map for the cluster"), - mcp.WithString("map_name", mcp.Description("The name of the BPF map to get"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF map for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_bpf_map", handleGetBPFMap))) - - s.AddTool(mcp.NewTool("cilium_list_bpf_maps", - mcp.WithDescription("List BPF maps for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the BPF maps for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_bpf_maps", handleListBPFMaps))) - - s.AddTool(mcp.NewTool("cilium_list_metrics", - mcp.WithDescription("List metrics for the cluster"), - mcp.WithString("match_pattern", mcp.Description("The match pattern to filter metrics by")), - mcp.WithString("node_name", mcp.Description("The name of the node to get the metrics for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_metrics", handleListMetrics))) - - s.AddTool(mcp.NewTool("cilium_list_cluster_nodes", - mcp.WithDescription("List cluster nodes for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the cluster nodes for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_cluster_nodes", handleListClusterNodes))) - - s.AddTool(mcp.NewTool("cilium_list_node_ids", - mcp.WithDescription("List node IDs for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the node IDs for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_node_ids", handleListNodeIds))) - - s.AddTool(mcp.NewTool("cilium_display_policy_node_information", - mcp.WithDescription("Display policy node information for the cluster"), - mcp.WithString("labels", mcp.Description("The labels to get policy node information for")), - mcp.WithString("node_name", mcp.Description("The name of the node to get policy node information for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_policy_node_information", handleDisplayPolicyNodeInformation))) - - // Write tool - delete_policy_rules + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_flush_ipsec_state", Description: "Flush the IPsec state for the cluster"}, handleFlushIPsecState) + } + + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_envoy_config", Description: "List the Envoy configuration for a resource in the cluster"}, handleListEnvoyConfig) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_fqdn_cache", Description: "Manage the FQDN cache for the cluster"}, handleFQDNCache) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_show_dns_names", Description: "Show the DNS names for the cluster"}, handleShowDNSNames) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_ip_addresses", Description: "List the IP addresses for the cluster"}, handleListIPAddresses) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_show_ip_cache_information", Description: "Show the IP cache information for the cluster"}, handleShowIPCacheInformation) + if !readOnly { - s.AddTool(mcp.NewTool("cilium_delete_policy_rules", - mcp.WithDescription("Delete policy rules for the cluster"), - mcp.WithString("labels", mcp.Description("The labels to delete policy rules for")), - mcp.WithString("all", mcp.Description("Whether to delete all policy rules")), - mcp.WithString("node_name", mcp.Description("The name of the node to delete policy rules for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_policy_rules", handleDeletePolicyRules))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_delete_key_from_kv_store", Description: "Delete a key from the kvstore for the cluster"}, handleDeleteKeyFromKVStore) } - s.AddTool(mcp.NewTool("cilium_display_selectors", - mcp.WithDescription("Display selectors for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get selectors for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_display_selectors", handleDisplaySelectors))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_kv_store_key", Description: "Get a key from the kvstore for the cluster"}, handleGetKVStoreKey) + + if !readOnly { + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_set_kv_store_key", Description: "Set a key in the kvstore for the cluster"}, handleSetKVStoreKey) + } - s.AddTool(mcp.NewTool("cilium_list_xdp_cidr_filters", - mcp.WithDescription("List XDP CIDR filters for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the XDP CIDR filters for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_xdp_cidr_filters", handleListXDPCIDRFilters))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_show_load_information", Description: "Show load information for the cluster"}, handleShowLoadInformation) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_local_redirect_policies", Description: "List local redirect policies for the cluster"}, handleListLocalRedirectPolicies) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_bpf_map_events", Description: "List BPF map events for the cluster"}, handleListBPFMapEvents) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_bpf_map", Description: "Get BPF map for the cluster"}, handleGetBPFMap) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_bpf_maps", Description: "List BPF maps for the cluster"}, handleListBPFMaps) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_metrics", Description: "List metrics for the cluster"}, handleListMetrics) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_cluster_nodes", Description: "List cluster nodes for the cluster"}, handleListClusterNodes) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_node_ids", Description: "List node IDs for the cluster"}, handleListNodeIds) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_display_policy_node_information", Description: "Display policy node information for the cluster"}, handleDisplayPolicyNodeInformation) - // Write tools - XDP CIDR filters if !readOnly { - s.AddTool(mcp.NewTool("cilium_update_xdp_cidr_filters", - mcp.WithDescription("Update XDP CIDR filters for the cluster"), - mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to update the XDP filters for"), mcp.Required()), - mcp.WithString("revision", mcp.Description("The revision of the XDP filters to update")), - mcp.WithString("node_name", mcp.Description("The name of the node to update the XDP filters for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_xdp_cidr_filters", handleUpdateXDPCIDRFilters))) - - s.AddTool(mcp.NewTool("cilium_delete_xdp_cidr_filters", - mcp.WithDescription("Delete XDP CIDR filters for the cluster"), - mcp.WithString("cidr_prefixes", mcp.Description("The CIDR prefixes to delete the XDP filters for"), mcp.Required()), - mcp.WithString("revision", mcp.Description("The revision of the XDP filters to delete")), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the XDP filters for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_xdp_cidr_filters", handleDeleteXDPCIDRFilters))) - } - - s.AddTool(mcp.NewTool("cilium_validate_cilium_network_policies", - mcp.WithDescription("Validate Cilium network policies for the cluster"), - mcp.WithString("enable_k8s", mcp.Description("Whether to enable k8s API discovery")), - mcp.WithString("enable_k8s_api_discovery", mcp.Description("Whether to enable k8s API discovery")), - mcp.WithString("node_name", mcp.Description("The name of the node to validate the Cilium network policies for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_validate_cilium_network_policies", handleValidateCiliumNetworkPolicies))) - - s.AddTool(mcp.NewTool("cilium_list_pcap_recorders", - mcp.WithDescription("List PCAP recorders for the cluster"), - mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorders for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_list_pcap_recorders", handleListPCAPRecorders))) - - s.AddTool(mcp.NewTool("cilium_get_pcap_recorder", - mcp.WithDescription("Get a PCAP recorder for the cluster"), - mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to get"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to get the PCAP recorder for")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_get_pcap_recorder", handleGetPCAPRecorder))) - - // Write tools - PCAP recorder management + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_delete_policy_rules", Description: "Delete policy rules for the cluster"}, handleDeletePolicyRules) + } + + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_display_selectors", Description: "Display selectors for the cluster"}, handleDisplaySelectors) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_xdp_cidr_filters", Description: "List XDP CIDR filters for the cluster"}, handleListXDPCIDRFilters) + if !readOnly { - s.AddTool(mcp.NewTool("cilium_delete_pcap_recorder", - mcp.WithDescription("Delete a PCAP recorder for the cluster"), - mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to delete"), mcp.Required()), - mcp.WithString("node_name", mcp.Description("The name of the node to delete the PCAP recorder from")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_delete_pcap_recorder", handleDeletePCAPRecorder))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_update_xdp_cidr_filters", Description: "Update XDP CIDR filters for the cluster"}, handleUpdateXDPCIDRFilters) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_delete_xdp_cidr_filters", Description: "Delete XDP CIDR filters for the cluster"}, handleDeleteXDPCIDRFilters) + } - s.AddTool(mcp.NewTool("cilium_update_pcap_recorder", - mcp.WithDescription("Update a PCAP recorder for the cluster"), - mcp.WithString("recorder_id", mcp.Description("The ID of the PCAP recorder to update"), mcp.Required()), - mcp.WithString("filters", mcp.Description("The filters to update the PCAP recorder with"), mcp.Required()), - mcp.WithString("caplen", mcp.Description("The caplen to update the PCAP recorder with")), - mcp.WithString("id", mcp.Description("The id to update the PCAP recorder with")), - mcp.WithString("node_name", mcp.Description("The name of the node to update the PCAP recorder on")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("cilium_update_pcap_recorder", handleUpdatePCAPRecorder))) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_validate_cilium_network_policies", Description: "Validate Cilium network policies for the cluster"}, handleValidateCiliumNetworkPolicies) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_list_pcap_recorders", Description: "List PCAP recorders for the cluster"}, handleListPCAPRecorders) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_get_pcap_recorder", Description: "Get a PCAP recorder for the cluster"}, handleGetPCAPRecorder) + + if !readOnly { + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_delete_pcap_recorder", Description: "Delete a PCAP recorder for the cluster"}, handleDeletePCAPRecorder) + mcp.AddTool(s, "cilium", &mcp.Tool{Name: "cilium_update_pcap_recorder", Description: "Update a PCAP recorder for the cluster"}, handleUpdatePCAPRecorder) } } @@ -629,11 +532,14 @@ func runCiliumDbgCommandWithContext(ctx context.Context, command, nodeName strin Execute(ctx) } -func handleGetEndpointDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - endpointID := mcp.ParseString(request, "endpoint_id", "") - labels := mcp.ParseString(request, "labels", "") - outputFormat := mcp.ParseString(request, "output_format", "json") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetEndpointDetails(ctx context.Context, request *mcp.CallToolRequest, in getEndpointDetailsInput) (*mcp.CallToolResult, any, error) { + if in.OutputFormat == "" { + in.OutputFormat = "json" + } + endpointID := in.EndpointID + labels := in.Labels + outputFormat := in.OutputFormat + nodeName := in.NodeName var cmd string if labels != "" { @@ -641,144 +547,147 @@ func handleGetEndpointDetails(ctx context.Context, request mcp.CallToolRequest) } else if endpointID != "" { cmd = fmt.Sprintf("endpoint get %s -o %s", endpointID, outputFormat) } else { - return mcp.NewToolResultError("either endpoint_id or labels must be provided"), nil + return mcp.NewToolResultError("either endpoint_id or labels must be provided"), nil, nil } output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint details: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint details: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetEndpointLogs(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - endpointID := mcp.ParseString(request, "endpoint_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetEndpointLogs(ctx context.Context, request *mcp.CallToolRequest, in getEndpointLogsInput) (*mcp.CallToolResult, any, error) { + endpointID := in.EndpointID + nodeName := in.NodeName if endpointID == "" { - return mcp.NewToolResultError("endpoint_id parameter is required"), nil + return mcp.NewToolResultError("endpoint_id parameter is required"), nil, nil } cmd := fmt.Sprintf("endpoint logs %s", endpointID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint logs: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint logs: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetEndpointHealth(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - endpointID := mcp.ParseString(request, "endpoint_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetEndpointHealth(ctx context.Context, request *mcp.CallToolRequest, in getEndpointHealthInput) (*mcp.CallToolResult, any, error) { + endpointID := in.EndpointID + nodeName := in.NodeName if endpointID == "" { - return mcp.NewToolResultError("endpoint_id parameter is required"), nil + return mcp.NewToolResultError("endpoint_id parameter is required"), nil, nil } cmd := fmt.Sprintf("endpoint health %s", endpointID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint health: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoint health: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleManageEndpointLabels(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - endpointID := mcp.ParseString(request, "endpoint_id", "") - labels := mcp.ParseString(request, "labels", "") - action := mcp.ParseString(request, "action", "add") // Default to add - nodeName := mcp.ParseString(request, "node_name", "") +func handleManageEndpointLabels(ctx context.Context, request *mcp.CallToolRequest, in manageEndpointLabelsInput) (*mcp.CallToolResult, any, error) { + if in.Action == "" { + in.Action = "add" + } + endpointID := in.EndpointID + labels := in.Labels + action := in.Action + nodeName := in.NodeName if endpointID == "" || labels == "" { - return mcp.NewToolResultError("endpoint_id and labels parameters are required"), nil + return mcp.NewToolResultError("endpoint_id and labels parameters are required"), nil, nil } cmd := fmt.Sprintf("endpoint labels %s --%s %s", endpointID, action, labels) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to manage endpoint labels: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to manage endpoint labels: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleManageEndpointConfiguration(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - endpointID := mcp.ParseString(request, "endpoint_id", "") - config := mcp.ParseString(request, "config", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleManageEndpointConfiguration(ctx context.Context, request *mcp.CallToolRequest, in manageEndpointConfigurationInput) (*mcp.CallToolResult, any, error) { + endpointID := in.EndpointID + config := in.Config + nodeName := in.NodeName if endpointID == "" { - return mcp.NewToolResultError("endpoint_id parameter is required"), nil + return mcp.NewToolResultError("endpoint_id parameter is required"), nil, nil } if config == "" { - return mcp.NewToolResultError("config parameter is required"), nil + return mcp.NewToolResultError("config parameter is required"), nil, nil } command := fmt.Sprintf("endpoint config %s %s", endpointID, config) output, err := runCiliumDbgCommand(ctx, command, nodeName) if err != nil { - return mcp.NewToolResultError("Error managing endpoint configuration: " + err.Error()), nil + return mcp.NewToolResultError("Error managing endpoint configuration: " + err.Error()), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDisconnectEndpoint(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - endpointID := mcp.ParseString(request, "endpoint_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleDisconnectEndpoint(ctx context.Context, request *mcp.CallToolRequest, in disconnectEndpointInput) (*mcp.CallToolResult, any, error) { + endpointID := in.EndpointID + nodeName := in.NodeName if endpointID == "" { - return mcp.NewToolResultError("endpoint_id parameter is required"), nil + return mcp.NewToolResultError("endpoint_id parameter is required"), nil, nil } cmd := fmt.Sprintf("endpoint disconnect %s", endpointID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to disconnect endpoint: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to disconnect endpoint: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetEndpointsList(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetEndpointsList(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "endpoint list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoints list: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get endpoints list: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListIdentities(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListIdentities(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "identity list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list identities: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list identities: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetIdentityDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - identityID := mcp.ParseString(request, "identity_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetIdentityDetails(ctx context.Context, request *mcp.CallToolRequest, in getIdentityDetailsInput) (*mcp.CallToolResult, any, error) { + identityID := in.IdentityID + nodeName := in.NodeName if identityID == "" { - return mcp.NewToolResultError("identity_id parameter is required"), nil + return mcp.NewToolResultError("identity_id parameter is required"), nil, nil } cmd := fmt.Sprintf("identity get %s", identityID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get identity details: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get identity details: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleShowConfigurationOptions(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - listAll := mcp.ParseString(request, "list_all", "") == "true" - listReadOnly := mcp.ParseString(request, "list_read_only", "") == "true" - listOptions := mcp.ParseString(request, "list_options", "") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleShowConfigurationOptions(ctx context.Context, request *mcp.CallToolRequest, in showConfigurationOptionsInput) (*mcp.CallToolResult, any, error) { + listAll := in.ListAll + listReadOnly := in.ListReadOnly + listOptions := in.ListOptions + nodeName := in.NodeName var cmd string if listAll { @@ -793,18 +702,21 @@ func handleShowConfigurationOptions(ctx context.Context, request mcp.CallToolReq output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to show configuration options: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to show configuration options: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleToggleConfigurationOption(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - option := mcp.ParseString(request, "option", "") - value := mcp.ParseString(request, "value", "true") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleToggleConfigurationOption(ctx context.Context, request *mcp.CallToolRequest, in toggleConfigurationOptionInput) (*mcp.CallToolResult, any, error) { + option := in.Option + value := true + if in.Value != nil { + value = *in.Value + } + nodeName := in.NodeName if option == "" { - return mcp.NewToolResultError("option parameter is required"), nil + return mcp.NewToolResultError("option parameter is required"), nil, nil } valueStr := "enable" @@ -815,60 +727,63 @@ func handleToggleConfigurationOption(ctx context.Context, request mcp.CallToolRe cmd := fmt.Sprintf("config %s=%s", option, valueStr) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to toggle configuration option: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to toggle configuration option: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleRequestDebuggingInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleRequestDebuggingInformation(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "debuginfo", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to request debugging information: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to request debugging information: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDisplayEncryptionState(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleDisplayEncryptionState(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "encrypt status", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to display encryption state: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to display encryption state: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleFlushIPsecState(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleFlushIPsecState(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "encrypt flush -f", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to flush IPsec state: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to flush IPsec state: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListEnvoyConfig(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceName := mcp.ParseString(request, "resource_name", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleListEnvoyConfig(ctx context.Context, request *mcp.CallToolRequest, in listEnvoyConfigInput) (*mcp.CallToolResult, any, error) { + resourceName := in.ResourceName + nodeName := in.NodeName if resourceName == "" { - return mcp.NewToolResultError("resource_name parameter is required"), nil + return mcp.NewToolResultError("resource_name parameter is required"), nil, nil } cmd := fmt.Sprintf("envoy admin %s", resourceName) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list Envoy config: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list Envoy config: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleFQDNCache(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - command := mcp.ParseString(request, "command", "list") - nodeName := mcp.ParseString(request, "node_name", "") +func handleFQDNCache(ctx context.Context, request *mcp.CallToolRequest, in fqdnCacheInput) (*mcp.CallToolResult, any, error) { + if in.Command == "" { + in.Command = "list" + } + command := in.Command + nodeName := in.NodeName var cmd string if command == "clean" { @@ -879,35 +794,35 @@ func handleFQDNCache(ctx context.Context, request mcp.CallToolRequest) (*mcp.Cal output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to manage FQDN cache: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to manage FQDN cache: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleShowDNSNames(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleShowDNSNames(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "fqdn names", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to show DNS names: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to show DNS names: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListIPAddresses(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListIPAddresses(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "ip list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list IP addresses: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list IP addresses: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleShowIPCacheInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - cidr := mcp.ParseString(request, "cidr", "") - labels := mcp.ParseString(request, "labels", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleShowIPCacheInformation(ctx context.Context, request *mcp.CallToolRequest, in showIPCacheInformationInput) (*mcp.CallToolResult, any, error) { + cidr := in.CIDR + labels := in.Labels + nodeName := in.NodeName var cmd string if labels != "" { @@ -915,130 +830,130 @@ func handleShowIPCacheInformation(ctx context.Context, request mcp.CallToolReque } else if cidr != "" { cmd = fmt.Sprintf("ip get %s", cidr) } else { - return mcp.NewToolResultError("either cidr or labels must be provided"), nil + return mcp.NewToolResultError("either cidr or labels must be provided"), nil, nil } output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to show IP cache information: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to show IP cache information: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDeleteKeyFromKVStore(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - key := mcp.ParseString(request, "key", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleDeleteKeyFromKVStore(ctx context.Context, request *mcp.CallToolRequest, in kvStoreKeyInput) (*mcp.CallToolResult, any, error) { + key := in.Key + nodeName := in.NodeName if key == "" { - return mcp.NewToolResultError("key parameter is required"), nil + return mcp.NewToolResultError("key parameter is required"), nil, nil } cmd := fmt.Sprintf("kvstore delete %s", key) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to delete key from kvstore: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to delete key from kvstore: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - key := mcp.ParseString(request, "key", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetKVStoreKey(ctx context.Context, request *mcp.CallToolRequest, in kvStoreKeyInput) (*mcp.CallToolResult, any, error) { + key := in.Key + nodeName := in.NodeName if key == "" { - return mcp.NewToolResultError("key parameter is required"), nil + return mcp.NewToolResultError("key parameter is required"), nil, nil } cmd := fmt.Sprintf("kvstore get %s", key) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get key from kvstore: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get key from kvstore: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleSetKVStoreKey(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - key := mcp.ParseString(request, "key", "") - value := mcp.ParseString(request, "value", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleSetKVStoreKey(ctx context.Context, request *mcp.CallToolRequest, in setKVStoreKeyInput) (*mcp.CallToolResult, any, error) { + key := in.Key + value := in.Value + nodeName := in.NodeName if key == "" || value == "" { - return mcp.NewToolResultError("key and value parameters are required"), nil + return mcp.NewToolResultError("key and value parameters are required"), nil, nil } cmd := fmt.Sprintf("kvstore set %s=%s", key, value) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to set key in kvstore: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to set key in kvstore: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleShowLoadInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleShowLoadInformation(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "loadinfo", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to show load information: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to show load information: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListLocalRedirectPolicies(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListLocalRedirectPolicies(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "lrp list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list local redirect policies: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list local redirect policies: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListBPFMapEvents(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - mapName := mcp.ParseString(request, "map_name", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleListBPFMapEvents(ctx context.Context, request *mcp.CallToolRequest, in bpfMapInput) (*mcp.CallToolResult, any, error) { + mapName := in.MapName + nodeName := in.NodeName if mapName == "" { - return mcp.NewToolResultError("map_name parameter is required"), nil + return mcp.NewToolResultError("map_name parameter is required"), nil, nil } cmd := fmt.Sprintf("map events %s", mapName) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF map events: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF map events: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetBPFMap(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - mapName := mcp.ParseString(request, "map_name", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetBPFMap(ctx context.Context, request *mcp.CallToolRequest, in bpfMapInput) (*mcp.CallToolResult, any, error) { + mapName := in.MapName + nodeName := in.NodeName if mapName == "" { - return mcp.NewToolResultError("map_name parameter is required"), nil + return mcp.NewToolResultError("map_name parameter is required"), nil, nil } cmd := fmt.Sprintf("map get %s", mapName) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get BPF map: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get BPF map: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListBPFMaps(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListBPFMaps(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "map list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF maps: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list BPF maps: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListMetrics(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - matchPattern := mcp.ParseString(request, "match_pattern", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleListMetrics(ctx context.Context, request *mcp.CallToolRequest, in listMetricsInput) (*mcp.CallToolResult, any, error) { + matchPattern := in.MatchPattern + nodeName := in.NodeName var cmd string if matchPattern != "" { @@ -1049,34 +964,34 @@ func handleListMetrics(ctx context.Context, request mcp.CallToolRequest) (*mcp.C output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list metrics: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list metrics: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListClusterNodes(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListClusterNodes(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "node list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list cluster nodes: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list cluster nodes: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListNodeIds(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListNodeIds(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "nodeid list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list node IDs: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list node IDs: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDisplayPolicyNodeInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - labels := mcp.ParseString(request, "labels", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleDisplayPolicyNodeInformation(ctx context.Context, request *mcp.CallToolRequest, in displayPolicyNodeInformationInput) (*mcp.CallToolResult, any, error) { + labels := in.Labels + nodeName := in.NodeName var cmd string if labels != "" { @@ -1087,15 +1002,15 @@ func handleDisplayPolicyNodeInformation(ctx context.Context, request mcp.CallToo output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to display policy node information: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to display policy node information: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDeletePolicyRules(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - labels := mcp.ParseString(request, "labels", "") - all := mcp.ParseString(request, "all", "") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleDeletePolicyRules(ctx context.Context, request *mcp.CallToolRequest, in deletePolicyRulesInput) (*mcp.CallToolResult, any, error) { + labels := in.Labels + all := in.All + nodeName := in.NodeName var cmd string if all { @@ -1103,43 +1018,43 @@ func handleDeletePolicyRules(ctx context.Context, request mcp.CallToolRequest) ( } else if labels != "" { cmd = fmt.Sprintf("policy delete %s", labels) } else { - return mcp.NewToolResultError("either labels or all=true must be provided"), nil + return mcp.NewToolResultError("either labels or all=true must be provided"), nil, nil } output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to delete policy rules: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to delete policy rules: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDisplaySelectors(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleDisplaySelectors(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "policy selectors", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to display selectors: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to display selectors: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListXDPCIDRFilters(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "prefilter list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list XDP CIDR filters: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list XDP CIDR filters: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleUpdateXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - cidrPrefixes := mcp.ParseString(request, "cidr_prefixes", "") - revision := mcp.ParseString(request, "revision", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleUpdateXDPCIDRFilters(ctx context.Context, request *mcp.CallToolRequest, in xdpCIDRFiltersInput) (*mcp.CallToolResult, any, error) { + cidrPrefixes := in.CIDRPrefixes + revision := in.Revision + nodeName := in.NodeName if cidrPrefixes == "" { - return mcp.NewToolResultError("cidr_prefixes parameter is required"), nil + return mcp.NewToolResultError("cidr_prefixes parameter is required"), nil, nil } var cmd string @@ -1151,18 +1066,18 @@ func handleUpdateXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to update XDP CIDR filters: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to update XDP CIDR filters: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDeleteXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - cidrPrefixes := mcp.ParseString(request, "cidr_prefixes", "") - revision := mcp.ParseString(request, "revision", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleDeleteXDPCIDRFilters(ctx context.Context, request *mcp.CallToolRequest, in xdpCIDRFiltersInput) (*mcp.CallToolResult, any, error) { + cidrPrefixes := in.CIDRPrefixes + revision := in.Revision + nodeName := in.NodeName if cidrPrefixes == "" { - return mcp.NewToolResultError("cidr_prefixes parameter is required"), nil + return mcp.NewToolResultError("cidr_prefixes parameter is required"), nil, nil } var cmd string @@ -1174,15 +1089,15 @@ func handleDeleteXDPCIDRFilters(ctx context.Context, request mcp.CallToolRequest output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to delete XDP CIDR filters: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to delete XDP CIDR filters: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleValidateCiliumNetworkPolicies(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - enableK8s := mcp.ParseString(request, "enable_k8s", "") == "true" - enableK8sAPIDiscovery := mcp.ParseString(request, "enable_k8s_api_discovery", "") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleValidateCiliumNetworkPolicies(ctx context.Context, request *mcp.CallToolRequest, in validateCiliumNetworkPoliciesInput) (*mcp.CallToolResult, any, error) { + enableK8s := in.EnableK8s + enableK8sAPIDiscovery := in.EnableK8sAPIDiscovery + nodeName := in.NodeName cmd := "preflight validate-cnp" if enableK8s { @@ -1194,75 +1109,81 @@ func handleValidateCiliumNetworkPolicies(ctx context.Context, request mcp.CallTo output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to validate Cilium network policies: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to validate Cilium network policies: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListPCAPRecorders(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeName := mcp.ParseString(request, "node_name", "") +func handleListPCAPRecorders(ctx context.Context, request *mcp.CallToolRequest, in nodeNameInput) (*mcp.CallToolResult, any, error) { + nodeName := in.NodeName output, err := runCiliumDbgCommand(ctx, "recorder list", nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list PCAP recorders: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list PCAP recorders: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetPCAPRecorder(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - recorderID := mcp.ParseString(request, "recorder_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetPCAPRecorder(ctx context.Context, request *mcp.CallToolRequest, in pcapRecorderIDInput) (*mcp.CallToolResult, any, error) { + recorderID := in.RecorderID + nodeName := in.NodeName if recorderID == "" { - return mcp.NewToolResultError("recorder_id parameter is required"), nil + return mcp.NewToolResultError("recorder_id parameter is required"), nil, nil } cmd := fmt.Sprintf("recorder get %s", recorderID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get PCAP recorder: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get PCAP recorder: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDeletePCAPRecorder(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - recorderID := mcp.ParseString(request, "recorder_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleDeletePCAPRecorder(ctx context.Context, request *mcp.CallToolRequest, in pcapRecorderIDInput) (*mcp.CallToolResult, any, error) { + recorderID := in.RecorderID + nodeName := in.NodeName if recorderID == "" { - return mcp.NewToolResultError("recorder_id parameter is required"), nil + return mcp.NewToolResultError("recorder_id parameter is required"), nil, nil } cmd := fmt.Sprintf("recorder delete %s", recorderID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to delete PCAP recorder: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to delete PCAP recorder: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleUpdatePCAPRecorder(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - recorderID := mcp.ParseString(request, "recorder_id", "") - filters := mcp.ParseString(request, "filters", "") - caplen := mcp.ParseString(request, "caplen", "0") - id := mcp.ParseString(request, "id", "0") - nodeName := mcp.ParseString(request, "node_name", "") +func handleUpdatePCAPRecorder(ctx context.Context, request *mcp.CallToolRequest, in updatePCAPRecorderInput) (*mcp.CallToolResult, any, error) { + if in.Caplen == "" { + in.Caplen = "0" + } + if in.ID == "" { + in.ID = "0" + } + recorderID := in.RecorderID + filters := in.Filters + caplen := in.Caplen + id := in.ID + nodeName := in.NodeName if recorderID == "" || filters == "" { - return mcp.NewToolResultError("recorder_id and filters parameters are required"), nil + return mcp.NewToolResultError("recorder_id and filters parameters are required"), nil, nil } cmd := fmt.Sprintf("recorder update %s --filters %s --caplen %s --id %s", recorderID, filters, caplen, id) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to update PCAP recorder: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to update PCAP recorder: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleListServices(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - showClusterMeshAffinity := mcp.ParseString(request, "show_cluster_mesh_affinity", "") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleListServices(ctx context.Context, request *mcp.CallToolRequest, in listServicesInput) (*mcp.CallToolResult, any, error) { + showClusterMeshAffinity := in.ShowClusterMeshAffinity + nodeName := in.NodeName var cmd string if showClusterMeshAffinity { @@ -1273,31 +1194,31 @@ func handleListServices(ctx context.Context, request mcp.CallToolRequest) (*mcp. output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to list services: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to list services: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetServiceInformation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - serviceID := mcp.ParseString(request, "service_id", "") - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetServiceInformation(ctx context.Context, request *mcp.CallToolRequest, in getServiceInformationInput) (*mcp.CallToolResult, any, error) { + serviceID := in.ServiceID + nodeName := in.NodeName if serviceID == "" { - return mcp.NewToolResultError("service_id parameter is required"), nil + return mcp.NewToolResultError("service_id parameter is required"), nil, nil } cmd := fmt.Sprintf("service get %s", serviceID) output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get service information: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get service information: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleDeleteService(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - serviceID := mcp.ParseString(request, "service_id", "") - all := mcp.ParseString(request, "all", "") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleDeleteService(ctx context.Context, request *mcp.CallToolRequest, in deleteServiceInput) (*mcp.CallToolResult, any, error) { + serviceID := in.ServiceID + all := in.All + nodeName := in.NodeName var cmd string if all { @@ -1305,35 +1226,47 @@ func handleDeleteService(ctx context.Context, request mcp.CallToolRequest) (*mcp } else if serviceID != "" { cmd = fmt.Sprintf("service delete %s", serviceID) } else { - return mcp.NewToolResultError("either service_id or all=true must be provided"), nil + return mcp.NewToolResultError("either service_id or all=true must be provided"), nil, nil } output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to delete service: %v", err)), nil - } - return mcp.NewToolResultText(output), nil -} - -func handleUpdateService(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - backendWeights := mcp.ParseString(request, "backend_weights", "") - backends := mcp.ParseString(request, "backends", "") - frontend := mcp.ParseString(request, "frontend", "") - id := mcp.ParseString(request, "id", "") - k8sClusterInternal := mcp.ParseString(request, "k8s_cluster_internal", "") == "true" - k8sExtTrafficPolicy := mcp.ParseString(request, "k8s_ext_traffic_policy", "Cluster") - k8sExternal := mcp.ParseString(request, "k8s_external", "") == "true" - k8sHostPort := mcp.ParseString(request, "k8s_host_port", "") == "true" - k8sIntTrafficPolicy := mcp.ParseString(request, "k8s_int_traffic_policy", "Cluster") - k8sLoadBalancer := mcp.ParseString(request, "k8s_load_balancer", "") == "true" - k8sNodePort := mcp.ParseString(request, "k8s_node_port", "") == "true" - localRedirect := mcp.ParseString(request, "local_redirect", "") == "true" - protocol := mcp.ParseString(request, "protocol", "TCP") - states := mcp.ParseString(request, "states", "active") - nodeName := mcp.ParseString(request, "node_name", "") + return mcp.NewToolResultError(fmt.Sprintf("Failed to delete service: %v", err)), nil, nil + } + return mcp.NewToolResultText(output), nil, nil +} + +func handleUpdateService(ctx context.Context, request *mcp.CallToolRequest, in updateServiceInput) (*mcp.CallToolResult, any, error) { + if in.K8sExtTrafficPolicy == "" { + in.K8sExtTrafficPolicy = "Cluster" + } + if in.K8sIntTrafficPolicy == "" { + in.K8sIntTrafficPolicy = "Cluster" + } + if in.Protocol == "" { + in.Protocol = "TCP" + } + if in.States == "" { + in.States = "active" + } + backendWeights := in.BackendWeights + backends := in.Backends + frontend := in.Frontend + id := in.ID + k8sClusterInternal := in.K8sClusterInternal + k8sExtTrafficPolicy := in.K8sExtTrafficPolicy + k8sExternal := in.K8sExternal + k8sHostPort := in.K8sHostPort + k8sIntTrafficPolicy := in.K8sIntTrafficPolicy + k8sLoadBalancer := in.K8sLoadBalancer + k8sNodePort := in.K8sNodePort + localRedirect := in.LocalRedirect + protocol := in.Protocol + states := in.States + nodeName := in.NodeName if backends == "" || frontend == "" || id == "" { - return mcp.NewToolResultError("backends, frontend, and id parameters are required"), nil + return mcp.NewToolResultError("backends, frontend, and id parameters are required"), nil, nil } cmd := fmt.Sprintf("service update %s --backends %s --frontend %s --protocol %s --states %s", @@ -1369,20 +1302,20 @@ func handleUpdateService(ctx context.Context, request mcp.CallToolRequest) (*mcp output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to update service: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to update service: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } -func handleGetDaemonStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - showAllAddresses := mcp.ParseString(request, "show_all_addresses", "") == "true" - showAllClusters := mcp.ParseString(request, "show_all_clusters", "") == "true" - showAllControllers := mcp.ParseString(request, "show_all_controllers", "") == "true" - showHealth := mcp.ParseString(request, "show_health", "") == "true" - showAllNodes := mcp.ParseString(request, "show_all_nodes", "") == "true" - showAllRedirects := mcp.ParseString(request, "show_all_redirects", "") == "true" - brief := mcp.ParseString(request, "brief", "") == "true" - nodeName := mcp.ParseString(request, "node_name", "") +func handleGetDaemonStatus(ctx context.Context, request *mcp.CallToolRequest, in getDaemonStatusInput) (*mcp.CallToolResult, any, error) { + showAllAddresses := in.ShowAllAddresses + showAllClusters := in.ShowAllClusters + showAllControllers := in.ShowAllControllers + showHealth := in.ShowHealth + showAllNodes := in.ShowAllNodes + showAllRedirects := in.ShowAllRedirects + brief := in.Brief + nodeName := in.NodeName cmd := "status" if showAllAddresses { @@ -1409,7 +1342,7 @@ func handleGetDaemonStatus(ctx context.Context, request mcp.CallToolRequest) (*m output, err := runCiliumDbgCommand(ctx, cmd, nodeName) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to get daemon status: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to get daemon status: %v", err)), nil, nil } - return mcp.NewToolResultText(output), nil + return mcp.NewToolResultText(output), nil, nil } diff --git a/pkg/cilium/cilium_test.go b/pkg/cilium/cilium_test.go index 84313e77..bde7ce68 100644 --- a/pkg/cilium/cilium_test.go +++ b/pkg/cilium/cilium_test.go @@ -8,14 +8,15 @@ import ( "testing" "github.com/kagent-dev/tools/internal/cmd" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func boolPtr(b bool) *bool { return &b } + func TestRegisterCiliumTools(t *testing.T) { - s := server.NewMCPServer("test-server", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test-server", Version: "v0.0.1"}, nil) RegisterTools(s, false) // false = enable all tools including write operations // We can't directly check the tools, but we can ensure the call doesn't panic } @@ -28,22 +29,14 @@ func TestHandleCiliumStatusAndVersion(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleCiliumStatusAndVersion(ctx, mcp.CallToolRequest{}) + result, _, err := handleCiliumStatusAndVersion(ctx, &mcp.CallToolRequest{}, noInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) - var textContent mcp.TextContent - var ok bool - for _, content := range result.Content { - if textContent, ok = content.(mcp.TextContent); ok { - break - } - } - require.True(t, ok, "no text content in result") - - assert.Contains(t, textContent.Text, "Cilium status: OK") - assert.Contains(t, textContent.Text, "cilium version 1.14.0") + text := getResultText(result) + assert.Contains(t, text, "Cilium status: OK") + assert.Contains(t, text, "cilium version 1.14.0") } func TestHandleCiliumStatusAndVersionError(t *testing.T) { @@ -54,7 +47,7 @@ func TestHandleCiliumStatusAndVersionError(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleCiliumStatusAndVersion(ctx, mcp.CallToolRequest{}) + result, _, err := handleCiliumStatusAndVersion(ctx, &mcp.CallToolRequest{}, noInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -68,7 +61,7 @@ func TestHandleInstallCilium(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleInstallCilium(ctx, mcp.CallToolRequest{}) + result, _, err := handleInstallCilium(ctx, &mcp.CallToolRequest{}, installCiliumInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -82,7 +75,7 @@ func TestHandleUninstallCilium(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleUninstallCilium(ctx, mcp.CallToolRequest{}) + result, _, err := handleUninstallCilium(ctx, &mcp.CallToolRequest{}, noInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -96,7 +89,7 @@ func TestHandleUpgradeCilium(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleUpgradeCilium(ctx, mcp.CallToolRequest{}) + result, _, err := handleUpgradeCilium(ctx, &mcp.CallToolRequest{}, upgradeCiliumInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -110,15 +103,7 @@ func TestHandleConnectToRemoteCluster(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"clustermesh", "connect", "--destination-cluster", "my-cluster"}, "✓ Connected to cluster my-cluster!", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "cluster_name": "my-cluster", - }, - }, - } - - result, err := handleConnectToRemoteCluster(ctx, req) + result, _, err := handleConnectToRemoteCluster(ctx, &mcp.CallToolRequest{}, connectToRemoteClusterInput{ClusterName: "my-cluster"}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -126,12 +111,7 @@ func TestHandleConnectToRemoteCluster(t *testing.T) { }) t.Run("missing cluster_name", func(t *testing.T) { - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{}, - }, - } - result, err := handleConnectToRemoteCluster(ctx, req) + result, _, err := handleConnectToRemoteCluster(ctx, &mcp.CallToolRequest{}, connectToRemoteClusterInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -146,15 +126,7 @@ func TestHandleDisconnectFromRemoteCluster(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"clustermesh", "disconnect", "--destination-cluster", "my-cluster"}, "✓ Disconnected from cluster my-cluster!", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "cluster_name": "my-cluster", - }, - }, - } - - result, err := handleDisconnectRemoteCluster(ctx, req) + result, _, err := handleDisconnectRemoteCluster(ctx, &mcp.CallToolRequest{}, disconnectRemoteClusterInput{ClusterName: "my-cluster"}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -162,12 +134,7 @@ func TestHandleDisconnectFromRemoteCluster(t *testing.T) { }) t.Run("missing cluster_name", func(t *testing.T) { - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{}, - }, - } - result, err := handleDisconnectRemoteCluster(ctx, req) + result, _, err := handleDisconnectRemoteCluster(ctx, &mcp.CallToolRequest{}, disconnectRemoteClusterInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -180,15 +147,7 @@ func TestHandleEnableHubble(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"hubble", "enable"}, "✓ Hubble was successfully enabled!", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "enable": true, - }, - }, - } - - result, err := handleToggleHubble(ctx, req) + result, _, err := handleToggleHubble(ctx, &mcp.CallToolRequest{}, enableToggleInput{Enable: boolPtr(true)}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -200,14 +159,7 @@ func TestHandleDisableHubble(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"hubble", "disable"}, "✓ Hubble was successfully disabled!", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: map[string]any{ - "enable": false, - }, - }, - } - result, err := handleToggleHubble(ctx, req) + result, _, err := handleToggleHubble(ctx, &mcp.CallToolRequest{}, enableToggleInput{Enable: boolPtr(false)}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -219,7 +171,7 @@ func TestHandleListBGPPeers(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"bgp", "peers"}, "listing BGP peers", nil) ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleListBGPPeers(ctx, mcp.CallToolRequest{}) + result, _, err := handleListBGPPeers(ctx, &mcp.CallToolRequest{}, noInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -231,7 +183,7 @@ func TestHandleListBGPRoutes(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"bgp", "routes"}, "listing BGP routes", nil) ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleListBGPRoutes(ctx, mcp.CallToolRequest{}) + result, _, err := handleListBGPRoutes(ctx, &mcp.CallToolRequest{}, noInput{}) require.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -275,22 +227,13 @@ func mockCiliumDbgCommand(mock *cmd.MockShellExecutor, dbgArgs []string, output mock.AddCommandString("kubectl", execArgs, output, err) } -func newRequestWithArgs(args map[string]any) mcp.CallToolRequest { - return mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Arguments: args, - }, - } -} - func TestHandleGetEndpointsList(t *testing.T) { ctx := context.Background() mock := cmd.NewMockShellExecutor() mockCiliumDbgCommand(mock, []string{"endpoint", "list"}, "ENDPOINT POLICY\n34 Disabled", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleGetEndpointsList(ctx, req) + result, _, err := handleGetEndpointsList(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "ENDPOINT") @@ -302,8 +245,7 @@ func TestHandleGetEndpointDetails(t *testing.T) { mockCiliumDbgCommand(mock, []string{"endpoint", "get", "34", "-o", "json"}, `{"id": 34}`, nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"endpoint_id": "34", "node_name": "test-node"}) - result, err := handleGetEndpointDetails(ctx, req) + result, _, err := handleGetEndpointDetails(ctx, &mcp.CallToolRequest{}, getEndpointDetailsInput{EndpointID: "34", NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), `"id": 34`) @@ -315,8 +257,7 @@ func TestHandleGetEndpointLogs(t *testing.T) { mockCiliumDbgCommand(mock, []string{"endpoint", "logs", "34"}, "endpoint log output", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"endpoint_id": "34", "node_name": "test-node"}) - result, err := handleGetEndpointLogs(ctx, req) + result, _, err := handleGetEndpointLogs(ctx, &mcp.CallToolRequest{}, getEndpointLogsInput{EndpointID: "34", NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "endpoint log output") @@ -328,8 +269,7 @@ func TestHandleGetEndpointHealth(t *testing.T) { mockCiliumDbgCommand(mock, []string{"endpoint", "health", "34"}, "endpoint health OK", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"endpoint_id": "34", "node_name": "test-node"}) - result, err := handleGetEndpointHealth(ctx, req) + result, _, err := handleGetEndpointHealth(ctx, &mcp.CallToolRequest{}, getEndpointHealthInput{EndpointID: "34", NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "endpoint health OK") @@ -342,8 +282,7 @@ func TestHandleShowConfigurationOptions(t *testing.T) { mockCiliumDbgCommand(mock, []string{"config"}, "PolicyEnforcement=default", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleShowConfigurationOptions(ctx, req) + result, _, err := handleShowConfigurationOptions(ctx, &mcp.CallToolRequest{}, showConfigurationOptionsInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "PolicyEnforcement") @@ -355,8 +294,7 @@ func TestHandleShowConfigurationOptions(t *testing.T) { mockCiliumDbgCommand(mock, []string{"config", "--all"}, "all config options", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node", "list_all": "true"}) - result, err := handleShowConfigurationOptions(ctx, req) + result, _, err := handleShowConfigurationOptions(ctx, &mcp.CallToolRequest{}, showConfigurationOptionsInput{NodeName: "test-node", ListAll: true}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "all config options") @@ -368,8 +306,7 @@ func TestHandleShowConfigurationOptions(t *testing.T) { mockCiliumDbgCommand(mock, []string{"config", "-r"}, "read only config", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node", "list_read_only": "true"}) - result, err := handleShowConfigurationOptions(ctx, req) + result, _, err := handleShowConfigurationOptions(ctx, &mcp.CallToolRequest{}, showConfigurationOptionsInput{NodeName: "test-node", ListReadOnly: true}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "read only config") @@ -382,8 +319,7 @@ func TestHandleToggleConfigurationOption(t *testing.T) { mockCiliumDbgCommand(mock, []string{"config", "PolicyEnforcement=enable"}, "option toggled", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"option": "PolicyEnforcement", "value": "true", "node_name": "test-node"}) - result, err := handleToggleConfigurationOption(ctx, req) + result, _, err := handleToggleConfigurationOption(ctx, &mcp.CallToolRequest{}, toggleConfigurationOptionInput{Option: "PolicyEnforcement", Value: boolPtr(true), NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "option toggled") @@ -395,8 +331,7 @@ func TestHandleListIdentities(t *testing.T) { mockCiliumDbgCommand(mock, []string{"identity", "list"}, "ID LABELS\n1 reserved:host", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListIdentities(ctx, req) + result, _, err := handleListIdentities(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "reserved:host") @@ -408,8 +343,7 @@ func TestHandleGetDaemonStatus(t *testing.T) { mockCiliumDbgCommand(mock, []string{"status"}, "KVStore: Ok\nKubernetes: Ok", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleGetDaemonStatus(ctx, req) + result, _, err := handleGetDaemonStatus(ctx, &mcp.CallToolRequest{}, getDaemonStatusInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "KVStore: Ok") @@ -421,8 +355,7 @@ func TestHandleDisplayEncryptionState(t *testing.T) { mockCiliumDbgCommand(mock, []string{"encrypt", "status"}, "Encryption: Disabled", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleDisplayEncryptionState(ctx, req) + result, _, err := handleDisplayEncryptionState(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "Encryption: Disabled") @@ -434,8 +367,7 @@ func TestHandleShowDNSNames(t *testing.T) { mockCiliumDbgCommand(mock, []string{"fqdn", "names"}, "DNS names output", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleShowDNSNames(ctx, req) + result, _, err := handleShowDNSNames(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "DNS names output") @@ -447,8 +379,7 @@ func TestHandleFQDNCache(t *testing.T) { mockCiliumDbgCommand(mock, []string{"fqdn", "cache", "list"}, "FQDN cache entries", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleFQDNCache(ctx, req) + result, _, err := handleFQDNCache(ctx, &mcp.CallToolRequest{}, fqdnCacheInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "FQDN cache entries") @@ -460,8 +391,7 @@ func TestHandleListClusterNodes(t *testing.T) { mockCiliumDbgCommand(mock, []string{"node", "list"}, "Name IPv4 Address\nnode1 10.0.0.1", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListClusterNodes(ctx, req) + result, _, err := handleListClusterNodes(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "node1") @@ -473,8 +403,7 @@ func TestHandleListNodeIds(t *testing.T) { mockCiliumDbgCommand(mock, []string{"nodeid", "list"}, "ID IP\n1 10.0.0.1", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListNodeIds(ctx, req) + result, _, err := handleListNodeIds(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "10.0.0.1") @@ -486,8 +415,7 @@ func TestHandleListBPFMaps(t *testing.T) { mockCiliumDbgCommand(mock, []string{"map", "list"}, "Name Num entries\ncilium_lb4 22", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListBPFMaps(ctx, req) + result, _, err := handleListBPFMaps(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "cilium_lb4") @@ -499,8 +427,7 @@ func TestHandleGetBPFMap(t *testing.T) { mockCiliumDbgCommand(mock, []string{"map", "get", "cilium_lb4"}, "map contents", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"map_name": "cilium_lb4", "node_name": "test-node"}) - result, err := handleGetBPFMap(ctx, req) + result, _, err := handleGetBPFMap(ctx, &mcp.CallToolRequest{}, bpfMapInput{MapName: "cilium_lb4", NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "map contents") @@ -512,8 +439,7 @@ func TestHandleListBPFMapEvents(t *testing.T) { mockCiliumDbgCommand(mock, []string{"map", "events", "cilium_lb4"}, "map events", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"map_name": "cilium_lb4", "node_name": "test-node"}) - result, err := handleListBPFMapEvents(ctx, req) + result, _, err := handleListBPFMapEvents(ctx, &mcp.CallToolRequest{}, bpfMapInput{MapName: "cilium_lb4", NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "map events") @@ -525,8 +451,7 @@ func TestHandleListMetrics(t *testing.T) { mockCiliumDbgCommand(mock, []string{"metrics", "list"}, "Metric Value\ncilium_endpoint_count 4", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListMetrics(ctx, req) + result, _, err := handleListMetrics(ctx, &mcp.CallToolRequest{}, listMetricsInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "cilium_endpoint_count") @@ -538,8 +463,7 @@ func TestHandleListServices(t *testing.T) { mockCiliumDbgCommand(mock, []string{"service", "list"}, "ID Frontend\n1 10.96.0.1:443", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListServices(ctx, req) + result, _, err := handleListServices(ctx, &mcp.CallToolRequest{}, listServicesInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "10.96.0.1") @@ -551,8 +475,7 @@ func TestHandleListIPAddresses(t *testing.T) { mockCiliumDbgCommand(mock, []string{"ip", "list"}, "IP Identity\n10.0.0.1 1", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListIPAddresses(ctx, req) + result, _, err := handleListIPAddresses(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "10.0.0.1") @@ -564,8 +487,7 @@ func TestHandleDisplaySelectors(t *testing.T) { mockCiliumDbgCommand(mock, []string{"policy", "selectors"}, "SELECTOR IDENTITIES", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleDisplaySelectors(ctx, req) + result, _, err := handleDisplaySelectors(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "SELECTOR") @@ -577,8 +499,7 @@ func TestHandleListLocalRedirectPolicies(t *testing.T) { mockCiliumDbgCommand(mock, []string{"lrp", "list"}, "No local redirect policies", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListLocalRedirectPolicies(ctx, req) + result, _, err := handleListLocalRedirectPolicies(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "No local redirect policies") @@ -590,8 +511,7 @@ func TestHandleRequestDebuggingInformation(t *testing.T) { mockCiliumDbgCommand(mock, []string{"debuginfo"}, "debug info output", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleRequestDebuggingInformation(ctx, req) + result, _, err := handleRequestDebuggingInformation(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "debug info output") @@ -603,8 +523,7 @@ func TestHandleListXDPCIDRFilters(t *testing.T) { mockCiliumDbgCommand(mock, []string{"prefilter", "list"}, "CIDR filters", nil) ctx = cmd.WithShellExecutor(ctx, mock) - req := newRequestWithArgs(map[string]any{"node_name": "test-node"}) - result, err := handleListXDPCIDRFilters(ctx, req) + result, _, err := handleListXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "CIDR filters") @@ -614,52 +533,136 @@ func getResultText(r *mcp.CallToolResult) string { if r == nil || len(r.Content) == 0 { return "" } - if textContent, ok := r.Content[0].(mcp.TextContent); ok { + if textContent, ok := r.Content[0].(*mcp.TextContent); ok { return strings.TrimSpace(textContent.Text) } return "" } -type ciliumHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) - // TestCiliumDbgHandlers exercises the success path of every cilium-dbg based handler. func TestCiliumDbgHandlers(t *testing.T) { cases := []struct { name string - handler ciliumHandler - args map[string]any dbgArgs []string expect string + run func(context.Context) (*mcp.CallToolResult, error) }{ - {"manage_endpoint_labels", handleManageEndpointLabels, map[string]any{"endpoint_id": "34", "labels": "key=val"}, []string{"endpoint", "labels", "34", "--add", "key=val"}, "ok"}, - {"manage_endpoint_configuration", handleManageEndpointConfiguration, map[string]any{"endpoint_id": "34", "config": "Debug=true"}, []string{"endpoint", "config", "34", "Debug=true"}, "ok"}, - {"disconnect_endpoint", handleDisconnectEndpoint, map[string]any{"endpoint_id": "34"}, []string{"endpoint", "disconnect", "34"}, "ok"}, - {"get_identity_details", handleGetIdentityDetails, map[string]any{"identity_id": "123"}, []string{"identity", "get", "123"}, "ok"}, - {"flush_ipsec_state", handleFlushIPsecState, map[string]any{}, []string{"encrypt", "flush", "-f"}, "ok"}, - {"list_envoy_config", handleListEnvoyConfig, map[string]any{"resource_name": "clusters"}, []string{"envoy", "admin", "clusters"}, "ok"}, - {"show_ipcache_cidr", handleShowIPCacheInformation, map[string]any{"cidr": "10.0.0.0/24"}, []string{"ip", "get", "10.0.0.0/24"}, "ok"}, - {"show_ipcache_labels", handleShowIPCacheInformation, map[string]any{"labels": "app=foo"}, []string{"ip", "get", "--labels", "app=foo"}, "ok"}, - {"delete_kvstore_key", handleDeleteKeyFromKVStore, map[string]any{"key": "foo"}, []string{"kvstore", "delete", "foo"}, "ok"}, - {"get_kvstore_key", handleGetKVStoreKey, map[string]any{"key": "foo"}, []string{"kvstore", "get", "foo"}, "ok"}, - {"set_kvstore_key", handleSetKVStoreKey, map[string]any{"key": "foo", "value": "bar"}, []string{"kvstore", "set", "foo=bar"}, "ok"}, - {"show_load_information", handleShowLoadInformation, map[string]any{}, []string{"loadinfo"}, "ok"}, - {"display_policy_node_info", handleDisplayPolicyNodeInformation, map[string]any{}, []string{"policy", "get"}, "ok"}, - {"display_policy_node_info_labels", handleDisplayPolicyNodeInformation, map[string]any{"labels": "k=v"}, []string{"policy", "get", "k=v"}, "ok"}, - {"delete_policy_rules_all", handleDeletePolicyRules, map[string]any{"all": "true"}, []string{"policy", "delete", "--all"}, "ok"}, - {"delete_policy_rules_labels", handleDeletePolicyRules, map[string]any{"labels": "k=v"}, []string{"policy", "delete", "k=v"}, "ok"}, - {"update_xdp_cidr", handleUpdateXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8"}, []string{"prefilter", "update", "--cidr", "10.0.0.0/8"}, "ok"}, - {"update_xdp_cidr_rev", handleUpdateXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8", "revision": "2"}, []string{"prefilter", "update", "--cidr", "10.0.0.0/8", "--revision", "2"}, "ok"}, - {"delete_xdp_cidr", handleDeleteXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8"}, []string{"prefilter", "delete", "--cidr", "10.0.0.0/8"}, "ok"}, - {"delete_xdp_cidr_rev", handleDeleteXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8", "revision": "2"}, []string{"prefilter", "delete", "--cidr", "10.0.0.0/8", "--revision", "2"}, "ok"}, - {"validate_cnp", handleValidateCiliumNetworkPolicies, map[string]any{"enable_k8s": "true", "enable_k8s_api_discovery": "true"}, []string{"preflight", "validate-cnp", "--enable-k8s", "--enable-k8s-api-discovery"}, "ok"}, - {"list_pcap_recorders", handleListPCAPRecorders, map[string]any{}, []string{"recorder", "list"}, "ok"}, - {"get_pcap_recorder", handleGetPCAPRecorder, map[string]any{"recorder_id": "1"}, []string{"recorder", "get", "1"}, "ok"}, - {"delete_pcap_recorder", handleDeletePCAPRecorder, map[string]any{"recorder_id": "1"}, []string{"recorder", "delete", "1"}, "ok"}, - {"update_pcap_recorder", handleUpdatePCAPRecorder, map[string]any{"recorder_id": "1", "filters": "f"}, []string{"recorder", "update", "1", "--filters", "f", "--caplen", "0", "--id", "0"}, "ok"}, - {"get_service_information", handleGetServiceInformation, map[string]any{"service_id": "5"}, []string{"service", "get", "5"}, "ok"}, - {"delete_service_all", handleDeleteService, map[string]any{"all": "true"}, []string{"service", "delete", "--all"}, "ok"}, - {"delete_service_id", handleDeleteService, map[string]any{"service_id": "5"}, []string{"service", "delete", "5"}, "ok"}, - {"update_service", handleUpdateService, map[string]any{"backends": "b", "frontend": "f", "id": "1"}, []string{"service", "update", "1", "--backends", "b", "--frontend", "f", "--protocol", "TCP", "--states", "active"}, "ok"}, + {"manage_endpoint_labels", []string{"endpoint", "labels", "34", "--add", "key=val"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleManageEndpointLabels(ctx, &mcp.CallToolRequest{}, manageEndpointLabelsInput{EndpointID: "34", Labels: "key=val", NodeName: "test-node"}) + return r, err + }}, + {"manage_endpoint_configuration", []string{"endpoint", "config", "34", "Debug=true"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleManageEndpointConfiguration(ctx, &mcp.CallToolRequest{}, manageEndpointConfigurationInput{EndpointID: "34", Config: "Debug=true", NodeName: "test-node"}) + return r, err + }}, + {"disconnect_endpoint", []string{"endpoint", "disconnect", "34"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDisconnectEndpoint(ctx, &mcp.CallToolRequest{}, disconnectEndpointInput{EndpointID: "34", NodeName: "test-node"}) + return r, err + }}, + {"get_identity_details", []string{"identity", "get", "123"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetIdentityDetails(ctx, &mcp.CallToolRequest{}, getIdentityDetailsInput{IdentityID: "123", NodeName: "test-node"}) + return r, err + }}, + {"flush_ipsec_state", []string{"encrypt", "flush", "-f"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleFlushIPsecState(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) + return r, err + }}, + {"list_envoy_config", []string{"envoy", "admin", "clusters"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleListEnvoyConfig(ctx, &mcp.CallToolRequest{}, listEnvoyConfigInput{ResourceName: "clusters", NodeName: "test-node"}) + return r, err + }}, + {"show_ipcache_cidr", []string{"ip", "get", "10.0.0.0/24"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleShowIPCacheInformation(ctx, &mcp.CallToolRequest{}, showIPCacheInformationInput{CIDR: "10.0.0.0/24", NodeName: "test-node"}) + return r, err + }}, + {"show_ipcache_labels", []string{"ip", "get", "--labels", "app=foo"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleShowIPCacheInformation(ctx, &mcp.CallToolRequest{}, showIPCacheInformationInput{Labels: "app=foo", NodeName: "test-node"}) + return r, err + }}, + {"delete_kvstore_key", []string{"kvstore", "delete", "foo"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteKeyFromKVStore(ctx, &mcp.CallToolRequest{}, kvStoreKeyInput{Key: "foo", NodeName: "test-node"}) + return r, err + }}, + {"get_kvstore_key", []string{"kvstore", "get", "foo"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetKVStoreKey(ctx, &mcp.CallToolRequest{}, kvStoreKeyInput{Key: "foo", NodeName: "test-node"}) + return r, err + }}, + {"set_kvstore_key", []string{"kvstore", "set", "foo=bar"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleSetKVStoreKey(ctx, &mcp.CallToolRequest{}, setKVStoreKeyInput{Key: "foo", Value: "bar", NodeName: "test-node"}) + return r, err + }}, + {"show_load_information", []string{"loadinfo"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleShowLoadInformation(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) + return r, err + }}, + {"display_policy_node_info", []string{"policy", "get"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDisplayPolicyNodeInformation(ctx, &mcp.CallToolRequest{}, displayPolicyNodeInformationInput{NodeName: "test-node"}) + return r, err + }}, + {"display_policy_node_info_labels", []string{"policy", "get", "k=v"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDisplayPolicyNodeInformation(ctx, &mcp.CallToolRequest{}, displayPolicyNodeInformationInput{Labels: "k=v", NodeName: "test-node"}) + return r, err + }}, + {"delete_policy_rules_all", []string{"policy", "delete", "--all"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeletePolicyRules(ctx, &mcp.CallToolRequest{}, deletePolicyRulesInput{All: true, NodeName: "test-node"}) + return r, err + }}, + {"delete_policy_rules_labels", []string{"policy", "delete", "k=v"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeletePolicyRules(ctx, &mcp.CallToolRequest{}, deletePolicyRulesInput{Labels: "k=v", NodeName: "test-node"}) + return r, err + }}, + {"update_xdp_cidr", []string{"prefilter", "update", "--cidr", "10.0.0.0/8"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdateXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, xdpCIDRFiltersInput{CIDRPrefixes: "10.0.0.0/8", NodeName: "test-node"}) + return r, err + }}, + {"update_xdp_cidr_rev", []string{"prefilter", "update", "--cidr", "10.0.0.0/8", "--revision", "2"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdateXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, xdpCIDRFiltersInput{CIDRPrefixes: "10.0.0.0/8", Revision: "2", NodeName: "test-node"}) + return r, err + }}, + {"delete_xdp_cidr", []string{"prefilter", "delete", "--cidr", "10.0.0.0/8"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, xdpCIDRFiltersInput{CIDRPrefixes: "10.0.0.0/8", NodeName: "test-node"}) + return r, err + }}, + {"delete_xdp_cidr_rev", []string{"prefilter", "delete", "--cidr", "10.0.0.0/8", "--revision", "2"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, xdpCIDRFiltersInput{CIDRPrefixes: "10.0.0.0/8", Revision: "2", NodeName: "test-node"}) + return r, err + }}, + {"validate_cnp", []string{"preflight", "validate-cnp", "--enable-k8s", "--enable-k8s-api-discovery"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleValidateCiliumNetworkPolicies(ctx, &mcp.CallToolRequest{}, validateCiliumNetworkPoliciesInput{EnableK8s: true, EnableK8sAPIDiscovery: true, NodeName: "test-node"}) + return r, err + }}, + {"list_pcap_recorders", []string{"recorder", "list"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleListPCAPRecorders(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) + return r, err + }}, + {"get_pcap_recorder", []string{"recorder", "get", "1"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetPCAPRecorder(ctx, &mcp.CallToolRequest{}, pcapRecorderIDInput{RecorderID: "1", NodeName: "test-node"}) + return r, err + }}, + {"delete_pcap_recorder", []string{"recorder", "delete", "1"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeletePCAPRecorder(ctx, &mcp.CallToolRequest{}, pcapRecorderIDInput{RecorderID: "1", NodeName: "test-node"}) + return r, err + }}, + {"update_pcap_recorder", []string{"recorder", "update", "1", "--filters", "f", "--caplen", "0", "--id", "0"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdatePCAPRecorder(ctx, &mcp.CallToolRequest{}, updatePCAPRecorderInput{RecorderID: "1", Filters: "f", NodeName: "test-node"}) + return r, err + }}, + {"get_service_information", []string{"service", "get", "5"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetServiceInformation(ctx, &mcp.CallToolRequest{}, getServiceInformationInput{ServiceID: "5", NodeName: "test-node"}) + return r, err + }}, + {"delete_service_all", []string{"service", "delete", "--all"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteService(ctx, &mcp.CallToolRequest{}, deleteServiceInput{All: true, NodeName: "test-node"}) + return r, err + }}, + {"delete_service_id", []string{"service", "delete", "5"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteService(ctx, &mcp.CallToolRequest{}, deleteServiceInput{ServiceID: "5", NodeName: "test-node"}) + return r, err + }}, + {"update_service", []string{"service", "update", "1", "--backends", "b", "--frontend", "f", "--protocol", "TCP", "--states", "active"}, "ok", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdateService(ctx, &mcp.CallToolRequest{}, updateServiceInput{Backends: "b", Frontend: "f", ID: "1", NodeName: "test-node"}) + return r, err + }}, } for _, tc := range cases { @@ -668,8 +671,7 @@ func TestCiliumDbgHandlers(t *testing.T) { mockCiliumDbgCommand(mock, tc.dbgArgs, tc.expect, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - tc.args["node_name"] = "test-node" - result, err := tc.handler(ctx, newRequestWithArgs(tc.args)) + result, err := tc.run(ctx) require.NoError(t, err) assert.False(t, result.IsError, "handler returned error result: %s", getResultText(result)) assert.Contains(t, getResultText(result), tc.expect) @@ -680,36 +682,92 @@ func TestCiliumDbgHandlers(t *testing.T) { // TestCiliumDbgHandlersMissingParams covers required-parameter validation branches. func TestCiliumDbgHandlersMissingParams(t *testing.T) { cases := []struct { - name string - handler ciliumHandler - args map[string]any + name string + run func(context.Context) (*mcp.CallToolResult, error) }{ - {"manage_endpoint_labels", handleManageEndpointLabels, map[string]any{}}, - {"manage_endpoint_configuration_no_id", handleManageEndpointConfiguration, map[string]any{}}, - {"manage_endpoint_configuration_no_config", handleManageEndpointConfiguration, map[string]any{"endpoint_id": "34"}}, - {"disconnect_endpoint", handleDisconnectEndpoint, map[string]any{}}, - {"get_identity_details", handleGetIdentityDetails, map[string]any{}}, - {"list_envoy_config", handleListEnvoyConfig, map[string]any{}}, - {"show_ipcache_none", handleShowIPCacheInformation, map[string]any{}}, - {"delete_kvstore_key", handleDeleteKeyFromKVStore, map[string]any{}}, - {"get_kvstore_key", handleGetKVStoreKey, map[string]any{}}, - {"set_kvstore_key", handleSetKVStoreKey, map[string]any{"key": "foo"}}, - {"delete_policy_rules_none", handleDeletePolicyRules, map[string]any{}}, - {"update_xdp_cidr", handleUpdateXDPCIDRFilters, map[string]any{}}, - {"delete_xdp_cidr", handleDeleteXDPCIDRFilters, map[string]any{}}, - {"get_pcap_recorder", handleGetPCAPRecorder, map[string]any{}}, - {"delete_pcap_recorder", handleDeletePCAPRecorder, map[string]any{}}, - {"update_pcap_recorder", handleUpdatePCAPRecorder, map[string]any{"recorder_id": "1"}}, - {"get_service_information", handleGetServiceInformation, map[string]any{}}, - {"delete_service_none", handleDeleteService, map[string]any{}}, - {"update_service", handleUpdateService, map[string]any{"backends": "b"}}, + {"manage_endpoint_labels", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleManageEndpointLabels(ctx, &mcp.CallToolRequest{}, manageEndpointLabelsInput{}) + return r, err + }}, + {"manage_endpoint_configuration_no_id", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleManageEndpointConfiguration(ctx, &mcp.CallToolRequest{}, manageEndpointConfigurationInput{}) + return r, err + }}, + {"manage_endpoint_configuration_no_config", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleManageEndpointConfiguration(ctx, &mcp.CallToolRequest{}, manageEndpointConfigurationInput{EndpointID: "34"}) + return r, err + }}, + {"disconnect_endpoint", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDisconnectEndpoint(ctx, &mcp.CallToolRequest{}, disconnectEndpointInput{}) + return r, err + }}, + {"get_identity_details", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetIdentityDetails(ctx, &mcp.CallToolRequest{}, getIdentityDetailsInput{}) + return r, err + }}, + {"list_envoy_config", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleListEnvoyConfig(ctx, &mcp.CallToolRequest{}, listEnvoyConfigInput{}) + return r, err + }}, + {"show_ipcache_none", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleShowIPCacheInformation(ctx, &mcp.CallToolRequest{}, showIPCacheInformationInput{}) + return r, err + }}, + {"delete_kvstore_key", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteKeyFromKVStore(ctx, &mcp.CallToolRequest{}, kvStoreKeyInput{}) + return r, err + }}, + {"get_kvstore_key", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetKVStoreKey(ctx, &mcp.CallToolRequest{}, kvStoreKeyInput{}) + return r, err + }}, + {"set_kvstore_key", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleSetKVStoreKey(ctx, &mcp.CallToolRequest{}, setKVStoreKeyInput{Key: "foo"}) + return r, err + }}, + {"delete_policy_rules_none", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeletePolicyRules(ctx, &mcp.CallToolRequest{}, deletePolicyRulesInput{}) + return r, err + }}, + {"update_xdp_cidr", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdateXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, xdpCIDRFiltersInput{}) + return r, err + }}, + {"delete_xdp_cidr", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteXDPCIDRFilters(ctx, &mcp.CallToolRequest{}, xdpCIDRFiltersInput{}) + return r, err + }}, + {"get_pcap_recorder", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetPCAPRecorder(ctx, &mcp.CallToolRequest{}, pcapRecorderIDInput{}) + return r, err + }}, + {"delete_pcap_recorder", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeletePCAPRecorder(ctx, &mcp.CallToolRequest{}, pcapRecorderIDInput{}) + return r, err + }}, + {"update_pcap_recorder", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdatePCAPRecorder(ctx, &mcp.CallToolRequest{}, updatePCAPRecorderInput{RecorderID: "1"}) + return r, err + }}, + {"get_service_information", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleGetServiceInformation(ctx, &mcp.CallToolRequest{}, getServiceInformationInput{}) + return r, err + }}, + {"delete_service_none", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleDeleteService(ctx, &mcp.CallToolRequest{}, deleteServiceInput{}) + return r, err + }}, + {"update_service", func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleUpdateService(ctx, &mcp.CallToolRequest{}, updateServiceInput{Backends: "b"}) + return r, err + }}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := tc.handler(ctx, newRequestWithArgs(tc.args)) + result, err := tc.run(ctx) require.NoError(t, err) assert.True(t, result.IsError) assert.Empty(t, mock.GetCallLog()) @@ -721,14 +779,25 @@ func TestCiliumDbgHandlersMissingParams(t *testing.T) { func TestCiliumCliHandlers(t *testing.T) { cases := []struct { name string - handler ciliumHandler - args map[string]any cliArgs []string + run func(context.Context) (*mcp.CallToolResult, error) }{ - {"show_cluster_mesh_status", handleShowClusterMeshStatus, map[string]any{}, []string{"clustermesh", "status"}}, - {"show_features_status", handleShowFeaturesStatus, map[string]any{}, []string{"features", "status"}}, - {"toggle_cluster_mesh_enable", handleToggleClusterMesh, map[string]any{"enable": "true"}, []string{"clustermesh", "enable"}}, - {"toggle_cluster_mesh_disable", handleToggleClusterMesh, map[string]any{"enable": "false"}, []string{"clustermesh", "disable"}}, + {"show_cluster_mesh_status", []string{"clustermesh", "status"}, func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleShowClusterMeshStatus(ctx, &mcp.CallToolRequest{}, noInput{}) + return r, err + }}, + {"show_features_status", []string{"features", "status"}, func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleShowFeaturesStatus(ctx, &mcp.CallToolRequest{}, noInput{}) + return r, err + }}, + {"toggle_cluster_mesh_enable", []string{"clustermesh", "enable"}, func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleToggleClusterMesh(ctx, &mcp.CallToolRequest{}, enableToggleInput{Enable: boolPtr(true)}) + return r, err + }}, + {"toggle_cluster_mesh_disable", []string{"clustermesh", "disable"}, func(ctx context.Context) (*mcp.CallToolResult, error) { + r, _, err := handleToggleClusterMesh(ctx, &mcp.CallToolRequest{}, enableToggleInput{Enable: boolPtr(false)}) + return r, err + }}, } for _, tc := range cases { @@ -736,7 +805,7 @@ func TestCiliumCliHandlers(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", tc.cliArgs, "cli-ok", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := tc.handler(ctx, newRequestWithArgs(tc.args)) + result, err := tc.run(ctx) require.NoError(t, err) assert.False(t, result.IsError) assert.Contains(t, getResultText(result), "cli-ok") @@ -748,7 +817,7 @@ func TestCiliumCliHandlersError(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("cilium", []string{"clustermesh", "status"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleShowClusterMeshStatus(ctx, newRequestWithArgs(map[string]any{})) + result, _, err := handleShowClusterMeshStatus(ctx, &mcp.CallToolRequest{}, noInput{}) require.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "Error getting cluster mesh status") @@ -759,7 +828,7 @@ func TestCiliumDbgHandlerError(t *testing.T) { mock := cmd.NewMockShellExecutor() mockCiliumDbgCommand(mock, []string{"loadinfo"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleShowLoadInformation(ctx, newRequestWithArgs(map[string]any{"node_name": "test-node"})) + result, _, err := handleShowLoadInformation(ctx, &mcp.CallToolRequest{}, nodeNameInput{NodeName: "test-node"}) require.NoError(t, err) assert.True(t, result.IsError) } diff --git a/pkg/helm/helm.go b/pkg/helm/helm.go index c8a6b917..07ea602f 100644 --- a/pkg/helm/helm.go +++ b/pkg/helm/helm.go @@ -8,66 +8,71 @@ import ( "github.com/kagent-dev/tools/internal/commands" "github.com/kagent-dev/tools/internal/errors" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/internal/security" - "github.com/kagent-dev/tools/internal/telemetry" "github.com/kagent-dev/tools/pkg/utils" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" ) -// Helm list releases -func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - allNamespaces := mcp.ParseString(request, "all_namespaces", "") == "true" - all := mcp.ParseString(request, "all", "") == "true" - uninstalled := mcp.ParseString(request, "uninstalled", "") == "true" - uninstalling := mcp.ParseString(request, "uninstalling", "") == "true" - failed := mcp.ParseString(request, "failed", "") == "true" - deployed := mcp.ParseString(request, "deployed", "") == "true" - pending := mcp.ParseString(request, "pending", "") == "true" - filter := mcp.ParseString(request, "filter", "") - output := mcp.ParseString(request, "output", "") +// toolErrorResult formats a ToolError as an MCP error result. +func toolErrorResult(toolErr *errors.ToolError) *mcp.CallToolResult { + return toolErr.ToMCPResult() +} +type helmListReleasesInput struct { + Namespace string `json:"namespace" jsonschema:"The namespace to list releases from"` + AllNamespaces bool `json:"all_namespaces" jsonschema:"List releases from all namespaces"` + All bool `json:"all" jsonschema:"Show all releases without any filter applied"` + Uninstalled bool `json:"uninstalled" jsonschema:"List uninstalled releases"` + Uninstalling bool `json:"uninstalling" jsonschema:"List uninstalling releases"` + Failed bool `json:"failed" jsonschema:"List failed releases"` + Deployed bool `json:"deployed" jsonschema:"List deployed releases"` + Pending bool `json:"pending" jsonschema:"List pending releases"` + Filter string `json:"filter" jsonschema:"A regular expression to filter releases by"` + Output string `json:"output" jsonschema:"The output format (e.g., 'json', 'yaml', 'table')"` +} + +// Helm list releases +func handleHelmListReleases(ctx context.Context, request *mcp.CallToolRequest, in helmListReleasesInput) (*mcp.CallToolResult, any, error) { args := []string{"list"} - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - if allNamespaces { + if in.AllNamespaces { args = append(args, "-A") } - if all { + if in.All { args = append(args, "-a") } - if uninstalled { + if in.Uninstalled { args = append(args, "--uninstalled") } - if uninstalling { + if in.Uninstalling { args = append(args, "--uninstalling") } - if failed { + if in.Failed { args = append(args, "--failed") } - if deployed { + if in.Deployed { args = append(args, "--deployed") } - if pending { + if in.Pending { args = append(args, "--pending") } - if filter != "" { - args = append(args, "-f", filter) + if in.Filter != "" { + args = append(args, "-f", in.Filter) } - if output != "" { - args = append(args, "-o", output) + if in.Output != "" { + args = append(args, "-o", in.Output) } result, err := runHelmCommand(ctx, args) @@ -75,16 +80,16 @@ func handleHelmListReleases(ctx context.Context, request mcp.CallToolRequest) (* // Check if it's a structured error if toolErr, ok := err.(*errors.ToolError); ok { // Add namespace context if provided - if namespace != "" { - toolErr = toolErr.WithContext("namespace", namespace) + if in.Namespace != "" { + toolErr = toolErr.WithContext("namespace", in.Namespace) } - return toolErr.ToMCPResult(), nil + return toolErrorResult(toolErr), nil, nil } // Fallback for non-structured errors - return mcp.NewToolResultError(fmt.Sprintf("Helm list command failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Helm list command failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } func runHelmCommand(ctx context.Context, args []string) (string, error) { @@ -116,232 +121,224 @@ func runHelmCommand(ctx context.Context, args []string) (string, error) { return result, nil } +type helmGetReleaseInput struct { + Name string `json:"name" jsonschema:"The name of the release"` + Namespace string `json:"namespace" jsonschema:"The namespace of the release"` + Resource string `json:"resource" jsonschema:"The resource to get (all, hooks, manifest, notes, values)"` +} + // Helm get release -func handleHelmGetRelease(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := mcp.ParseString(request, "name", "") - namespace := mcp.ParseString(request, "namespace", "") - resource := mcp.ParseString(request, "resource", "all") +func handleHelmGetRelease(ctx context.Context, request *mcp.CallToolRequest, in helmGetReleaseInput) (*mcp.CallToolResult, any, error) { + if in.Resource == "" { + in.Resource = "all" + } - if name == "" { - return mcp.NewToolResultError("name parameter is required"), nil + if in.Name == "" { + return mcp.NewToolResultError("name parameter is required"), nil, nil } - if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil + if in.Namespace == "" { + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } - args := []string{"get", resource, name, "-n", namespace} + args := []string{"get", in.Resource, in.Name, "-n", in.Namespace} result, err := runHelmCommand(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Helm get command failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Helm get command failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Helm upgrade release -func handleHelmUpgradeRelease(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := mcp.ParseString(request, "name", "") - chart := mcp.ParseString(request, "chart", "") - namespace := mcp.ParseString(request, "namespace", "") - version := mcp.ParseString(request, "version", "") - values := mcp.ParseString(request, "values", "") - setValues := mcp.ParseString(request, "set", "") - install := mcp.ParseString(request, "install", "") == "true" - dryRun := mcp.ParseString(request, "dry_run", "") == "true" - wait := mcp.ParseString(request, "wait", "") == "true" +type helmUpgradeReleaseInput struct { + Name string `json:"name" jsonschema:"The name of the release"` + Chart string `json:"chart" jsonschema:"The chart to install or upgrade to"` + Namespace string `json:"namespace" jsonschema:"The namespace of the release"` + Version string `json:"version" jsonschema:"The version of the chart to upgrade to"` + Values string `json:"values" jsonschema:"Path to a values file"` + Set string `json:"set" jsonschema:"Set values on the command line (e.g., 'key1=val1,key2=val2')"` + Install bool `json:"install" jsonschema:"Run an install if the release is not present"` + DryRun bool `json:"dry_run" jsonschema:"Simulate an upgrade"` + Wait bool `json:"wait" jsonschema:"Wait for the upgrade to complete"` +} - if name == "" || chart == "" { - return mcp.NewToolResultError("name and chart parameters are required"), nil +// Helm upgrade release +func handleHelmUpgradeRelease(ctx context.Context, request *mcp.CallToolRequest, in helmUpgradeReleaseInput) (*mcp.CallToolResult, any, error) { + if in.Name == "" || in.Chart == "" { + return mcp.NewToolResultError("name and chart parameters are required"), nil, nil } // Validate release name - if err := security.ValidateHelmReleaseName(name); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid release name: %v", err)), nil + if err := security.ValidateHelmReleaseName(in.Name); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid release name: %v", err)), nil, nil } // Validate namespace if provided - if namespace != "" { - if err := security.ValidateNamespace(namespace); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + if in.Namespace != "" { + if err := security.ValidateNamespace(in.Namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil, nil } } // Validate values file path if provided - if values != "" { - if err := security.ValidateFilePath(values); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid values file path: %v", err)), nil + if in.Values != "" { + if err := security.ValidateFilePath(in.Values); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid values file path: %v", err)), nil, nil } } - args := []string{"upgrade", name, chart} + args := []string{"upgrade", in.Name, in.Chart} - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - if version != "" { - args = append(args, "--version", version) + if in.Version != "" { + args = append(args, "--version", in.Version) } - if values != "" { - args = append(args, "-f", values) + if in.Values != "" { + args = append(args, "-f", in.Values) } - if setValues != "" { + if in.Set != "" { // Split multiple set values by comma - setValuesList := strings.Split(setValues, ",") + setValuesList := strings.Split(in.Set, ",") for _, setValue := range setValuesList { args = append(args, "--set", strings.TrimSpace(setValue)) } } - if install { + if in.Install { args = append(args, "--install") } - if dryRun { + if in.DryRun { args = append(args, "--dry-run") } - if wait { + if in.Wait { args = append(args, "--wait") } result, err := runHelmCommand(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Helm upgrade command failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Helm upgrade command failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Helm uninstall release -func handleHelmUninstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := mcp.ParseString(request, "name", "") - namespace := mcp.ParseString(request, "namespace", "") - dryRun := mcp.ParseString(request, "dry_run", "") == "true" - wait := mcp.ParseString(request, "wait", "") == "true" +type helmUninstallInput struct { + Name string `json:"name" jsonschema:"The name of the release to uninstall"` + Namespace string `json:"namespace" jsonschema:"The namespace of the release"` + DryRun bool `json:"dry_run" jsonschema:"Simulate an uninstall"` + Wait bool `json:"wait" jsonschema:"Wait for the uninstall to complete"` +} - if name == "" || namespace == "" { - return mcp.NewToolResultError("name and namespace parameters are required"), nil +// Helm uninstall release +func handleHelmUninstall(ctx context.Context, request *mcp.CallToolRequest, in helmUninstallInput) (*mcp.CallToolResult, any, error) { + if in.Name == "" || in.Namespace == "" { + return mcp.NewToolResultError("name and namespace parameters are required"), nil, nil } - args := []string{"uninstall", name, "-n", namespace} + args := []string{"uninstall", in.Name, "-n", in.Namespace} - if dryRun { + if in.DryRun { args = append(args, "--dry-run") } - if wait { + if in.Wait { args = append(args, "--wait") } result, err := runHelmCommand(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Helm uninstall command failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Helm uninstall command failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Helm repo add -func handleHelmRepoAdd(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := mcp.ParseString(request, "name", "") - url := mcp.ParseString(request, "url", "") +type helmRepoAddInput struct { + Name string `json:"name" jsonschema:"The name of the repository"` + URL string `json:"url" jsonschema:"The URL of the repository"` +} - if name == "" || url == "" { - return mcp.NewToolResultError("name and url parameters are required"), nil +// Helm repo add +func handleHelmRepoAdd(ctx context.Context, request *mcp.CallToolRequest, in helmRepoAddInput) (*mcp.CallToolResult, any, error) { + if in.Name == "" || in.URL == "" { + return mcp.NewToolResultError("name and url parameters are required"), nil, nil } // Validate repository name - if err := security.ValidateHelmReleaseName(name); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid repository name: %v", err)), nil + if err := security.ValidateHelmReleaseName(in.Name); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid repository name: %v", err)), nil, nil } // Validate repository URL - if err := security.ValidateURL(url); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid repository URL: %v", err)), nil + if err := security.ValidateURL(in.URL); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid repository URL: %v", err)), nil, nil } - args := []string{"repo", "add", name, url} + args := []string{"repo", "add", in.Name, in.URL} result, err := runHelmCommand(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Helm repo add command failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Helm repo add command failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } +type helmRepoUpdateInput struct{} + // Helm repo update -func handleHelmRepoUpdate(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleHelmRepoUpdate(ctx context.Context, request *mcp.CallToolRequest, in helmRepoUpdateInput) (*mcp.CallToolResult, any, error) { args := []string{"repo", "update"} result, err := runHelmCommand(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Helm repo update command failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Helm repo update command failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } // Register Helm tools -func RegisterTools(s *server.MCPServer, readOnly bool) { +func RegisterTools(s *mcp.Server, readOnly bool) { // Read-only tools - always registered - s.AddTool(mcp.NewTool("helm_list_releases", - mcp.WithDescription("List Helm releases in a namespace"), - mcp.WithString("namespace", mcp.Description("The namespace to list releases from")), - mcp.WithString("all_namespaces", mcp.Description("List releases from all namespaces")), - mcp.WithString("all", mcp.Description("Show all releases without any filter applied")), - mcp.WithString("uninstalled", mcp.Description("List uninstalled releases")), - mcp.WithString("uninstalling", mcp.Description("List uninstalling releases")), - mcp.WithString("failed", mcp.Description("List failed releases")), - mcp.WithString("deployed", mcp.Description("List deployed releases")), - mcp.WithString("pending", mcp.Description("List pending releases")), - mcp.WithString("filter", mcp.Description("A regular expression to filter releases by")), - mcp.WithString("output", mcp.Description("The output format (e.g., 'json', 'yaml', 'table')")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_list_releases", handleHelmListReleases))) - - s.AddTool(mcp.NewTool("helm_get_release", - mcp.WithDescription("Get extended information about a Helm release"), - mcp.WithString("name", mcp.Description("The name of the release"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the release"), mcp.Required()), - mcp.WithString("resource", mcp.Description("The resource to get (all, hooks, manifest, notes, values)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_get_release", handleHelmGetRelease))) - - s.AddTool(mcp.NewTool("helm_repo_update", - mcp.WithDescription("Update information of available charts locally from chart repositories"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_repo_update", handleHelmRepoUpdate))) + mcp.AddTool(s, "helm", &mcp.Tool{ + Name: "helm_list_releases", + Description: "List Helm releases in a namespace", + }, handleHelmListReleases) + + mcp.AddTool(s, "helm", &mcp.Tool{ + Name: "helm_get_release", + Description: "Get extended information about a Helm release", + }, handleHelmGetRelease) + + mcp.AddTool(s, "helm", &mcp.Tool{ + Name: "helm_repo_update", + Description: "Update information of available charts locally from chart repositories", + }, handleHelmRepoUpdate) // Write tools - only registered when not in read-only mode if !readOnly { - s.AddTool(mcp.NewTool("helm_upgrade", - mcp.WithDescription("Upgrade or install a Helm release"), - mcp.WithString("name", mcp.Description("The name of the release"), mcp.Required()), - mcp.WithString("chart", mcp.Description("The chart to install or upgrade to"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the release")), - mcp.WithString("version", mcp.Description("The version of the chart to upgrade to")), - mcp.WithString("values", mcp.Description("Path to a values file")), - mcp.WithString("set", mcp.Description("Set values on the command line (e.g., 'key1=val1,key2=val2')")), - mcp.WithString("install", mcp.Description("Run an install if the release is not present")), - mcp.WithString("dry_run", mcp.Description("Simulate an upgrade")), - mcp.WithString("wait", mcp.Description("Wait for the upgrade to complete")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_upgrade", handleHelmUpgradeRelease))) - - s.AddTool(mcp.NewTool("helm_uninstall", - mcp.WithDescription("Uninstall a Helm release"), - mcp.WithString("name", mcp.Description("The name of the release to uninstall"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the release"), mcp.Required()), - mcp.WithString("dry_run", mcp.Description("Simulate an uninstall")), - mcp.WithString("wait", mcp.Description("Wait for the uninstall to complete")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_uninstall", handleHelmUninstall))) - - s.AddTool(mcp.NewTool("helm_repo_add", - mcp.WithDescription("Add a Helm repository"), - mcp.WithString("name", mcp.Description("The name of the repository"), mcp.Required()), - mcp.WithString("url", mcp.Description("The URL of the repository"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("helm_repo_add", handleHelmRepoAdd))) + mcp.AddTool(s, "helm", &mcp.Tool{ + Name: "helm_upgrade", + Description: "Upgrade or install a Helm release", + }, handleHelmUpgradeRelease) + + mcp.AddTool(s, "helm", &mcp.Tool{ + Name: "helm_uninstall", + Description: "Uninstall a Helm release", + }, handleHelmUninstall) + + mcp.AddTool(s, "helm", &mcp.Tool{ + Name: "helm_repo_add", + Description: "Add a Helm repository", + }, handleHelmRepoAdd) } } diff --git a/pkg/helm/helm_test.go b/pkg/helm/helm_test.go index 9e5b26ca..d9665c6f 100644 --- a/pkg/helm/helm_test.go +++ b/pkg/helm/helm_test.go @@ -5,14 +5,13 @@ import ( "testing" "github.com/kagent-dev/tools/internal/cmd" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterTools(t *testing.T) { - s := server.NewMCPServer("test-server", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test-server", Version: "v0.0.1"}, nil) RegisterTools(s, false) // false = enable all tools including write operations } @@ -20,14 +19,14 @@ func TestRegisterTools(t *testing.T) { func TestHandleHelmListReleases(t *testing.T) { tests := []struct { name string - args map[string]interface{} + input helmListReleasesInput expectedArgs []string expectedOutput string expectError bool }{ { name: "basic_list_releases", - args: map[string]interface{}{}, + input: helmListReleasesInput{}, expectedArgs: []string{"list"}, expectedOutput: `NAME NAMESPACE REVISION STATUS CHART app1 default 1 deployed my-chart-1.0.0 @@ -36,8 +35,8 @@ app2 default 2 deployed my-chart-2.0.0`, }, { name: "list_releases_with_namespace", - args: map[string]interface{}{ - "namespace": "production", + input: helmListReleasesInput{ + Namespace: "production", }, expectedArgs: []string{"list", "-n", "production"}, expectedOutput: `NAME NAMESPACE REVISION STATUS CHART @@ -46,8 +45,8 @@ prod-app production 1 deployed my-chart-1.0.0`, }, { name: "list_releases_with_all_namespaces", - args: map[string]interface{}{ - "all_namespaces": "true", + input: helmListReleasesInput{ + AllNamespaces: true, }, expectedArgs: []string{"list", "-A"}, expectedOutput: `NAME NAMESPACE REVISION STATUS CHART @@ -57,11 +56,11 @@ prod-app production 1 deployed my-chart-1.0.0`, }, { name: "list_releases_with_multiple_flags", - args: map[string]interface{}{ - "all_namespaces": "true", - "all": "true", - "failed": "true", - "output": "json", + input: helmListReleasesInput{ + AllNamespaces: true, + All: true, + Failed: true, + Output: "json", }, expectedArgs: []string{"list", "-A", "-a", "--failed", "-o", "json"}, expectedOutput: `[ @@ -82,10 +81,7 @@ prod-app production 1 deployed my-chart-1.0.0`, mock.AddCommandString("helm", tt.expectedArgs, tt.expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = tt.args - - result, err := handleHelmListReleases(ctx, request) + result, _, err := handleHelmListReleases(ctx, &mcp.CallToolRequest{}, tt.input) assert.NoError(t, err) assert.False(t, result.IsError) @@ -119,8 +115,7 @@ prod-app production 1 deployed my-chart-1.0.0`, mock.AddCommandString("helm", []string{"list"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - result, err := handleHelmListReleases(ctx, request) + result, _, err := handleHelmListReleases(ctx, &mcp.CallToolRequest{}, helmListReleasesInput{}) assert.NoError(t, err) // MCP handlers should not return Go errors assert.True(t, result.IsError) @@ -141,13 +136,10 @@ replicaCount: 3` mock.AddCommandString("helm", []string{"get", "all", "myapp", "-n", "default"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - "namespace": "default", - } - - result, err := handleHelmGetRelease(ctx, request) + result, _, err := handleHelmGetRelease(ctx, &mcp.CallToolRequest{}, helmGetReleaseInput{ + Name: "myapp", + Namespace: "default", + }) assert.NoError(t, err) assert.False(t, result.IsError) @@ -165,14 +157,11 @@ replicaCount: 3` mock.AddCommandString("helm", []string{"get", "values", "myapp", "-n", "default"}, "replicaCount: 3", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - "namespace": "default", - "resource": "values", - } - - result, err := handleHelmGetRelease(ctx, request) + result, _, err := handleHelmGetRelease(ctx, &mcp.CallToolRequest{}, helmGetReleaseInput{ + Name: "myapp", + Namespace: "default", + Resource: "values", + }) assert.NoError(t, err) assert.False(t, result.IsError) @@ -189,22 +178,17 @@ replicaCount: 3` ctx := cmd.WithShellExecutor(context.Background(), mock) // Test missing name - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - } - - result, err := handleHelmGetRelease(ctx, request) + result, _, err := handleHelmGetRelease(ctx, &mcp.CallToolRequest{}, helmGetReleaseInput{ + Namespace: "default", + }) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "name parameter is required") // Test missing namespace - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - } - - result, err = handleHelmGetRelease(ctx, request) + result, _, err = handleHelmGetRelease(ctx, &mcp.CallToolRequest{}, helmGetReleaseInput{ + Name: "myapp", + }) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "namespace parameter is required") @@ -229,13 +213,10 @@ REVISION: 2` mock.AddCommandString("helm", []string{"upgrade", "myapp", "stable/myapp", "--timeout", "30s"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - "chart": "stable/myapp", - } - - result, err := handleHelmUpgradeRelease(ctx, request) + result, _, err := handleHelmUpgradeRelease(ctx, &mcp.CallToolRequest{}, helmUpgradeReleaseInput{ + Name: "myapp", + Chart: "stable/myapp", + }) assert.NoError(t, err) assert.False(t, result.IsError) @@ -265,20 +246,17 @@ REVISION: 2` mock.AddCommandString("helm", expectedArgs, "Upgraded with options", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - "chart": "stable/myapp", - "namespace": "production", - "version": "1.2.0", - "values": "values.yaml", - "set": "replicas=5,image.tag=v1.2.0", - "install": "true", - "dry_run": "true", - "wait": "true", - } - - result, err := handleHelmUpgradeRelease(ctx, request) + result, _, err := handleHelmUpgradeRelease(ctx, &mcp.CallToolRequest{}, helmUpgradeReleaseInput{ + Name: "myapp", + Chart: "stable/myapp", + Namespace: "production", + Version: "1.2.0", + Values: "values.yaml", + Set: "replicas=5,image.tag=v1.2.0", + Install: true, + DryRun: true, + Wait: true, + }) assert.NoError(t, err) assert.False(t, result.IsError) @@ -295,12 +273,9 @@ REVISION: 2` ctx := cmd.WithShellExecutor(context.Background(), mock) // Test missing chart - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - } - - result, err := handleHelmUpgradeRelease(ctx, request) + result, _, err := handleHelmUpgradeRelease(ctx, &mcp.CallToolRequest{}, helmUpgradeReleaseInput{ + Name: "myapp", + }) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "name and chart parameters are required") @@ -320,13 +295,10 @@ func TestHandleHelmUninstall(t *testing.T) { mock.AddCommandString("helm", []string{"uninstall", "myapp", "-n", "default"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - "namespace": "default", - } - - result, err := handleHelmUninstall(ctx, request) + result, _, err := handleHelmUninstall(ctx, &mcp.CallToolRequest{}, helmUninstallInput{ + Name: "myapp", + Namespace: "default", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -347,15 +319,12 @@ func TestHandleHelmUninstall(t *testing.T) { mock.AddCommandString("helm", []string{"uninstall", "myapp", "-n", "production", "--dry-run", "--wait"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - "namespace": "production", - "dry_run": "true", - "wait": "true", - } - - result, err := handleHelmUninstall(ctx, request) + result, _, err := handleHelmUninstall(ctx, &mcp.CallToolRequest{}, helmUninstallInput{ + Name: "myapp", + Namespace: "production", + DryRun: true, + Wait: true, + }) assert.NoError(t, err) assert.False(t, result.IsError) @@ -372,22 +341,17 @@ func TestHandleHelmUninstall(t *testing.T) { ctx := cmd.WithShellExecutor(context.Background(), mock) // Test missing name - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - } - - result, err := handleHelmUninstall(ctx, request) + result, _, err := handleHelmUninstall(ctx, &mcp.CallToolRequest{}, helmUninstallInput{ + Namespace: "default", + }) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "name and namespace parameters are required") // Test missing namespace - request.Params.Arguments = map[string]interface{}{ - "name": "myapp", - } - - result, err = handleHelmUninstall(ctx, request) + result, _, err = handleHelmUninstall(ctx, &mcp.CallToolRequest{}, helmUninstallInput{ + Name: "myapp", + }) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "name and namespace parameters are required") @@ -407,13 +371,10 @@ func TestHandleHelmRepoAdd(t *testing.T) { mock.AddCommandString("helm", []string{"repo", "add", "my-repo", "https://charts.example.com/"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "name": "my-repo", - "url": "https://charts.example.com/", - } - - result, err := handleHelmRepoAdd(ctx, request) + result, _, err := handleHelmRepoAdd(ctx, &mcp.CallToolRequest{}, helmRepoAddInput{ + Name: "my-repo", + URL: "https://charts.example.com/", + }) assert.NoError(t, err) assert.False(t, result.IsError) @@ -431,12 +392,9 @@ func TestHandleHelmRepoAdd(t *testing.T) { ctx := cmd.WithShellExecutor(context.Background(), mock) // Test missing name - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "url": "https://charts.example.com/", - } - - result, err := handleHelmRepoAdd(ctx, request) + result, _, err := handleHelmRepoAdd(ctx, &mcp.CallToolRequest{}, helmRepoAddInput{ + URL: "https://charts.example.com/", + }) assert.NoError(t, err) assert.True(t, result.IsError) assert.Contains(t, getResultText(result), "name and url parameters are required") @@ -458,8 +416,7 @@ Update Complete. ⎈Happy Helming!⎈` mock.AddCommandString("helm", []string{"repo", "update"}, expectedOutput, nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - request := mcp.CallToolRequest{} - result, err := handleHelmRepoUpdate(ctx, request) + result, _, err := handleHelmRepoUpdate(ctx, &mcp.CallToolRequest{}, helmRepoUpdateInput{}) assert.NoError(t, err) assert.False(t, result.IsError) @@ -478,7 +435,7 @@ func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { return "" } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { return textContent.Text } return "" diff --git a/pkg/istio/istio.go b/pkg/istio/istio.go index dd1958c9..6a0da8e0 100644 --- a/pkg/istio/istio.go +++ b/pkg/istio/istio.go @@ -6,33 +6,33 @@ import ( "strings" "github.com/kagent-dev/tools/internal/commands" - "github.com/kagent-dev/tools/internal/telemetry" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/pkg/utils" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" ) -// Istio proxy status -func handleIstioProxyStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - podName := mcp.ParseString(request, "pod_name", "") - namespace := mcp.ParseString(request, "namespace", "") +type istioProxyStatusInput struct { + PodName string `json:"pod_name" jsonschema:"Name of the pod to get proxy status for"` + Namespace string `json:"namespace" jsonschema:"Namespace of the pod"` +} +// Istio proxy status +func handleIstioProxyStatus(ctx context.Context, request *mcp.CallToolRequest, in istioProxyStatusInput) (*mcp.CallToolResult, any, error) { args := []string{"proxy-status"} - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - if podName != "" { - args = append(args, podName) + if in.PodName != "" { + args = append(args, in.PodName) } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl proxy-status failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl proxy-status failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } func runIstioCtl(ctx context.Context, args []string) (string, error) { @@ -43,336 +43,376 @@ func runIstioCtl(ctx context.Context, args []string) (string, error) { Execute(ctx) } +type istioProxyConfigInput struct { + PodName string `json:"pod_name" jsonschema:"Name of the pod to get proxy configuration for"` + Namespace string `json:"namespace" jsonschema:"Namespace of the pod"` + ConfigType string `json:"config_type" jsonschema:"Type of configuration (all, bootstrap, cluster, ecds, listener, log, route, secret)"` +} + // Istio proxy config -func handleIstioProxyConfig(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - podName := mcp.ParseString(request, "pod_name", "") - namespace := mcp.ParseString(request, "namespace", "") - configType := mcp.ParseString(request, "config_type", "all") +func handleIstioProxyConfig(ctx context.Context, request *mcp.CallToolRequest, in istioProxyConfigInput) (*mcp.CallToolResult, any, error) { + if in.ConfigType == "" { + in.ConfigType = "all" + } - if podName == "" { - return mcp.NewToolResultError("pod_name parameter is required"), nil + if in.PodName == "" { + return mcp.NewToolResultError("pod_name parameter is required"), nil, nil } - args := []string{"proxy-config", configType} + args := []string{"proxy-config", in.ConfigType} - if namespace != "" { - args = append(args, fmt.Sprintf("%s.%s", podName, namespace)) + if in.Namespace != "" { + args = append(args, fmt.Sprintf("%s.%s", in.PodName, in.Namespace)) } else { - args = append(args, podName) + args = append(args, in.PodName) } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl proxy-config failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl proxy-config failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil +} + +type istioInstallInput struct { + Profile string `json:"profile" jsonschema:"Istio configuration profile (ambient, default, demo, minimal, empty)"` } // Istio install -func handleIstioInstall(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - profile := mcp.ParseString(request, "profile", "default") +func handleIstioInstall(ctx context.Context, request *mcp.CallToolRequest, in istioInstallInput) (*mcp.CallToolResult, any, error) { + if in.Profile == "" { + in.Profile = "default" + } - args := []string{"install", "--set", fmt.Sprintf("profile=%s", profile), "-y"} + args := []string{"install", "--set", fmt.Sprintf("profile=%s", in.Profile), "-y"} result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl install failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl install failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil +} + +type istioGenerateManifestInput struct { + Profile string `json:"profile" jsonschema:"Istio configuration profile (ambient, default, demo, minimal, empty)"` } // Istio generate manifest -func handleIstioGenerateManifest(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - profile := mcp.ParseString(request, "profile", "default") +func handleIstioGenerateManifest(ctx context.Context, request *mcp.CallToolRequest, in istioGenerateManifestInput) (*mcp.CallToolResult, any, error) { + if in.Profile == "" { + in.Profile = "default" + } - args := []string{"manifest", "generate", "--set", fmt.Sprintf("profile=%s", profile)} + args := []string{"manifest", "generate", "--set", fmt.Sprintf("profile=%s", in.Profile)} result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl manifest generate failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl manifest generate failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Istio analyze -func handleIstioAnalyzeClusterConfiguration(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - allNamespaces := mcp.ParseString(request, "all_namespaces", "") == "true" +type istioAnalyzeClusterConfigurationInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace to analyze"` + AllNamespaces bool `json:"all_namespaces" jsonschema:"Analyze all namespaces"` +} +// Istio analyze +func handleIstioAnalyzeClusterConfiguration(ctx context.Context, request *mcp.CallToolRequest, in istioAnalyzeClusterConfigurationInput) (*mcp.CallToolResult, any, error) { args := []string{"analyze"} - if allNamespaces { + if in.AllNamespaces { args = append(args, "-A") - } else if namespace != "" { - args = append(args, "-n", namespace) + } else if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl analyze failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl analyze failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Istio version -func handleIstioVersion(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - short := mcp.ParseString(request, "short", "") == "true" +type istioVersionInput struct { + Short bool `json:"short" jsonschema:"Return short version output"` +} +// Istio version +func handleIstioVersion(ctx context.Context, request *mcp.CallToolRequest, in istioVersionInput) (*mcp.CallToolResult, any, error) { args := []string{"version"} - if short { + if in.Short { args = append(args, "--short") } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl version failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl version failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } +type istioRemoteClustersInput struct{} + // Istio remote clusters -func handleIstioRemoteClusters(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleIstioRemoteClusters(ctx context.Context, request *mcp.CallToolRequest, in istioRemoteClustersInput) (*mcp.CallToolResult, any, error) { args := []string{"remote-clusters"} result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl remote-clusters failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl remote-clusters failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Waypoint list -func handleWaypointList(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - allNamespaces := mcp.ParseString(request, "all_namespaces", "") == "true" +type waypointListInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace to list waypoints in"` + AllNamespaces bool `json:"all_namespaces" jsonschema:"List waypoints in all namespaces"` +} +// Waypoint list +func handleWaypointList(ctx context.Context, request *mcp.CallToolRequest, in waypointListInput) (*mcp.CallToolResult, any, error) { args := []string{"waypoint", "list"} - if allNamespaces { + if in.AllNamespaces { args = append(args, "-A") - } else if namespace != "" { - args = append(args, "-n", namespace) + } else if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint list failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint list failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil +} + +type waypointGenerateInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace for the waypoint resource"` + Name string `json:"name" jsonschema:"Name of the waypoint resource"` + TrafficType string `json:"traffic_type" jsonschema:"Traffic type for the waypoint (all, service, workload)"` } // Waypoint generate -func handleWaypointGenerate(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - name := mcp.ParseString(request, "name", "waypoint") - trafficType := mcp.ParseString(request, "traffic_type", "all") +func handleWaypointGenerate(ctx context.Context, request *mcp.CallToolRequest, in waypointGenerateInput) (*mcp.CallToolResult, any, error) { + if in.Name == "" { + in.Name = "waypoint" + } + if in.TrafficType == "" { + in.TrafficType = "all" + } - if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil + if in.Namespace == "" { + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } args := []string{"waypoint", "generate"} - if name != "" { - args = append(args, name) + if in.Name != "" { + args = append(args, in.Name) } - args = append(args, "-n", namespace) + args = append(args, "-n", in.Namespace) - if trafficType != "" { - args = append(args, "--for", trafficType) + if in.TrafficType != "" { + args = append(args, "--for", in.TrafficType) } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint generate failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint generate failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Waypoint apply -func handleWaypointApply(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - enrollNamespace := mcp.ParseString(request, "enroll_namespace", "") == "true" +type waypointApplyInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace to apply the waypoint in"` + EnrollNamespace bool `json:"enroll_namespace" jsonschema:"Enroll the namespace in the ambient mesh"` +} - if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil +// Waypoint apply +func handleWaypointApply(ctx context.Context, request *mcp.CallToolRequest, in waypointApplyInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } - args := []string{"waypoint", "apply", "-n", namespace} + args := []string{"waypoint", "apply", "-n", in.Namespace} - if enrollNamespace { + if in.EnrollNamespace { args = append(args, "--enroll-namespace") } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint apply failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint apply failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Waypoint delete -func handleWaypointDelete(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - names := mcp.ParseString(request, "names", "") - all := mcp.ParseString(request, "all", "") == "true" +type waypointDeleteInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace containing the waypoints to delete"` + Names string `json:"names" jsonschema:"Comma-separated list of waypoint names to delete"` + All bool `json:"all" jsonschema:"Delete all waypoints in the namespace"` +} - if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil +// Waypoint delete +func handleWaypointDelete(ctx context.Context, request *mcp.CallToolRequest, in waypointDeleteInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } args := []string{"waypoint", "delete"} - if all { + if in.All { args = append(args, "--all") - } else if names != "" { - namesList := strings.Split(names, ",") + } else if in.Names != "" { + namesList := strings.Split(in.Names, ",") for _, name := range namesList { args = append(args, strings.TrimSpace(name)) } } - args = append(args, "-n", namespace) + args = append(args, "-n", in.Namespace) result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint delete failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint delete failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } -// Waypoint status -func handleWaypointStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - name := mcp.ParseString(request, "name", "") +type waypointStatusInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace of the waypoint"` + Name string `json:"name" jsonschema:"Name of the waypoint resource"` +} - if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil +// Waypoint status +func handleWaypointStatus(ctx context.Context, request *mcp.CallToolRequest, in waypointStatusInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } args := []string{"waypoint", "status"} - if name != "" { - args = append(args, name) + if in.Name != "" { + args = append(args, in.Name) } - args = append(args, "-n", namespace) + args = append(args, "-n", in.Namespace) result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint status failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl waypoint status failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil +} + +type ztunnelConfigInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace to get ztunnel configuration for"` + ConfigType string `json:"config_type" jsonschema:"Type of ztunnel configuration (all, workloads, services)"` } // Ztunnel config -func handleZtunnelConfig(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") - configType := mcp.ParseString(request, "config_type", "all") +func handleZtunnelConfig(ctx context.Context, request *mcp.CallToolRequest, in ztunnelConfigInput) (*mcp.CallToolResult, any, error) { + if in.ConfigType == "" { + in.ConfigType = "all" + } - args := []string{"ztunnel", "config", configType} + args := []string{"ztunnel", "config", in.ConfigType} - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } result, err := runIstioCtl(ctx, args) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("istioctl ztunnel config failed: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("istioctl ztunnel config failed: %v", err)), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } // Register Istio tools -func RegisterTools(s *server.MCPServer, readOnly bool) { +func RegisterTools(s *mcp.Server, readOnly bool) { // Read-only tools - always registered - // Istio proxy status - s.AddTool(mcp.NewTool("istio_proxy_status", - mcp.WithDescription("Get Envoy proxy status for pods, retrieves last sent and acknowledged xDS sync from Istiod to each Envoy in the mesh"), - mcp.WithString("pod_name", mcp.Description("Name of the pod to get proxy status for")), - mcp.WithString("namespace", mcp.Description("Namespace of the pod")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_proxy_status", handleIstioProxyStatus))) - - // Istio proxy config - s.AddTool(mcp.NewTool("istio_proxy_config", - mcp.WithDescription("Get specific proxy configuration for a single pod"), - mcp.WithString("pod_name", mcp.Description("Name of the pod to get proxy configuration for"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the pod")), - mcp.WithString("config_type", mcp.Description("Type of configuration (all, bootstrap, cluster, ecds, listener, log, route, secret)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_proxy_config", handleIstioProxyConfig))) - - // Istio generate manifest (read-only - just generates YAML, doesn't apply) - s.AddTool(mcp.NewTool("istio_generate_manifest", - mcp.WithDescription("Generate Istio manifest for a given profile"), - mcp.WithString("profile", mcp.Description("Istio configuration profile (ambient, default, demo, minimal, empty)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_generate_manifest", handleIstioGenerateManifest))) - - // Istio analyze - s.AddTool(mcp.NewTool("istio_analyze_cluster_configuration", - mcp.WithDescription("Analyze Istio cluster configuration for issues"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_analyze_cluster_configuration", handleIstioAnalyzeClusterConfiguration))) - - // Istio version - s.AddTool(mcp.NewTool("istio_version", - mcp.WithDescription("Get Istio version information"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_version", handleIstioVersion))) - - // Istio remote clusters - s.AddTool(mcp.NewTool("istio_remote_clusters", - mcp.WithDescription("List remote clusters registered with Istio"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_remote_clusters", handleIstioRemoteClusters))) - - // Waypoint list - s.AddTool(mcp.NewTool("istio_list_waypoints", - mcp.WithDescription("List all waypoints in the mesh"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_list_waypoints", handleWaypointList))) - - // Waypoint generate (read-only - just generates YAML, doesn't apply) - s.AddTool(mcp.NewTool("istio_generate_waypoint", - mcp.WithDescription("Generate a waypoint resource YAML"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_generate_waypoint", handleWaypointGenerate))) - - // Waypoint status - s.AddTool(mcp.NewTool("istio_waypoint_status", - mcp.WithDescription("Get the status of a waypoint resource"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_waypoint_status", handleWaypointStatus))) - - // Ztunnel config - s.AddTool(mcp.NewTool("istio_ztunnel_config", - mcp.WithDescription("Get the ztunnel configuration for a namespace"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_ztunnel_config", handleZtunnelConfig))) + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_proxy_status", + Description: "Get Envoy proxy status for pods, retrieves last sent and acknowledged xDS sync from Istiod to each Envoy in the mesh", + }, handleIstioProxyStatus) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_proxy_config", + Description: "Get specific proxy configuration for a single pod", + }, handleIstioProxyConfig) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_generate_manifest", + Description: "Generate Istio manifest for a given profile", + }, handleIstioGenerateManifest) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_analyze_cluster_configuration", + Description: "Analyze Istio cluster configuration for issues", + }, handleIstioAnalyzeClusterConfiguration) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_version", + Description: "Get Istio version information", + }, handleIstioVersion) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_remote_clusters", + Description: "List remote clusters registered with Istio", + }, handleIstioRemoteClusters) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_list_waypoints", + Description: "List all waypoints in the mesh", + }, handleWaypointList) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_generate_waypoint", + Description: "Generate a waypoint resource YAML", + }, handleWaypointGenerate) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_waypoint_status", + Description: "Get the status of a waypoint resource", + }, handleWaypointStatus) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_ztunnel_config", + Description: "Get the ztunnel configuration for a namespace", + }, handleZtunnelConfig) // Write tools - only registered when write operations are enabled if !readOnly { - // Istio install - s.AddTool(mcp.NewTool("istio_install_istio", - mcp.WithDescription("Install Istio with a specified configuration profile"), - mcp.WithString("profile", mcp.Description("Istio configuration profile (ambient, default, demo, minimal, empty)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_install_istio", handleIstioInstall))) - - // Waypoint apply - s.AddTool(mcp.NewTool("istio_apply_waypoint", - mcp.WithDescription("Apply a waypoint resource to the cluster"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_apply_waypoint", handleWaypointApply))) - - // Waypoint delete - s.AddTool(mcp.NewTool("istio_delete_waypoint", - mcp.WithDescription("Delete a waypoint resource from the cluster"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("istio_delete_waypoint", handleWaypointDelete))) + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_install_istio", + Description: "Install Istio with a specified configuration profile", + }, handleIstioInstall) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_apply_waypoint", + Description: "Apply a waypoint resource to the cluster", + }, handleWaypointApply) + + mcp.AddTool(s, "istio", &mcp.Tool{ + Name: "istio_delete_waypoint", + Description: "Delete a waypoint resource from the cluster", + }, handleWaypointDelete) } } diff --git a/pkg/istio/istio_test.go b/pkg/istio/istio_test.go index 4eacea90..e57f4c1d 100644 --- a/pkg/istio/istio_test.go +++ b/pkg/istio/istio_test.go @@ -5,14 +5,13 @@ import ( "testing" "github.com/kagent-dev/tools/internal/cmd" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterTools(t *testing.T) { - s := server.NewMCPServer("test-server", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test-server", Version: "v0.0.1"}, nil) RegisterTools(s, false) // false = enable all tools including write operations } @@ -25,7 +24,7 @@ func TestHandleIstioProxyStatus(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleIstioProxyStatus(ctx, mcp.CallToolRequest{}) + result, _, err := handleIstioProxyStatus(ctx, &mcp.CallToolRequest{}, istioProxyStatusInput{}) require.NoError(t, err) assert.NotNil(t, result) @@ -38,12 +37,9 @@ func TestHandleIstioProxyStatus(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "istio-system", - } - - result, err := handleIstioProxyStatus(ctx, request) + result, _, err := handleIstioProxyStatus(ctx, &mcp.CallToolRequest{}, istioProxyStatusInput{ + Namespace: "istio-system", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -56,13 +52,10 @@ func TestHandleIstioProxyStatus(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "pod_name": "test-pod", - "namespace": "default", - } - - result, err := handleIstioProxyStatus(ctx, request) + result, _, err := handleIstioProxyStatus(ctx, &mcp.CallToolRequest{}, istioProxyStatusInput{ + PodName: "test-pod", + Namespace: "default", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -74,7 +67,7 @@ func TestHandleIstioProxyConfig(t *testing.T) { ctx := context.Background() t.Run("missing pod_name parameter", func(t *testing.T) { - result, err := handleIstioProxyConfig(ctx, mcp.CallToolRequest{}) + result, _, err := handleIstioProxyConfig(ctx, &mcp.CallToolRequest{}, istioProxyConfigInput{}) require.NoError(t, err) assert.NotNil(t, result) @@ -87,12 +80,9 @@ func TestHandleIstioProxyConfig(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "pod_name": "test-pod", - } - - result, err := handleIstioProxyConfig(ctx, request) + result, _, err := handleIstioProxyConfig(ctx, &mcp.CallToolRequest{}, istioProxyConfigInput{ + PodName: "test-pod", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -105,14 +95,11 @@ func TestHandleIstioProxyConfig(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "pod_name": "test-pod", - "namespace": "default", - "config_type": "cluster", - } - - result, err := handleIstioProxyConfig(ctx, request) + result, _, err := handleIstioProxyConfig(ctx, &mcp.CallToolRequest{}, istioProxyConfigInput{ + PodName: "test-pod", + Namespace: "default", + ConfigType: "cluster", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -129,7 +116,7 @@ func TestHandleIstioInstall(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleIstioInstall(ctx, mcp.CallToolRequest{}) + result, _, err := handleIstioInstall(ctx, &mcp.CallToolRequest{}, istioInstallInput{}) require.NoError(t, err) assert.NotNil(t, result) @@ -142,12 +129,9 @@ func TestHandleIstioInstall(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "profile": "demo", - } - - result, err := handleIstioInstall(ctx, request) + result, _, err := handleIstioInstall(ctx, &mcp.CallToolRequest{}, istioInstallInput{ + Profile: "demo", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -163,12 +147,9 @@ func TestHandleIstioGenerateManifest(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "profile": "minimal", - } - - result, err := handleIstioGenerateManifest(ctx, request) + result, _, err := handleIstioGenerateManifest(ctx, &mcp.CallToolRequest{}, istioGenerateManifestInput{ + Profile: "minimal", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -184,12 +165,9 @@ func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "all_namespaces": "true", - } - - result, err := handleIstioAnalyzeClusterConfiguration(ctx, request) + result, _, err := handleIstioAnalyzeClusterConfiguration(ctx, &mcp.CallToolRequest{}, istioAnalyzeClusterConfigurationInput{ + AllNamespaces: true, + }) require.NoError(t, err) assert.NotNil(t, result) @@ -202,12 +180,9 @@ func TestHandleIstioAnalyzeClusterConfiguration(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - } - - result, err := handleIstioAnalyzeClusterConfiguration(ctx, request) + result, _, err := handleIstioAnalyzeClusterConfiguration(ctx, &mcp.CallToolRequest{}, istioAnalyzeClusterConfigurationInput{ + Namespace: "default", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -224,7 +199,7 @@ func TestHandleIstioVersion(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleIstioVersion(ctx, mcp.CallToolRequest{}) + result, _, err := handleIstioVersion(ctx, &mcp.CallToolRequest{}, istioVersionInput{}) require.NoError(t, err) assert.NotNil(t, result) @@ -237,12 +212,9 @@ func TestHandleIstioVersion(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "short": "true", - } - - result, err := handleIstioVersion(ctx, request) + result, _, err := handleIstioVersion(ctx, &mcp.CallToolRequest{}, istioVersionInput{ + Short: true, + }) require.NoError(t, err) assert.NotNil(t, result) @@ -258,7 +230,7 @@ func TestHandleIstioRemoteClusters(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - result, err := handleIstioRemoteClusters(ctx, mcp.CallToolRequest{}) + result, _, err := handleIstioRemoteClusters(ctx, &mcp.CallToolRequest{}, istioRemoteClustersInput{}) require.NoError(t, err) assert.NotNil(t, result) @@ -274,12 +246,9 @@ func TestHandleWaypointList(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "all_namespaces": "true", - } - - result, err := handleWaypointList(ctx, request) + result, _, err := handleWaypointList(ctx, &mcp.CallToolRequest{}, waypointListInput{ + AllNamespaces: true, + }) require.NoError(t, err) assert.NotNil(t, result) @@ -292,12 +261,9 @@ func TestHandleWaypointList(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - } - - result, err := handleWaypointList(ctx, request) + result, _, err := handleWaypointList(ctx, &mcp.CallToolRequest{}, waypointListInput{ + Namespace: "default", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -314,14 +280,11 @@ func TestHandleWaypointGenerate(t *testing.T) { ctx = cmd.WithShellExecutor(ctx, mock) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "namespace": "default", - "name": "waypoint", - "traffic_type": "all", - } - - result, err := handleWaypointGenerate(ctx, request) + result, _, err := handleWaypointGenerate(ctx, &mcp.CallToolRequest{}, waypointGenerateInput{ + Namespace: "default", + Name: "waypoint", + TrafficType: "all", + }) require.NoError(t, err) assert.NotNil(t, result) @@ -348,7 +311,7 @@ func TestIstioErrorHandling(t *testing.T) { mock.AddCommandString("istioctl", []string{"proxy-status"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleIstioProxyStatus(ctx, mcp.CallToolRequest{}) + result, _, err := handleIstioProxyStatus(ctx, &mcp.CallToolRequest{}, istioProxyStatusInput{}) require.NoError(t, err) assert.NotNil(t, result) @@ -362,9 +325,7 @@ func TestHandleWaypointApply(t *testing.T) { mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, "applied", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default"} - result, err := handleWaypointApply(ctx, req) + result, _, err := handleWaypointApply(ctx, &mcp.CallToolRequest{}, waypointApplyInput{Namespace: "default"}) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -374,9 +335,10 @@ func TestHandleWaypointApply(t *testing.T) { mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default", "--enroll-namespace"}, "applied", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default", "enroll_namespace": "true"} - result, err := handleWaypointApply(ctx, req) + result, _, err := handleWaypointApply(ctx, &mcp.CallToolRequest{}, waypointApplyInput{ + Namespace: "default", + EnrollNamespace: true, + }) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -384,7 +346,7 @@ func TestHandleWaypointApply(t *testing.T) { t.Run("missing namespace", func(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleWaypointApply(ctx, mcp.CallToolRequest{}) + result, _, err := handleWaypointApply(ctx, &mcp.CallToolRequest{}, waypointApplyInput{}) require.NoError(t, err) assert.True(t, result.IsError) }) @@ -393,9 +355,7 @@ func TestHandleWaypointApply(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default"} - result, err := handleWaypointApply(ctx, req) + result, _, err := handleWaypointApply(ctx, &mcp.CallToolRequest{}, waypointApplyInput{Namespace: "default"}) require.NoError(t, err) assert.True(t, result.IsError) }) @@ -406,9 +366,10 @@ func TestHandleWaypointDelete(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"waypoint", "delete", "--all", "-n", "default"}, "deleted", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default", "all": "true"} - result, err := handleWaypointDelete(ctx, req) + result, _, err := handleWaypointDelete(ctx, &mcp.CallToolRequest{}, waypointDeleteInput{ + Namespace: "default", + All: true, + }) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -417,9 +378,10 @@ func TestHandleWaypointDelete(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"waypoint", "delete", "wp1", "wp2", "-n", "default"}, "deleted", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default", "names": "wp1, wp2"} - result, err := handleWaypointDelete(ctx, req) + result, _, err := handleWaypointDelete(ctx, &mcp.CallToolRequest{}, waypointDeleteInput{ + Namespace: "default", + Names: "wp1, wp2", + }) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -427,7 +389,7 @@ func TestHandleWaypointDelete(t *testing.T) { t.Run("missing namespace", func(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleWaypointDelete(ctx, mcp.CallToolRequest{}) + result, _, err := handleWaypointDelete(ctx, &mcp.CallToolRequest{}, waypointDeleteInput{}) require.NoError(t, err) assert.True(t, result.IsError) }) @@ -438,9 +400,10 @@ func TestHandleWaypointStatus(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"waypoint", "status", "wp1", "-n", "default"}, "status", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default", "name": "wp1"} - result, err := handleWaypointStatus(ctx, req) + result, _, err := handleWaypointStatus(ctx, &mcp.CallToolRequest{}, waypointStatusInput{ + Namespace: "default", + Name: "wp1", + }) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -449,9 +412,9 @@ func TestHandleWaypointStatus(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"waypoint", "status", "-n", "default"}, "status", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"namespace": "default"} - result, err := handleWaypointStatus(ctx, req) + result, _, err := handleWaypointStatus(ctx, &mcp.CallToolRequest{}, waypointStatusInput{ + Namespace: "default", + }) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -459,7 +422,7 @@ func TestHandleWaypointStatus(t *testing.T) { t.Run("missing namespace", func(t *testing.T) { mock := cmd.NewMockShellExecutor() ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleWaypointStatus(ctx, mcp.CallToolRequest{}) + result, _, err := handleWaypointStatus(ctx, &mcp.CallToolRequest{}, waypointStatusInput{}) require.NoError(t, err) assert.True(t, result.IsError) }) @@ -470,7 +433,7 @@ func TestHandleZtunnelConfig(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"ztunnel", "config", "all"}, "ztunnel config", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleZtunnelConfig(ctx, mcp.CallToolRequest{}) + result, _, err := handleZtunnelConfig(ctx, &mcp.CallToolRequest{}, ztunnelConfigInput{}) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -479,9 +442,10 @@ func TestHandleZtunnelConfig(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"ztunnel", "config", "workloads", "-n", "istio-system"}, "ztunnel config", nil) ctx := cmd.WithShellExecutor(context.Background(), mock) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"config_type": "workloads", "namespace": "istio-system"} - result, err := handleZtunnelConfig(ctx, req) + result, _, err := handleZtunnelConfig(ctx, &mcp.CallToolRequest{}, ztunnelConfigInput{ + ConfigType: "workloads", + Namespace: "istio-system", + }) require.NoError(t, err) assert.False(t, result.IsError) }) @@ -490,7 +454,7 @@ func TestHandleZtunnelConfig(t *testing.T) { mock := cmd.NewMockShellExecutor() mock.AddCommandString("istioctl", []string{"ztunnel", "config", "all"}, "", assert.AnError) ctx := cmd.WithShellExecutor(context.Background(), mock) - result, err := handleZtunnelConfig(ctx, mcp.CallToolRequest{}) + result, _, err := handleZtunnelConfig(ctx, &mcp.CallToolRequest{}, ztunnelConfigInput{}) require.NoError(t, err) assert.True(t, result.IsError) }) diff --git a/pkg/k8s/k8s.go b/pkg/k8s/k8s.go index 6def2f2a..abb40ac0 100644 --- a/pkg/k8s/k8s.go +++ b/pkg/k8s/k8s.go @@ -12,15 +12,13 @@ import ( "strings" "time" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" "github.com/tmc/langchaingo/llms" "github.com/kagent-dev/tools/internal/cache" "github.com/kagent-dev/tools/internal/commands" "github.com/kagent-dev/tools/internal/logger" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/internal/security" - "github.com/kagent-dev/tools/internal/telemetry" ) // K8sTool struct to hold the LLM model @@ -54,436 +52,634 @@ func (k *K8sTool) runKubectlCommandWithCacheInvalidation(ctx context.Context, he return result, err } -// Enhanced kubectl get -func (k *K8sTool) handleKubectlGetEnhanced(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "") - allNamespaces := mcp.ParseString(request, "all_namespaces", "") == "true" - output := mcp.ParseString(request, "output", "wide") +// getResourcesInput is the typed input for k8s_get_resources. +type getResourcesInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource (pod, service, deployment, etc.)"` + ResourceName string `json:"resource_name" jsonschema:"Name of specific resource (optional)"` + Namespace string `json:"namespace" jsonschema:"Namespace to query (optional)"` + AllNamespaces bool `json:"all_namespaces" jsonschema:"Query all namespaces"` + Output string `json:"output" jsonschema:"Output format (json, yaml, wide)"` +} - if resourceType == "" { - return mcp.NewToolResultError("resource_type parameter is required"), nil +// Enhanced kubectl get +func (k *K8sTool) handleKubectlGetEnhanced(ctx context.Context, request *mcp.CallToolRequest, in getResourcesInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" { + return mcp.NewToolResultError("resource_type parameter is required"), nil, nil + } + if in.Output == "" { + in.Output = "wide" } - args := []string{"get", resourceType} + args := []string{"get", in.ResourceType} - if resourceName != "" { - args = append(args, resourceName) + if in.ResourceName != "" { + args = append(args, in.ResourceName) } - if allNamespaces { + if in.AllNamespaces { args = append(args, "--all-namespaces") - } else if namespace != "" { - args = append(args, "-n", namespace) + } else if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - if output != "" { - args = append(args, "-o", output) - } else { - args = append(args, "-o", "json") - } + args = append(args, "-o", in.Output) - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Get pod logs -func (k *K8sTool) handleKubectlLogsEnhanced(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - podName := mcp.ParseString(request, "pod_name", "") - namespace := mcp.ParseString(request, "namespace", "default") - container := mcp.ParseString(request, "container", "") - tailLines := mcp.ParseInt(request, "tail_lines", 50) +// logsInput is the typed input for k8s_get_pod_logs. +type logsInput struct { + PodName string `json:"pod_name" jsonschema:"Name of the pod"` + Namespace string `json:"namespace" jsonschema:"Namespace of the pod (default: default)"` + Container string `json:"container" jsonschema:"Container name (for multi-container pods)"` + TailLines int `json:"tail_lines" jsonschema:"Number of lines to show from the end (default: 50)"` +} - if podName == "" { - return mcp.NewToolResultError("pod_name parameter is required"), nil +// Get pod logs +func (k *K8sTool) handleKubectlLogsEnhanced(ctx context.Context, request *mcp.CallToolRequest, in logsInput) (*mcp.CallToolResult, any, error) { + if in.PodName == "" { + return mcp.NewToolResultError("pod_name parameter is required"), nil, nil + } + if in.Namespace == "" { + in.Namespace = "default" + } + if in.TailLines == 0 { + in.TailLines = 50 } - args := []string{"logs", podName, "-n", namespace} + args := []string{"logs", in.PodName, "-n", in.Namespace} - if container != "" { - args = append(args, "-c", container) + if in.Container != "" { + args = append(args, "-c", in.Container) } - if tailLines > 0 { - args = append(args, "--tail", fmt.Sprintf("%d", tailLines)) + if in.TailLines > 0 { + args = append(args, "--tail", fmt.Sprintf("%d", in.TailLines)) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Scale deployment -func (k *K8sTool) handleScaleDeployment(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - deploymentName := mcp.ParseString(request, "name", "") - namespace := mcp.ParseString(request, "namespace", "default") - replicas := mcp.ParseInt(request, "replicas", 1) +// scaleInput is the typed input for k8s_scale. +type scaleInput struct { + Name string `json:"name" jsonschema:"Name of the deployment"` + Namespace string `json:"namespace" jsonschema:"Namespace of the deployment (default: default)"` + Replicas int `json:"replicas" jsonschema:"Number of replicas"` +} - if deploymentName == "" { - return mcp.NewToolResultError("name parameter is required"), nil +// Scale deployment +func (k *K8sTool) handleScaleDeployment(ctx context.Context, request *mcp.CallToolRequest, in scaleInput) (*mcp.CallToolResult, any, error) { + if in.Name == "" { + return mcp.NewToolResultError("name parameter is required"), nil, nil + } + if in.Namespace == "" { + in.Namespace = "default" } + if in.Replicas == 0 { + in.Replicas = 1 + } + + args := []string{"scale", "deployment", in.Name, "--replicas", fmt.Sprintf("%d", in.Replicas), "-n", in.Namespace} - args := []string{"scale", "deployment", deploymentName, "--replicas", fmt.Sprintf("%d", replicas), "-n", namespace} + res, err := k.runKubectlCommandWithCacheInvalidation(ctx, mcp.Header(request), args...) + return res, nil, err +} - return k.runKubectlCommandWithCacheInvalidation(ctx, request.Header, args...) +// patchResourceInput is the typed input for k8s_patch_resource. +type patchResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource (deployment, service, etc.)"` + ResourceName string `json:"resource_name" jsonschema:"Name of the resource"` + Patch string `json:"patch" jsonschema:"JSON patch to apply"` + PatchType string `json:"patch_type" jsonschema:"Patch strategy: \"strategic\" (default; built-in Kubernetes types only), \"merge\" (RFC 7386 JSON merge patch; required for CustomResources/CRDs), or \"json\" (RFC 6902 JSON patch)."` + Namespace string `json:"namespace" jsonschema:"Namespace of the resource (default: default)"` } // Patch resource -func (k *K8sTool) handlePatchResource(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - patch := mcp.ParseString(request, "patch", "") - namespace := mcp.ParseString(request, "namespace", "default") - patchType := mcp.ParseString(request, "patch_type", "strategic") +func (k *K8sTool) handlePatchResource(ctx context.Context, request *mcp.CallToolRequest, in patchResourceInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + in.Namespace = "default" + } + if in.PatchType == "" { + in.PatchType = "strategic" + } - if resourceType == "" || resourceName == "" || patch == "" { - return mcp.NewToolResultError("resource_type, resource_name, and patch parameters are required"), nil + if in.ResourceType == "" || in.ResourceName == "" || in.Patch == "" { + return mcp.NewToolResultError("resource_type, resource_name, and patch parameters are required"), nil, nil } // Validate patch type. "strategic" is only implemented for built-in Kubernetes // types; CustomResources (CRDs) reject it and require "merge" or "json". - switch patchType { + switch in.PatchType { case "strategic", "merge", "json": default: - return mcp.NewToolResultError(fmt.Sprintf("Invalid patch_type %q: must be one of strategic, merge, json", patchType)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid patch_type %q: must be one of strategic, merge, json", in.PatchType)), nil, nil } - // Validate resource name for security - if err := security.ValidateK8sResourceName(resourceName); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil + if err := security.ValidateK8sResourceName(in.ResourceName); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil, nil } - // Validate namespace for security - if err := security.ValidateNamespace(namespace); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + if err := security.ValidateNamespace(in.Namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil, nil } - // Validate patch content as JSON/YAML - if err := security.ValidateYAMLContent(patch); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid patch content: %v", err)), nil + if err := security.ValidateYAMLContent(in.Patch); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid patch content: %v", err)), nil, nil } - args := []string{"patch", resourceType, resourceName, "--type=" + patchType, "-p", patch, "-n", namespace} + args := []string{"patch", in.ResourceType, in.ResourceName, "--type=" + in.PatchType, "-p", in.Patch, "-n", in.Namespace} + + res, err := k.runKubectlCommandWithCacheInvalidation(ctx, mcp.Header(request), args...) + return res, nil, err +} - return k.runKubectlCommandWithCacheInvalidation(ctx, request.Header, args...) +// patchStatusInput is the typed input for k8s_patch_status. +type patchStatusInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource (deployment, service, etc.)"` + ResourceName string `json:"resource_name" jsonschema:"Name of the resource"` + Patch string `json:"patch" jsonschema:"JSON/YAML status patch"` + Namespace string `json:"namespace" jsonschema:"Namespace of the resource (default: default)"` } // Patch resource status -func (k *K8sTool) handlePatchStatus(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - patch := mcp.ParseString(request, "patch", "") - namespace := mcp.ParseString(request, "namespace", "default") +func (k *K8sTool) handlePatchStatus(ctx context.Context, request *mcp.CallToolRequest, in patchStatusInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + in.Namespace = "default" + } - if resourceType == "" || resourceName == "" || patch == "" { - return mcp.NewToolResultError("resource_type, resource_name, and patch parameters are required"), nil + if in.ResourceType == "" || in.ResourceName == "" || in.Patch == "" { + return mcp.NewToolResultError("resource_type, resource_name, and patch parameters are required"), nil, nil } - // Validate resource name for security - if err := security.ValidateK8sResourceName(resourceName); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil + if err := security.ValidateK8sResourceName(in.ResourceName); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil, nil } - // Validate namespace for security - if err := security.ValidateNamespace(namespace); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + if err := security.ValidateNamespace(in.Namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil, nil } - // Validate patch content as JSON/YAML - if err := security.ValidateYAMLContent(patch); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid patch content: %v", err)), nil + if err := security.ValidateYAMLContent(in.Patch); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid patch content: %v", err)), nil, nil } args := []string{ "patch", - resourceType, - resourceName, + in.ResourceType, + in.ResourceName, "--subresource=status", "--type=merge", "-p", - patch, + in.Patch, "-n", - namespace, + in.Namespace, } - return k.runKubectlCommandWithCacheInvalidation(ctx, request.Header, args...) + res, err := k.runKubectlCommandWithCacheInvalidation(ctx, mcp.Header(request), args...) + return res, nil, err } -// Apply manifest from content -func (k *K8sTool) handleApplyManifest(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - manifest := mcp.ParseString(request, "manifest", "") +// applyManifestInput is the typed input for k8s_apply_manifest. +type applyManifestInput struct { + Manifest string `json:"manifest" jsonschema:"YAML manifest content"` +} - if manifest == "" { - return mcp.NewToolResultError("manifest parameter is required"), nil +// Apply manifest from content +func (k *K8sTool) handleApplyManifest(ctx context.Context, request *mcp.CallToolRequest, in applyManifestInput) (*mcp.CallToolResult, any, error) { + if in.Manifest == "" { + return mcp.NewToolResultError("manifest parameter is required"), nil, nil } - // Validate YAML content for security - if err := security.ValidateYAMLContent(manifest); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid manifest content: %v", err)), nil + if err := security.ValidateYAMLContent(in.Manifest); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid manifest content: %v", err)), nil, nil } - // Create temporary file with secure permissions tmpFile, err := os.CreateTemp("", "k8s-manifest-*.yaml") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil, nil } - // Ensure file is removed regardless of execution path defer func() { if removeErr := os.Remove(tmpFile.Name()); removeErr != nil { logger.Get().Error("Failed to remove temporary file", "error", removeErr, "file", tmpFile.Name()) } }() - // Set secure file permissions (readable/writable by owner only) if err := os.Chmod(tmpFile.Name(), 0600); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to set file permissions: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to set file permissions: %v", err)), nil, nil } - // Write manifest content to temporary file - if _, err := tmpFile.WriteString(manifest); err != nil { + if _, err := tmpFile.WriteString(in.Manifest); err != nil { tmpFile.Close() - return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil, nil } - // Close the file before passing to kubectl if err := tmpFile.Close(); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to close temp file: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to close temp file: %v", err)), nil, nil } - return k.runKubectlCommandWithCacheInvalidation(ctx, request.Header, "apply", "-f", tmpFile.Name()) + res, err := k.runKubectlCommandWithCacheInvalidation(ctx, mcp.Header(request), "apply", "-f", tmpFile.Name()) + return res, nil, err +} + +// deleteResourceInput is the typed input for k8s_delete_resource. +type deleteResourceInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource (pod, service, deployment, etc.)"` + ResourceName string `json:"resource_name" jsonschema:"Name of the resource"` + Namespace string `json:"namespace" jsonschema:"Namespace of the resource (default: default)"` } // Delete resource -func (k *K8sTool) handleDeleteResource(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "default") +func (k *K8sTool) handleDeleteResource(ctx context.Context, request *mcp.CallToolRequest, in deleteResourceInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + in.Namespace = "default" + } - if resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("resource_type and resource_name parameters are required"), nil + if in.ResourceType == "" || in.ResourceName == "" { + return mcp.NewToolResultError("resource_type and resource_name parameters are required"), nil, nil } - args := []string{"delete", resourceType, resourceName, "-n", namespace} + args := []string{"delete", in.ResourceType, in.ResourceName, "-n", in.Namespace} - return k.runKubectlCommandWithCacheInvalidation(ctx, request.Header, args...) + res, err := k.runKubectlCommandWithCacheInvalidation(ctx, mcp.Header(request), args...) + return res, nil, err } -// Check service connectivity -func (k *K8sTool) handleCheckServiceConnectivity(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - serviceName := mcp.ParseString(request, "service_name", "") - namespace := mcp.ParseString(request, "namespace", "default") +// waitInput is the typed input for k8s_wait. +type waitInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource (pod, deployment, job, etc.)"` + Condition string `json:"condition" jsonschema:"Condition to wait for, passed to --for. Examples: 'condition=Ready', 'condition=Available', 'delete', 'create', \"jsonpath={.status.phase}=Running\""` + ResourceName string `json:"resource_name" jsonschema:"Name of a specific resource. Omit to target by selector or all"` + Selector string `json:"selector" jsonschema:"Label selector to target resources, e.g. 'app=nginx'"` + All bool `json:"all" jsonschema:"Wait on all resources of the type in the namespace"` + Namespace string `json:"namespace" jsonschema:"Namespace of the resource (default: default)"` + Timeout string `json:"timeout" jsonschema:"Max wait duration, e.g. '30s', '5m'. 0 waits forever (default: 30s)"` +} + +// Wait for a condition on one or more resources (kubectl wait) +func (k *K8sTool) handleKubectlWait(ctx context.Context, request *mcp.CallToolRequest, in waitInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + in.Namespace = "default" + } + if in.Timeout == "" { + in.Timeout = "30s" + } + + if in.ResourceType == "" || in.Condition == "" { + return mcp.NewToolResultError("resource_type and condition parameters are required"), nil, nil + } + if in.ResourceName == "" && in.Selector == "" && !in.All { + return mcp.NewToolResultError("one of resource_name, selector, or all=true must be provided"), nil, nil + } + + if err := security.ValidateNamespace(in.Namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil, nil + } + + target := in.ResourceType + if in.ResourceName != "" { + if err := security.ValidateK8sResourceName(in.ResourceName); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid resource name: %v", err)), nil, nil + } + target = fmt.Sprintf("%s/%s", in.ResourceType, in.ResourceName) + } + + args := []string{"wait", target, "--for=" + in.Condition, "--timeout", in.Timeout, "-n", in.Namespace} + if in.Selector != "" { + args = append(args, "-l", in.Selector) + } + if in.All { + args = append(args, "--all") + } + + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err +} - if serviceName == "" { - return mcp.NewToolResultError("service_name parameter is required"), nil +// serviceConnectivityInput is the typed input for k8s_check_service_connectivity. +type serviceConnectivityInput struct { + ServiceName string `json:"service_name" jsonschema:"Service name to test (e.g., my-service.my-namespace.svc.cluster.local:80)"` + Namespace string `json:"namespace" jsonschema:"Namespace to run the check from (default: default)"` +} + +// Check service connectivity +func (k *K8sTool) handleCheckServiceConnectivity(ctx context.Context, request *mcp.CallToolRequest, in serviceConnectivityInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + in.Namespace = "default" } + if in.ServiceName == "" { + return mcp.NewToolResultError("service_name parameter is required"), nil, nil + } + + headers := mcp.Header(request) // Create a temporary curl pod for connectivity check podName := fmt.Sprintf("curl-test-%d", rand.Intn(10000)) defer func() { - _, _ = k.runKubectlCommand(ctx, request.Header, "delete", "pod", podName, "-n", namespace, "--ignore-not-found") + _, _ = k.runKubectlCommand(ctx, headers, "delete", "pod", podName, "-n", in.Namespace, "--ignore-not-found") }() // Create the curl pod - _, err := k.runKubectlCommand(ctx, request.Header, "run", podName, "--image=curlimages/curl", "-n", namespace, "--restart=Never", "--", "sleep", "3600") + _, err := k.runKubectlCommand(ctx, headers, "run", podName, "--image=curlimages/curl", "-n", in.Namespace, "--restart=Never", "--", "sleep", "3600") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to create curl pod: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to create curl pod: %v", err)), nil, nil } // Wait for pod to be ready - _, err = k.runKubectlCommandWithTimeout(ctx, request.Header, 60*time.Second, "wait", "--for=condition=ready", "pod/"+podName, "-n", namespace) + _, err = k.runKubectlCommandWithTimeout(ctx, headers, 60*time.Second, "wait", "--for=condition=ready", "pod/"+podName, "-n", in.Namespace) if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to wait for curl pod: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Failed to wait for curl pod: %v", err)), nil, nil } // Execute kubectl command - return k.runKubectlCommand(ctx, request.Header, "exec", podName, "-n", namespace, "--", "curl", "-s", serviceName) + res, err := k.runKubectlCommand(ctx, headers, "exec", podName, "-n", in.Namespace, "--", "curl", "-s", in.ServiceName) + return res, nil, err } -// Get cluster events -func (k *K8sTool) handleGetEvents(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - namespace := mcp.ParseString(request, "namespace", "") +// eventsInput is the typed input for k8s_get_events. +type eventsInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace to get events from (default: default)"` +} +// Get cluster events +func (k *K8sTool) handleGetEvents(ctx context.Context, request *mcp.CallToolRequest, in eventsInput) (*mcp.CallToolResult, any, error) { args := []string{"get", "events", "-o", "json"} - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } else { args = append(args, "--all-namespaces") } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Execute command in pod -func (k *K8sTool) handleExecCommand(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - podName := mcp.ParseString(request, "pod_name", "") - namespace := mcp.ParseString(request, "namespace", "default") - command := mcp.ParseString(request, "command", "") +// execCommandInput is the typed input for k8s_execute_command. +type execCommandInput struct { + PodName string `json:"pod_name" jsonschema:"Name of the pod to execute in"` + Namespace string `json:"namespace" jsonschema:"Namespace of the pod (default: default)"` + Container string `json:"container" jsonschema:"Container name (for multi-container pods)"` + Command string `json:"command" jsonschema:"Command to execute"` +} - if podName == "" || command == "" { - return mcp.NewToolResultError("pod_name and command parameters are required"), nil +// Execute command in pod +func (k *K8sTool) handleExecCommand(ctx context.Context, request *mcp.CallToolRequest, in execCommandInput) (*mcp.CallToolResult, any, error) { + if in.Namespace == "" { + in.Namespace = "default" + } + if in.PodName == "" || in.Command == "" { + return mcp.NewToolResultError("pod_name and command parameters are required"), nil, nil } - // Validate pod name for security - if err := security.ValidateK8sResourceName(podName); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid pod name: %v", err)), nil + if err := security.ValidateK8sResourceName(in.PodName); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid pod name: %v", err)), nil, nil } - // Validate namespace for security - if err := security.ValidateNamespace(namespace); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil + if err := security.ValidateNamespace(in.Namespace); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid namespace: %v", err)), nil, nil } - // Validate command input for security - if err := security.ValidateCommandInput(command); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid command: %v", err)), nil + if err := security.ValidateCommandInput(in.Command); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Invalid command: %v", err)), nil, nil } - args := []string{"exec", podName, "-n", namespace, "--", command} + args := []string{"exec", in.PodName, "-n", in.Namespace, "--", in.Command} - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } +// noInput is the typed input for tools that take no arguments. +type noInput struct{} + // Get available API resources -func (k *K8sTool) handleGetAvailableAPIResources(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.runKubectlCommand(ctx, request.Header, "api-resources") +func (k *K8sTool) handleGetAvailableAPIResources(ctx context.Context, request *mcp.CallToolRequest, _ noInput) (*mcp.CallToolResult, any, error) { + res, err := k.runKubectlCommand(ctx, mcp.Header(request), "api-resources") + return res, nil, err } -// Kubectl describe tool -func (k *K8sTool) handleKubectlDescribeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "") +// describeInput is the typed input for k8s_describe_resource. +type describeInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource (deployment, service, pod, node, etc.)"` + ResourceName string `json:"resource_name" jsonschema:"Name of the resource"` + Namespace string `json:"namespace" jsonschema:"Namespace of the resource (optional)"` +} - if resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("resource_type and resource_name parameters are required"), nil +// Kubectl describe tool +func (k *K8sTool) handleKubectlDescribeTool(ctx context.Context, request *mcp.CallToolRequest, in describeInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceName == "" { + return mcp.NewToolResultError("resource_type and resource_name parameters are required"), nil, nil } - args := []string{"describe", resourceType, resourceName} - if namespace != "" { - args = append(args, "-n", namespace) + args := []string{"describe", in.ResourceType, in.ResourceName} + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Rollout operations -func (k *K8sTool) handleRollout(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - action := mcp.ParseString(request, "action", "") - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "") +// rolloutInput is the typed input for k8s_rollout. +type rolloutInput struct { + Action string `json:"action" jsonschema:"The rollout action to perform"` + ResourceType string `json:"resource_type" jsonschema:"The type of resource to rollout (e.g., deployment)"` + ResourceName string `json:"resource_name" jsonschema:"The name of the resource to rollout"` + Namespace string `json:"namespace" jsonschema:"The namespace of the resource"` +} - if action == "" || resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("action, resource_type, and resource_name parameters are required"), nil +// Rollout operations +func (k *K8sTool) handleRollout(ctx context.Context, request *mcp.CallToolRequest, in rolloutInput) (*mcp.CallToolResult, any, error) { + if in.Action == "" || in.ResourceType == "" || in.ResourceName == "" { + return mcp.NewToolResultError("action, resource_type, and resource_name parameters are required"), nil, nil } - args := []string{"rollout", action, fmt.Sprintf("%s/%s", resourceType, resourceName)} - if namespace != "" { - args = append(args, "-n", namespace) + args := []string{"rollout", in.Action, fmt.Sprintf("%s/%s", in.ResourceType, in.ResourceName)} + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } // Get cluster configuration -func (k *K8sTool) handleGetClusterConfiguration(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.runKubectlCommand(ctx, request.Header, "config", "view", "-o", "json") +func (k *K8sTool) handleGetClusterConfiguration(ctx context.Context, request *mcp.CallToolRequest, _ noInput) (*mcp.CallToolResult, any, error) { + res, err := k.runKubectlCommand(ctx, mcp.Header(request), "config", "view", "-o", "json") + return res, nil, err } -// Remove annotation -func (k *K8sTool) handleRemoveAnnotation(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - annotationKey := mcp.ParseString(request, "annotation_key", "") - namespace := mcp.ParseString(request, "namespace", "") +// removeAnnotationInput is the typed input for k8s_remove_annotation. +type removeAnnotationInput struct { + ResourceType string `json:"resource_type" jsonschema:"The type of resource"` + ResourceName string `json:"resource_name" jsonschema:"The name of the resource"` + AnnotationKey string `json:"annotation_key" jsonschema:"The key of the annotation to remove"` + Namespace string `json:"namespace" jsonschema:"The namespace of the resource"` +} - if resourceType == "" || resourceName == "" || annotationKey == "" { - return mcp.NewToolResultError("resource_type, resource_name, and annotation_key parameters are required"), nil +// Remove annotation +func (k *K8sTool) handleRemoveAnnotation(ctx context.Context, request *mcp.CallToolRequest, in removeAnnotationInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceName == "" || in.AnnotationKey == "" { + return mcp.NewToolResultError("resource_type, resource_name, and annotation_key parameters are required"), nil, nil } - args := []string{"annotate", resourceType, resourceName, annotationKey + "-"} - if namespace != "" { - args = append(args, "-n", namespace) + args := []string{"annotate", in.ResourceType, in.ResourceName, in.AnnotationKey + "-"} + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Remove label -func (k *K8sTool) handleRemoveLabel(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - labelKey := mcp.ParseString(request, "label_key", "") - namespace := mcp.ParseString(request, "namespace", "") +// removeLabelInput is the typed input for k8s_remove_label. +type removeLabelInput struct { + ResourceType string `json:"resource_type" jsonschema:"The type of resource"` + ResourceName string `json:"resource_name" jsonschema:"The name of the resource"` + LabelKey string `json:"label_key" jsonschema:"The key of the label to remove"` + Namespace string `json:"namespace" jsonschema:"The namespace of the resource"` +} - if resourceType == "" || resourceName == "" || labelKey == "" { - return mcp.NewToolResultError("resource_type, resource_name, and label_key parameters are required"), nil +// Remove label +func (k *K8sTool) handleRemoveLabel(ctx context.Context, request *mcp.CallToolRequest, in removeLabelInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceName == "" || in.LabelKey == "" { + return mcp.NewToolResultError("resource_type, resource_name, and label_key parameters are required"), nil, nil } - args := []string{"label", resourceType, resourceName, labelKey + "-"} - if namespace != "" { - args = append(args, "-n", namespace) + args := []string{"label", in.ResourceType, in.ResourceName, in.LabelKey + "-"} + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Annotate resource -func (k *K8sTool) handleAnnotateResource(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - annotations := mcp.ParseString(request, "annotations", "") - namespace := mcp.ParseString(request, "namespace", "") +// annotateInput is the typed input for k8s_annotate_resource. +type annotateInput struct { + ResourceType string `json:"resource_type" jsonschema:"The type of resource"` + ResourceName string `json:"resource_name" jsonschema:"The name of the resource"` + Annotations string `json:"annotations" jsonschema:"Space-separated key=value pairs for annotations"` + Namespace string `json:"namespace" jsonschema:"The namespace of the resource"` +} - if resourceType == "" || resourceName == "" || annotations == "" { - return mcp.NewToolResultError("resource_type, resource_name, and annotations parameters are required"), nil +// Annotate resource +func (k *K8sTool) handleAnnotateResource(ctx context.Context, request *mcp.CallToolRequest, in annotateInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceName == "" || in.Annotations == "" { + return mcp.NewToolResultError("resource_type, resource_name, and annotations parameters are required"), nil, nil } - args := []string{"annotate", resourceType, resourceName} - args = append(args, strings.Fields(annotations)...) + args := []string{"annotate", in.ResourceType, in.ResourceName} + args = append(args, strings.Fields(in.Annotations)...) - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err } -// Label resource -func (k *K8sTool) handleLabelResource(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - labels := mcp.ParseString(request, "labels", "") - namespace := mcp.ParseString(request, "namespace", "") +// labelInput is the typed input for k8s_label_resource. +type labelInput struct { + ResourceType string `json:"resource_type" jsonschema:"The type of resource"` + ResourceName string `json:"resource_name" jsonschema:"The name of the resource"` + Labels string `json:"labels" jsonschema:"Space-separated key=value pairs for labels"` + Namespace string `json:"namespace" jsonschema:"The namespace of the resource"` +} - if resourceType == "" || resourceName == "" || labels == "" { - return mcp.NewToolResultError("resource_type, resource_name, and labels parameters are required"), nil +// Label resource +func (k *K8sTool) handleLabelResource(ctx context.Context, request *mcp.CallToolRequest, in labelInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceName == "" || in.Labels == "" { + return mcp.NewToolResultError("resource_type, resource_name, and labels parameters are required"), nil, nil } - args := []string{"label", resourceType, resourceName} - args = append(args, strings.Fields(labels)...) + args := []string{"label", in.ResourceType, in.ResourceName} + args = append(args, strings.Fields(in.Labels)...) - if namespace != "" { - args = append(args, "-n", namespace) + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) } - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err +} + +// createFromURLInput is the typed input for k8s_create_resource_from_url. +type createFromURLInput struct { + URL string `json:"url" jsonschema:"The URL of the manifest"` + Namespace string `json:"namespace" jsonschema:"The namespace to create the resource in"` } // Create resource from URL -func (k *K8sTool) handleCreateResourceFromURL(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - url := mcp.ParseString(request, "url", "") - namespace := mcp.ParseString(request, "namespace", "") +func (k *K8sTool) handleCreateResourceFromURL(ctx context.Context, request *mcp.CallToolRequest, in createFromURLInput) (*mcp.CallToolResult, any, error) { + if in.URL == "" { + return mcp.NewToolResultError("url parameter is required"), nil, nil + } + + args := []string{"create", "-f", in.URL} + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) + } + + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + return res, nil, err +} - if url == "" { - return mcp.NewToolResultError("url parameter is required"), nil +// getResourceYAMLInput is the typed input for k8s_get_resource_yaml. +type getResourceYAMLInput struct { + ResourceType string `json:"resource_type" jsonschema:"Type of resource"` + ResourceName string `json:"resource_name" jsonschema:"Name of the resource"` + Namespace string `json:"namespace" jsonschema:"Namespace of the resource (optional)"` +} + +// Get resource YAML +func (k *K8sTool) handleGetResourceYAML(ctx context.Context, request *mcp.CallToolRequest, in getResourceYAMLInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceName == "" { + return mcp.NewToolResultError("resource_type and resource_name are required"), nil, nil + } + + args := []string{"get", in.ResourceType, in.ResourceName, "-o", "yaml"} + if in.Namespace != "" { + args = append(args, "-n", in.Namespace) + } + + res, err := k.runKubectlCommand(ctx, mcp.Header(request), args...) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil, nil + } + return res, nil, nil +} + +// createResourceInput is the typed input for k8s_create_resource. +type createResourceInput struct { + YAMLContent string `json:"yaml_content" jsonschema:"YAML content of the resource"` +} + +// Create resource from YAML content +func (k *K8sTool) handleCreateResource(ctx context.Context, request *mcp.CallToolRequest, in createResourceInput) (*mcp.CallToolResult, any, error) { + if in.YAMLContent == "" { + return mcp.NewToolResultError("yaml_content is required"), nil, nil + } + + tmpFile, err := os.CreateTemp("", "k8s-resource-*.yaml") + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil, nil } + defer os.Remove(tmpFile.Name()) - args := []string{"create", "-f", url} - if namespace != "" { - args = append(args, "-n", namespace) + if _, err := tmpFile.WriteString(in.YAMLContent); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil, nil } + tmpFile.Close() - return k.runKubectlCommand(ctx, request.Header, args...) + res, err := k.runKubectlCommand(ctx, mcp.Header(request), "create", "-f", tmpFile.Name()) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Create command failed: %v", err)), nil, nil + } + return res, nil, nil } // Resource generation embeddings @@ -530,23 +726,25 @@ var ( resourceTypes = maps.Keys(resourceMap) ) -// Generate resource using LLM -func (k *K8sTool) handleGenerateResource(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceDescription := mcp.ParseString(request, "resource_description", "") +// generateResourceInput is the typed input for k8s_generate_resource. +type generateResourceInput struct { + ResourceDescription string `json:"resource_description" jsonschema:"Detailed description of the resource to generate"` + ResourceType string `json:"resource_type" jsonschema:"Type of resource to generate"` +} - if resourceType == "" || resourceDescription == "" { - return mcp.NewToolResultError("resource_type and resource_description parameters are required"), nil +// Generate resource using LLM +func (k *K8sTool) handleGenerateResource(ctx context.Context, request *mcp.CallToolRequest, in generateResourceInput) (*mcp.CallToolResult, any, error) { + if in.ResourceType == "" || in.ResourceDescription == "" { + return mcp.NewToolResultError("resource_type and resource_description parameters are required"), nil, nil } - systemPrompt, ok := resourceMap[resourceType] + systemPrompt, ok := resourceMap[in.ResourceType] if !ok { - return mcp.NewToolResultError(fmt.Sprintf("resource type %s not found", resourceType)), nil + return mcp.NewToolResultError(fmt.Sprintf("resource type %s not found", in.ResourceType)), nil, nil } - // Use the injected LLM model if available, otherwise create a new OpenAI instance if k.llmModel == nil { - return mcp.NewToolResultError("No LLM client present, can't generate resource"), nil + return mcp.NewToolResultError("No LLM client present, can't generate resource"), nil, nil } llm := k.llmModel @@ -560,24 +758,23 @@ func (k *K8sTool) handleGenerateResource(ctx context.Context, request mcp.CallTo { Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ - llms.TextContent{Text: resourceDescription}, + llms.TextContent{Text: in.ResourceDescription}, }, }, } resp, err := llm.GenerateContent(ctx, contents, llms.WithModel("gpt-4o-mini")) if err != nil { - return mcp.NewToolResultError("failed to generate content: " + err.Error()), nil + return mcp.NewToolResultError("failed to generate content: " + err.Error()), nil, nil } choices := resp.Choices if len(choices) < 1 { - return mcp.NewToolResultError("empty response from model"), nil + return mcp.NewToolResultError("empty response from model"), nil, nil } - c1 := choices[0] - responseText := c1.Content + responseText := choices[0].Content - return mcp.NewToolResultText(responseText), nil + return mcp.NewToolResultText(responseText), nil, nil } // extractBearerToken extracts the Bearer token from the Authorization header @@ -641,207 +838,126 @@ func (k *K8sTool) runKubectlCommandWithTimeout(ctx context.Context, headers http return mcp.NewToolResultText(output), nil } -// RegisterK8sTools registers all k8s tools with the MCP server -func RegisterTools(s *server.MCPServer, llm llms.Model, kubeconfig string, readOnly bool) { +// RegisterTools registers all k8s tools with the MCP server +func RegisterTools(s *mcp.Server, llm llms.Model, kubeconfig string, readOnly bool) { k8sTool := NewK8sToolWithConfig(kubeconfig, llm) // Read-only tools - always registered - s.AddTool(mcp.NewTool("k8s_get_resources", - mcp.WithDescription("Get Kubernetes resources using kubectl"), - mcp.WithString("resource_type", mcp.Description("Type of resource (pod, service, deployment, etc.)"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("Name of specific resource (optional)")), - mcp.WithString("namespace", mcp.Description("Namespace to query (optional)")), - mcp.WithString("all_namespaces", mcp.Description("Query all namespaces (true/false)")), - mcp.WithString("output", mcp.Description("Output format (json, yaml, wide)"), mcp.DefaultString("wide")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_resources", k8sTool.handleKubectlGetEnhanced))) - - s.AddTool(mcp.NewTool("k8s_get_pod_logs", - mcp.WithDescription("Get logs from a Kubernetes pod"), - mcp.WithString("pod_name", mcp.Description("Name of the pod"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the pod (default: default)")), - mcp.WithString("container", mcp.Description("Container name (for multi-container pods)")), - mcp.WithNumber("tail_lines", mcp.Description("Number of lines to show from the end (default: 50)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_pod_logs", k8sTool.handleKubectlLogsEnhanced))) - - s.AddTool(mcp.NewTool("k8s_get_events", - mcp.WithDescription("Get events from a Kubernetes namespace"), - mcp.WithString("namespace", mcp.Description("Namespace to get events from (default: default)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_events", k8sTool.handleGetEvents))) - - s.AddTool(mcp.NewTool("k8s_get_available_api_resources", - mcp.WithDescription("Get available Kubernetes API resources"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_available_api_resources", k8sTool.handleGetAvailableAPIResources))) - - s.AddTool(mcp.NewTool("k8s_get_cluster_configuration", - mcp.WithDescription("Get cluster configuration details"), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_cluster_configuration", k8sTool.handleGetClusterConfiguration))) - - s.AddTool(mcp.NewTool("k8s_get_resource_yaml", - mcp.WithDescription("Get the YAML representation of a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("Type of resource"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the resource (optional)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_get_resource_yaml", func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - resourceType := mcp.ParseString(request, "resource_type", "") - resourceName := mcp.ParseString(request, "resource_name", "") - namespace := mcp.ParseString(request, "namespace", "") - - if resourceType == "" || resourceName == "" { - return mcp.NewToolResultError("resource_type and resource_name are required"), nil - } - - args := []string{"get", resourceType, resourceName, "-o", "yaml"} - if namespace != "" { - args = append(args, "-n", namespace) - } - - result, err := k8sTool.runKubectlCommand(ctx, request.Header, args...) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Get YAML command failed: %v", err)), nil - } - - return result, nil - }))) - - s.AddTool(mcp.NewTool("k8s_describe_resource", - mcp.WithDescription("Describe a Kubernetes resource in detail"), - mcp.WithString("resource_type", mcp.Description("Type of resource (deployment, service, pod, node, etc.)"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the resource (optional)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_describe_resource", k8sTool.handleKubectlDescribeTool))) - - s.AddTool(mcp.NewTool("k8s_generate_resource", - mcp.WithDescription("Generate a Kubernetes resource YAML from a description"), - mcp.WithString("resource_description", mcp.Description("Detailed description of the resource to generate"), mcp.Required()), - mcp.WithString("resource_type", mcp.Description(fmt.Sprintf("Type of resource to generate (%s)", strings.Join(slices.Collect(resourceTypes), ", "))), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_generate_resource", k8sTool.handleGenerateResource))) + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_get_resources", + Description: "Get Kubernetes resources using kubectl", + }, k8sTool.handleKubectlGetEnhanced) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_get_pod_logs", + Description: "Get logs from a Kubernetes pod", + }, k8sTool.handleKubectlLogsEnhanced) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_get_events", + Description: "Get events from a Kubernetes namespace", + }, k8sTool.handleGetEvents) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_get_available_api_resources", + Description: "Get available Kubernetes API resources", + }, k8sTool.handleGetAvailableAPIResources) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_get_cluster_configuration", + Description: "Get cluster configuration details", + }, k8sTool.handleGetClusterConfiguration) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_get_resource_yaml", + Description: "Get the YAML representation of a Kubernetes resource", + }, k8sTool.handleGetResourceYAML) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_describe_resource", + Description: "Describe a Kubernetes resource in detail", + }, k8sTool.handleKubectlDescribeTool) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_wait", + Description: "Wait for a condition on Kubernetes resources (kubectl wait). Blocks until the condition is met or the timeout elapses.", + }, k8sTool.handleKubectlWait) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_generate_resource", + Description: fmt.Sprintf("Generate a Kubernetes resource YAML from a description. Supported resource_type values: %s", strings.Join(slices.Collect(resourceTypes), ", ")), + }, k8sTool.handleGenerateResource) // Write tools - only registered when write operations are enabled if !readOnly { - s.AddTool(mcp.NewTool("k8s_scale", - mcp.WithDescription("Scale a Kubernetes deployment"), - mcp.WithString("name", mcp.Description("Name of the deployment"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the deployment (default: default)")), - mcp.WithNumber("replicas", mcp.Description("Number of replicas"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_scale", k8sTool.handleScaleDeployment))) - - s.AddTool(mcp.NewTool("k8s_patch_resource", - mcp.WithDescription("Patch a Kubernetes resource. Defaults to a strategic merge patch, which is only supported for built-in types; set patch_type to \"merge\" (or \"json\") to patch a CustomResource/CRD."), - mcp.WithString("resource_type", mcp.Description("Type of resource (deployment, service, etc.)"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), - mcp.WithString("patch", mcp.Description("JSON patch to apply"), mcp.Required()), - mcp.WithString("patch_type", mcp.Description("Patch strategy: \"strategic\" (default; built-in Kubernetes types only), \"merge\" (RFC 7386 JSON merge patch; required for CustomResources/CRDs), or \"json\" (RFC 6902 JSON patch).")), - mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_patch_resource", k8sTool.handlePatchResource))) - - s.AddTool(mcp.NewTool("k8s_patch_status", - mcp.WithDescription("Patch the status of a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("Type of resource (deployment, service, etc.)"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), - mcp.WithString("patch", mcp.Description("JSON/YAML status patch"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_patch_status", k8sTool.handlePatchStatus))) - - s.AddTool(mcp.NewTool("k8s_apply_manifest", - mcp.WithDescription("Apply a YAML manifest to the Kubernetes cluster"), - mcp.WithString("manifest", mcp.Description("YAML manifest content"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_apply_manifest", k8sTool.handleApplyManifest))) - - s.AddTool(mcp.NewTool("k8s_delete_resource", - mcp.WithDescription("Delete a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("Type of resource (pod, service, deployment, etc.)"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("Name of the resource"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the resource (default: default)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_delete_resource", k8sTool.handleDeleteResource))) - - s.AddTool(mcp.NewTool("k8s_check_service_connectivity", - mcp.WithDescription("Check connectivity to a service using a temporary curl pod"), - mcp.WithString("service_name", mcp.Description("Service name to test (e.g., my-service.my-namespace.svc.cluster.local:80)"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace to run the check from (default: default)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_check_service_connectivity", k8sTool.handleCheckServiceConnectivity))) - - s.AddTool(mcp.NewTool("k8s_execute_command", - mcp.WithDescription("Execute a command in a Kubernetes pod"), - mcp.WithString("pod_name", mcp.Description("Name of the pod to execute in"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("Namespace of the pod (default: default)")), - mcp.WithString("container", mcp.Description("Container name (for multi-container pods)")), - mcp.WithString("command", mcp.Description("Command to execute"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_execute_command", k8sTool.handleExecCommand))) - - s.AddTool(mcp.NewTool("k8s_rollout", - mcp.WithDescription("Perform rollout operations on Kubernetes resources (history, pause, restart, resume, status, undo)"), - mcp.WithString("action", mcp.Description("The rollout action to perform"), mcp.Required()), - mcp.WithString("resource_type", mcp.Description("The type of resource to rollout (e.g., deployment)"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("The name of the resource to rollout"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_rollout", k8sTool.handleRollout))) - - s.AddTool(mcp.NewTool("k8s_label_resource", - mcp.WithDescription("Add or update labels on a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("The type of resource"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), - mcp.WithString("labels", mcp.Description("Space-separated key=value pairs for labels"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_label_resource", k8sTool.handleLabelResource))) - - s.AddTool(mcp.NewTool("k8s_annotate_resource", - mcp.WithDescription("Add or update annotations on a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("The type of resource"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), - mcp.WithString("annotations", mcp.Description("Space-separated key=value pairs for annotations"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_annotate_resource", k8sTool.handleAnnotateResource))) - - s.AddTool(mcp.NewTool("k8s_remove_annotation", - mcp.WithDescription("Remove an annotation from a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("The type of resource"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), - mcp.WithString("annotation_key", mcp.Description("The key of the annotation to remove"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_remove_annotation", k8sTool.handleRemoveAnnotation))) - - s.AddTool(mcp.NewTool("k8s_remove_label", - mcp.WithDescription("Remove a label from a Kubernetes resource"), - mcp.WithString("resource_type", mcp.Description("The type of resource"), mcp.Required()), - mcp.WithString("resource_name", mcp.Description("The name of the resource"), mcp.Required()), - mcp.WithString("label_key", mcp.Description("The key of the label to remove"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace of the resource")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_remove_label", k8sTool.handleRemoveLabel))) - - s.AddTool(mcp.NewTool("k8s_create_resource", - mcp.WithDescription("Create a Kubernetes resource from YAML content"), - mcp.WithString("yaml_content", mcp.Description("YAML content of the resource"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_create_resource", func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - yamlContent := mcp.ParseString(request, "yaml_content", "") - - if yamlContent == "" { - return mcp.NewToolResultError("yaml_content is required"), nil - } - - // Create temporary file - tmpFile, err := os.CreateTemp("", "k8s-resource-*.yaml") - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to create temp file: %v", err)), nil - } - defer os.Remove(tmpFile.Name()) - - if _, err := tmpFile.WriteString(yamlContent); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Failed to write to temp file: %v", err)), nil - } - tmpFile.Close() - - result, err := k8sTool.runKubectlCommand(ctx, request.Header, "create", "-f", tmpFile.Name()) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Create command failed: %v", err)), nil - } - - return result, nil - }))) - - s.AddTool(mcp.NewTool("k8s_create_resource_from_url", - mcp.WithDescription("Create a Kubernetes resource from a URL pointing to a YAML manifest"), - mcp.WithString("url", mcp.Description("The URL of the manifest"), mcp.Required()), - mcp.WithString("namespace", mcp.Description("The namespace to create the resource in")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("k8s_create_resource_from_url", k8sTool.handleCreateResourceFromURL))) + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_scale", + Description: "Scale a Kubernetes deployment", + }, k8sTool.handleScaleDeployment) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_patch_resource", + Description: "Patch a Kubernetes resource. Defaults to a strategic merge patch, which is only supported for built-in types; set patch_type to \"merge\" (or \"json\") to patch a CustomResource/CRD.", + }, k8sTool.handlePatchResource) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_patch_status", + Description: "Patch the status of a Kubernetes resource", + }, k8sTool.handlePatchStatus) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_apply_manifest", + Description: "Apply a YAML manifest to the Kubernetes cluster", + }, k8sTool.handleApplyManifest) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_delete_resource", + Description: "Delete a Kubernetes resource", + }, k8sTool.handleDeleteResource) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_check_service_connectivity", + Description: "Check connectivity to a service using a temporary curl pod", + }, k8sTool.handleCheckServiceConnectivity) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_execute_command", + Description: "Execute a command in a Kubernetes pod", + }, k8sTool.handleExecCommand) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_rollout", + Description: "Perform rollout operations on Kubernetes resources (history, pause, restart, resume, status, undo)", + }, k8sTool.handleRollout) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_label_resource", + Description: "Add or update labels on a Kubernetes resource", + }, k8sTool.handleLabelResource) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_annotate_resource", + Description: "Add or update annotations on a Kubernetes resource", + }, k8sTool.handleAnnotateResource) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_remove_annotation", + Description: "Remove an annotation from a Kubernetes resource", + }, k8sTool.handleRemoveAnnotation) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_remove_label", + Description: "Remove a label from a Kubernetes resource", + }, k8sTool.handleRemoveLabel) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_create_resource", + Description: "Create a Kubernetes resource from YAML content", + }, k8sTool.handleCreateResource) + + mcp.AddTool(s, "k8s", &mcp.Tool{ + Name: "k8s_create_resource_from_url", + Description: "Create a Kubernetes resource from a URL pointing to a YAML manifest", + }, k8sTool.handleCreateResourceFromURL) } } diff --git a/pkg/k8s/k8s_test.go b/pkg/k8s/k8s_test.go index 6c008a20..c292cb36 100644 --- a/pkg/k8s/k8s_test.go +++ b/pkg/k8s/k8s_test.go @@ -6,8 +6,7 @@ import ( "testing" "github.com/kagent-dev/tools/internal/cmd" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/llms" @@ -15,11 +14,11 @@ import ( func TestRegisterTools(t *testing.T) { t.Run("read-write", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, nil, "", false) }) t.Run("read-only", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, nil, "/tmp/kubeconfig", true) }) } @@ -51,7 +50,7 @@ func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { return "" } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { return textContent.Text } return "" @@ -65,11 +64,8 @@ func headerWithBearerToken(token string) http.Header { } // Helper function to create a CallToolRequest with Bearer token -func requestWithBearerToken(token string, args map[string]interface{}) mcp.CallToolRequest { - req := mcp.CallToolRequest{} - req.Header = headerWithBearerToken(token) - req.Params.Arguments = args - return req +func requestWithBearerToken(token string) *mcp.CallToolRequest { + return &mcp.CallToolRequest{Extra: &mcp.RequestExtra{Header: headerWithBearerToken(token)}} } func TestHandleGetAvailableAPIResources(t *testing.T) { @@ -85,8 +81,8 @@ services svc v1 k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - result, err := k8sTool.handleGetAvailableAPIResources(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGetAvailableAPIResources(ctx, req, noInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -103,8 +99,8 @@ services svc v1 k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - result, err := k8sTool.handleGetAvailableAPIResources(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGetAvailableAPIResources(ctx, req, noInput{}) assert.NoError(t, err) // MCP handlers should not return Go errors assert.NotNil(t, result) assert.True(t, result.IsError) @@ -122,13 +118,8 @@ func TestHandleScaleDeployment(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "name": "test-deployment", - "replicas": float64(5), // JSON numbers come as float64 - } - - result, err := k8sTool.handleScaleDeployment(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleScaleDeployment(ctx, req, scaleInput{Name: "test-deployment", Replicas: 5}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -144,13 +135,8 @@ func TestHandleScaleDeployment(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - // Missing name parameter (this is the required one) - "replicas": float64(3), - } - - result, err := k8sTool.handleScaleDeployment(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleScaleDeployment(ctx, req, scaleInput{Replicas: 3}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -169,12 +155,8 @@ func TestHandleScaleDeployment(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "name": "test-deployment", - } - - result, err := k8sTool.handleScaleDeployment(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleScaleDeployment(ctx, req, scaleInput{Name: "test-deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -201,8 +183,8 @@ func TestHandleGetEvents(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - result, err := k8sTool.handleGetEvents(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGetEvents(ctx, req, eventsInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -219,12 +201,8 @@ func TestHandleGetEvents(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "namespace": "custom-namespace", - } - - result, err := k8sTool.handleGetEvents(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGetEvents(ctx, req, eventsInput{Namespace: "custom-namespace"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -240,13 +218,8 @@ func TestHandlePatchResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - // Missing resource_name and patch - } - - result, err := k8sTool.handlePatchResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handlePatchResource(ctx, req, patchResourceInput{ResourceType: "deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -264,14 +237,8 @@ func TestHandlePatchResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "patch": `{"spec":{"replicas":5}}`, - } - - result, err := k8sTool.handlePatchResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handlePatchResource(ctx, req, patchResourceInput{ResourceType: "deployment", ResourceName: "test-deployment", Patch: `{"spec":{"replicas":5}}`}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -288,16 +255,14 @@ func TestHandlePatchResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "installers.composition.krateo.io", - "resource_name": "installer", - "patch": `{"spec":{"features":{"composableportal":true}}}`, - "patch_type": "merge", - "namespace": "krateo-system", - } - - result, err := k8sTool.handlePatchResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handlePatchResource(ctx, req, patchResourceInput{ + ResourceType: "installers.composition.krateo.io", + ResourceName: "installer", + Patch: `{"spec":{"features":{"composableportal":true}}}`, + PatchType: "merge", + Namespace: "krateo-system", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -312,15 +277,13 @@ func TestHandlePatchResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "patch": `{"spec":{"replicas":5}}`, - "patch_type": "bogus", - } - - result, err := k8sTool.handlePatchResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handlePatchResource(ctx, req, patchResourceInput{ + ResourceType: "deployment", + ResourceName: "test-deployment", + Patch: `{"spec":{"replicas":5}}`, + PatchType: "bogus", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -339,13 +302,8 @@ func TestHandlePatchStatus(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "customresource", - // Missing resource_name and patch - } - - result, err := k8sTool.handlePatchStatus(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handlePatchStatus(ctx, req, patchStatusInput{ResourceType: "customresource"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -363,14 +321,8 @@ func TestHandlePatchStatus(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "customresource", - "resource_name": "test-resource", - "patch": `{"status":{"phase":"Ready"}}`, - } - - result, err := k8sTool.handlePatchStatus(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handlePatchStatus(ctx, req, patchStatusInput{ResourceType: "customresource", ResourceName: "test-resource", Patch: `{"status":{"phase":"Ready"}}`}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -389,13 +341,8 @@ func TestHandleDeleteResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "pod", - // Missing resource_name - } - - result, err := k8sTool.handleDeleteResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleDeleteResource(ctx, req, deleteResourceInput{ResourceType: "pod"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -413,13 +360,8 @@ func TestHandleDeleteResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - } - - result, err := k8sTool.handleDeleteResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleDeleteResource(ctx, req, deleteResourceInput{ResourceType: "deployment", ResourceName: "test-deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -438,10 +380,8 @@ func TestHandleCheckServiceConnectivity(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{} - - result, err := k8sTool.handleCheckServiceConnectivity(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleCheckServiceConnectivity(ctx, req, serviceConnectivityInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -464,12 +404,8 @@ func TestHandleCheckServiceConnectivity(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "service_name": "test-service.default.svc.cluster.local:80", - } - - result, err := k8sTool.handleCheckServiceConnectivity(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleCheckServiceConnectivity(ctx, req, serviceConnectivityInput{ServiceName: "test-service.default.svc.cluster.local:80"}) assert.NoError(t, err) assert.NotNil(t, result) // Should attempt connectivity check (may succeed or fail but validates params) @@ -485,13 +421,8 @@ func TestHandleKubectlDescribeTool(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - // Missing resource_name - } - - result, err := k8sTool.handleKubectlDescribeTool(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlDescribeTool(ctx, req, describeInput{ResourceType: "deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -511,14 +442,8 @@ Labels: app=test` k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "namespace": "default", - } - - result, err := k8sTool.handleKubectlDescribeTool(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlDescribeTool(ctx, req, describeInput{ResourceType: "deployment", ResourceName: "test-deployment", Namespace: "default"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -536,8 +461,8 @@ func TestHandleKubectlGetEnhanced(t *testing.T) { ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - result, err := k8sTool.handleKubectlGetEnhanced(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlGetEnhanced(ctx, req, getResourcesInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -554,9 +479,8 @@ func TestHandleKubectlGetEnhanced(t *testing.T) { ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"resource_type": "pods"} - result, err := k8sTool.handleKubectlGetEnhanced(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlGetEnhanced(ctx, req, getResourcesInput{ResourceType: "pods"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -571,8 +495,8 @@ func TestHandleKubectlLogsEnhanced(t *testing.T) { ctx := cmd.WithShellExecutor(context.Background(), mock) k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - result, err := k8sTool.handleKubectlLogsEnhanced(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlLogsEnhanced(ctx, req, logsInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -590,9 +514,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"pod_name": "test-pod"} - result, err := k8sTool.handleKubectlLogsEnhanced(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlLogsEnhanced(ctx, req, logsInput{PodName: "test-pod"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -619,12 +542,8 @@ spec: k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "manifest": manifest, - } - - result, err := k8sTool.handleApplyManifest(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleApplyManifest(ctx, req, applyManifestInput{Manifest: manifest}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -650,12 +569,8 @@ spec: k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - // Missing manifest parameter - } - - result, err := k8sTool.handleApplyManifest(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleApplyManifest(ctx, req, applyManifestInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -681,14 +596,8 @@ drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..` k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "pod_name": "mypod", - "namespace": "default", - "command": "ls -la", - } - - result, err := k8sTool.handleExecCommand(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleExecCommand(ctx, req, execCommandInput{PodName: "mypod", Namespace: "default", Command: "ls -la"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -710,13 +619,8 @@ drwxr-xr-x 1 root root 4096 Jan 1 12:00 ..` k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "pod_name": "mypod", - // Missing command parameter - } - - result, err := k8sTool.handleExecCommand(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleExecCommand(ctx, req, execCommandInput{PodName: "mypod"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -739,15 +643,13 @@ func TestHandleRollout(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "action": "restart", - "resource_type": "deployment", - "resource_name": "myapp", - "namespace": "default", - } - - result, err := k8sTool.handleRollout(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleRollout(ctx, req, rolloutInput{ + Action: "restart", + ResourceType: "deployment", + ResourceName: "myapp", + Namespace: "default", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -769,13 +671,8 @@ func TestHandleRollout(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "action": "restart", - // Missing resource_type and resource_name - } - - result, err := k8sTool.handleRollout(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleRollout(ctx, req, rolloutInput{Action: "restart"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -831,13 +728,8 @@ spec: k8sTool := newTestK8sToolWithLLM(mockLLM) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "istio_auth_policy", - "resource_description": "A peer authentication policy for strict mTLS", - } - - result, err := k8sTool.handleGenerateResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGenerateResource(ctx, req, generateResourceInput{ResourceType: "istio_auth_policy", ResourceDescription: "A peer authentication policy for strict mTLS"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -853,13 +745,8 @@ spec: t.Run("missing parameters", func(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "istio_auth_policy", - // Missing resource_description - } - - result, err := k8sTool.handleGenerateResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGenerateResource(ctx, req, generateResourceInput{ResourceType: "istio_auth_policy"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -869,13 +756,8 @@ spec: t.Run("no LLM model", func(t *testing.T) { k8sTool := newTestK8sTool() // No LLM model - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "istio_auth_policy", - "resource_description": "A peer authentication policy for strict mTLS", - } - - result, err := k8sTool.handleGenerateResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGenerateResource(ctx, req, generateResourceInput{ResourceType: "istio_auth_policy", ResourceDescription: "A peer authentication policy for strict mTLS"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -891,13 +773,8 @@ spec: k8sTool := newTestK8sToolWithLLM(mockLLM) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "invalid_resource_type", - "resource_description": "A test resource", - } - - result, err := k8sTool.handleGenerateResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGenerateResource(ctx, req, generateResourceInput{ResourceType: "invalid_resource_type", ResourceDescription: "A test resource"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -920,15 +797,13 @@ func TestHandleAnnotateResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "annotations": "key1=value1 key2=value2", - "namespace": "default", - } - - result, err := k8sTool.handleAnnotateResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleAnnotateResource(ctx, req, annotateInput{ + ResourceType: "deployment", + ResourceName: "test-deployment", + Annotations: "key1=value1 key2=value2", + Namespace: "default", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -943,13 +818,8 @@ func TestHandleAnnotateResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - // Missing resource_name and annotations - } - - result, err := k8sTool.handleAnnotateResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleAnnotateResource(ctx, req, annotateInput{ResourceType: "deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -972,15 +842,13 @@ func TestHandleLabelResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "labels": "env=prod version=1.0", - "namespace": "default", - } - - result, err := k8sTool.handleLabelResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleLabelResource(ctx, req, labelInput{ + ResourceType: "deployment", + ResourceName: "test-deployment", + Labels: "env=prod version=1.0", + Namespace: "default", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -995,13 +863,8 @@ func TestHandleLabelResource(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - // Missing resource_name and labels - } - - result, err := k8sTool.handleLabelResource(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleLabelResource(ctx, req, labelInput{ResourceType: "deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -1024,15 +887,13 @@ func TestHandleRemoveAnnotation(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "annotation_key": "key1", - "namespace": "default", - } - - result, err := k8sTool.handleRemoveAnnotation(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleRemoveAnnotation(ctx, req, removeAnnotationInput{ + ResourceType: "deployment", + ResourceName: "test-deployment", + AnnotationKey: "key1", + Namespace: "default", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1047,13 +908,8 @@ func TestHandleRemoveAnnotation(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - // Missing resource_name and annotation_key - } - - result, err := k8sTool.handleRemoveAnnotation(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleRemoveAnnotation(ctx, req, removeAnnotationInput{ResourceType: "deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -1076,15 +932,13 @@ func TestHandleRemoveLabel(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "label_key": "env", - "namespace": "default", - } - - result, err := k8sTool.handleRemoveLabel(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleRemoveLabel(ctx, req, removeLabelInput{ + ResourceType: "deployment", + ResourceName: "test-deployment", + LabelKey: "env", + Namespace: "default", + }) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1099,13 +953,8 @@ func TestHandleRemoveLabel(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "resource_type": "deployment", - // Missing resource_name and label_key - } - - result, err := k8sTool.handleRemoveLabel(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleRemoveLabel(ctx, req, removeLabelInput{ResourceType: "deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -1128,13 +977,8 @@ func TestHandleCreateResourceFromURL(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - "url": "https://example.com/manifest.yaml", - "namespace": "default", - } - - result, err := k8sTool.handleCreateResourceFromURL(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleCreateResourceFromURL(ctx, req, createFromURLInput{URL: "https://example.com/manifest.yaml", Namespace: "default"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1149,12 +993,8 @@ func TestHandleCreateResourceFromURL(t *testing.T) { k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{ - // Missing url parameter - } - - result, err := k8sTool.handleCreateResourceFromURL(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleCreateResourceFromURL(ctx, req, createFromURLInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -1191,8 +1031,8 @@ users: k8sTool := newTestK8sTool() - req := mcp.CallToolRequest{} - result, err := k8sTool.handleGetClusterConfiguration(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleGetClusterConfiguration(ctx, req, noInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1214,8 +1054,8 @@ func TestBearerTokenPassthrough(t *testing.T) { ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("test-token-123", map[string]interface{}{"resource_type": "pods"}) - result, err := k8sTool.handleKubectlGetEnhanced(ctx, req) + req := requestWithBearerToken("test-token-123") + result, _, err := k8sTool.handleKubectlGetEnhanced(ctx, req, getResourcesInput{ResourceType: "pods"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1235,12 +1075,8 @@ func TestBearerTokenPassthrough(t *testing.T) { ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("my-auth-token", map[string]interface{}{ - "name": "test-deployment", - "replicas": float64(5), - }) - - result, err := k8sTool.handleScaleDeployment(ctx, req) + req := requestWithBearerToken("my-auth-token") + result, _, err := k8sTool.handleScaleDeployment(ctx, req, scaleInput{Name: "test-deployment", Replicas: 5}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1260,8 +1096,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("logs-token", map[string]interface{}{"pod_name": "test-pod"}) - result, err := k8sTool.handleKubectlLogsEnhanced(ctx, req) + req := requestWithBearerToken("logs-token") + result, _, err := k8sTool.handleKubectlLogsEnhanced(ctx, req, logsInput{PodName: "test-pod"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1279,12 +1115,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("delete-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - }) - - result, err := k8sTool.handleDeleteResource(ctx, req) + req := requestWithBearerToken("delete-token") + result, _, err := k8sTool.handleDeleteResource(ctx, req, deleteResourceInput{ResourceType: "deployment", ResourceName: "test-deployment"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1302,13 +1134,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("patch-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "patch": `{"spec":{"replicas":5}}`, - }) - - result, err := k8sTool.handlePatchResource(ctx, req) + req := requestWithBearerToken("patch-token") + result, _, err := k8sTool.handlePatchResource(ctx, req, patchResourceInput{ResourceType: "deployment", ResourceName: "test-deployment", Patch: `{"spec":{"replicas":5}}`}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1326,13 +1153,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("describe-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "namespace": "default", - }) - - result, err := k8sTool.handleKubectlDescribeTool(ctx, req) + req := requestWithBearerToken("describe-token") + result, _, err := k8sTool.handleKubectlDescribeTool(ctx, req, describeInput{ResourceType: "deployment", ResourceName: "test-deployment", Namespace: "default"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1350,14 +1172,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("rollout-token", map[string]interface{}{ - "action": "restart", - "resource_type": "deployment", - "resource_name": "myapp", - "namespace": "default", - }) - - result, err := k8sTool.handleRollout(ctx, req) + req := requestWithBearerToken("rollout-token") + result, _, err := k8sTool.handleRollout(ctx, req, rolloutInput{Action: "restart", ResourceType: "deployment", ResourceName: "myapp", Namespace: "default"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1375,8 +1191,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("events-token", nil) - result, err := k8sTool.handleGetEvents(ctx, req) + req := requestWithBearerToken("events-token") + result, _, err := k8sTool.handleGetEvents(ctx, req, eventsInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1394,13 +1210,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("exec-token", map[string]interface{}{ - "pod_name": "mypod", - "namespace": "default", - "command": "ls -la", - }) - - result, err := k8sTool.handleExecCommand(ctx, req) + req := requestWithBearerToken("exec-token") + result, _, err := k8sTool.handleExecCommand(ctx, req, execCommandInput{PodName: "mypod", Namespace: "default", Command: "ls -la"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1418,13 +1229,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("annotate-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "annotations": "key1=value1", - }) - - result, err := k8sTool.handleAnnotateResource(ctx, req) + req := requestWithBearerToken("annotate-token") + result, _, err := k8sTool.handleAnnotateResource(ctx, req, annotateInput{ResourceType: "deployment", ResourceName: "test-deployment", Annotations: "key1=value1"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1442,13 +1248,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("label-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "labels": "env=prod", - }) - - result, err := k8sTool.handleLabelResource(ctx, req) + req := requestWithBearerToken("label-token") + result, _, err := k8sTool.handleLabelResource(ctx, req, labelInput{ResourceType: "deployment", ResourceName: "test-deployment", Labels: "env=prod"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1466,8 +1267,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("api-token", nil) - result, err := k8sTool.handleGetAvailableAPIResources(ctx, req) + req := requestWithBearerToken("api-token") + result, _, err := k8sTool.handleGetAvailableAPIResources(ctx, req, noInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1485,8 +1286,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("config-token", nil) - result, err := k8sTool.handleGetClusterConfiguration(ctx, req) + req := requestWithBearerToken("config-token") + result, _, err := k8sTool.handleGetClusterConfiguration(ctx, req, noInput{}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1504,13 +1305,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("remove-anno-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "annotation_key": "key1", - }) - - result, err := k8sTool.handleRemoveAnnotation(ctx, req) + req := requestWithBearerToken("remove-anno-token") + result, _, err := k8sTool.handleRemoveAnnotation(ctx, req, removeAnnotationInput{ResourceType: "deployment", ResourceName: "test-deployment", AnnotationKey: "key1"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1528,13 +1324,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("remove-label-token", map[string]interface{}{ - "resource_type": "deployment", - "resource_name": "test-deployment", - "label_key": "env", - }) - - result, err := k8sTool.handleRemoveLabel(ctx, req) + req := requestWithBearerToken("remove-label-token") + result, _, err := k8sTool.handleRemoveLabel(ctx, req, removeLabelInput{ResourceType: "deployment", ResourceName: "test-deployment", LabelKey: "env"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1552,12 +1343,8 @@ log line 2` ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("url-token", map[string]interface{}{ - "url": "https://example.com/manifest.yaml", - "namespace": "default", - }) - - result, err := k8sTool.handleCreateResourceFromURL(ctx, req) + req := requestWithBearerToken("url-token") + result, _, err := k8sTool.handleCreateResourceFromURL(ctx, req, createFromURLInput{URL: "https://example.com/manifest.yaml", Namespace: "default"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1580,11 +1367,8 @@ metadata: ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(true) - req := requestWithBearerToken("apply-token", map[string]interface{}{ - "manifest": manifest, - }) - - result, err := k8sTool.handleApplyManifest(ctx, req) + req := requestWithBearerToken("apply-token") + result, _, err := k8sTool.handleApplyManifest(ctx, req, applyManifestInput{Manifest: manifest}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1597,9 +1381,8 @@ metadata: t.Run("returns error when passthrough true and authorization header missing", func(t *testing.T) { k8sTool := newTestK8sToolWithPassthrough(true) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"resource_type": "pods"} - result, err := k8sTool.handleKubectlGetEnhanced(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlGetEnhanced(ctx, req, getResourcesInput{ResourceType: "pods"}) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -1614,10 +1397,8 @@ metadata: ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(false) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"resource_type": "pods"} - // No Header set on request - result, err := k8sTool.handleKubectlGetEnhanced(ctx, req) + req := &mcp.CallToolRequest{} + result, _, err := k8sTool.handleKubectlGetEnhanced(ctx, req, getResourcesInput{ResourceType: "pods"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) @@ -1636,11 +1417,9 @@ metadata: ctx := cmd.WithShellExecutor(ctx, mock) k8sTool := newTestK8sToolWithPassthrough(false) - req := mcp.CallToolRequest{} - req.Header = http.Header{} - req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") - req.Params.Arguments = map[string]interface{}{"resource_type": "pods"} - result, err := k8sTool.handleKubectlGetEnhanced(ctx, req) + req := &mcp.CallToolRequest{Extra: &mcp.RequestExtra{Header: http.Header{}}} + req.Extra.Header.Set("Authorization", "Basic dXNlcjpwYXNz") + result, _, err := k8sTool.handleKubectlGetEnhanced(ctx, req, getResourcesInput{ResourceType: "pods"}) assert.NoError(t, err) assert.NotNil(t, result) assert.False(t, result.IsError) diff --git a/pkg/kubescape/kubescape.go b/pkg/kubescape/kubescape.go index e2227fa4..cf349dcc 100644 --- a/pkg/kubescape/kubescape.go +++ b/pkg/kubescape/kubescape.go @@ -8,12 +8,10 @@ import ( "time" "github.com/kagent-dev/tools/internal/errors" - "github.com/kagent-dev/tools/internal/telemetry" + mcp "github.com/kagent-dev/tools/internal/mcp" helpersv1 "github.com/kubescape/k8s-interface/instanceidhandler/v1/helpers" "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" spdxv1beta1 "github.com/kubescape/storage/pkg/generated/clientset/versioned/typed/softwarecomposition/v1beta1" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" corev1 "k8s.io/api/core/v1" apiextensionsclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" k8serrors "k8s.io/apimachinery/pkg/api/errors" @@ -38,6 +36,11 @@ const ( storagePodLabel = "app.kubernetes.io/name=storage" ) +// kubescapeErrResult adapts ToolError to an MCP error result. +func kubescapeErrResult(toolErr *errors.ToolError) *mcp.CallToolResult { + return toolErr.ToMCPResult() +} + // KubescapeTool holds the clients for Kubescape and Kubernetes APIs type KubescapeTool struct { spdxClient spdxv1beta1.SpdxV1beta1Interface @@ -114,14 +117,64 @@ type CheckStatus struct { Details interface{} `json:"details,omitempty"` } +type checkHealthInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace to check (default: kubescape)"` +} + +type listVulnerabilityManifestsInput struct { + Namespace string `json:"namespace" jsonschema:"Filter by namespace (optional, defaults to all namespaces)"` + Level string `json:"level" jsonschema:"Type of manifests to list: 'image', 'workload', or 'both' (default: both)"` +} + +type listVulnerabilitiesInManifestInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace of the manifest (default: kubescape)"` + ManifestName string `json:"manifest_name" jsonschema:"Name of the vulnerability manifest"` +} + +type getVulnerabilityDetailsInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace of the manifest (default: kubescape)"` + ManifestName string `json:"manifest_name" jsonschema:"Name of the vulnerability manifest"` + CveID string `json:"cve_id" jsonschema:"CVE identifier (e.g., CVE-2023-12345)"` +} + +type listConfigurationScansInput struct { + Namespace string `json:"namespace" jsonschema:"Filter by namespace (optional, defaults to all namespaces)"` +} + +type getConfigurationScanInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace of the scan (default: kubescape)"` + ManifestName string `json:"manifest_name" jsonschema:"Name of the configuration scan manifest"` +} + +type listApplicationProfilesInput struct { + Namespace string `json:"namespace" jsonschema:"Filter by namespace (optional, defaults to all namespaces)"` +} + +type getApplicationProfileInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace of the profile"` + Name string `json:"name" jsonschema:"Name of the application profile"` +} + +type listNetworkNeighborhoodsInput struct { + Namespace string `json:"namespace" jsonschema:"Filter by namespace (optional, defaults to all namespaces)"` +} + +type getNetworkNeighborhoodInput struct { + Namespace string `json:"namespace" jsonschema:"Namespace of the network neighborhood"` + Name string `json:"name" jsonschema:"Name of the network neighborhood"` +} + // handleCheckHealth verifies Kubescape operator installation and readiness -func (k *KubescapeTool) handleCheckHealth(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleCheckHealth(ctx context.Context, request *mcp.CallToolRequest, in checkHealthInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("check_health", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", defaultKubescapeNamespace) + namespace := in.Namespace + if namespace == "" { + namespace = defaultKubescapeNamespace + } result := HealthCheckResult{ Healthy: true, @@ -456,21 +509,24 @@ func (k *KubescapeTool) handleCheckHealth(ctx context.Context, request mcp.CallT content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleListVulnerabilityManifests lists vulnerability manifests at image and workload levels -func (k *KubescapeTool) handleListVulnerabilityManifests(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleListVulnerabilityManifests(ctx context.Context, request *mcp.CallToolRequest, in listVulnerabilityManifestsInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("list_vulnerability_manifests", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", "") - level := mcp.ParseString(request, "level", "both") + namespace := in.Namespace + level := in.Level + if level == "" { + level = "both" + } // Build label selector based on level labelSelector := "" @@ -498,7 +554,7 @@ func (k *KubescapeTool) handleListVulnerabilityManifests(ctx context.Context, re toolErr := errors.NewKubescapeError("list_vulnerability_manifests", err). WithContext("namespace", namespace). WithContext("level", level) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } // Build response @@ -526,24 +582,27 @@ func (k *KubescapeTool) handleListVulnerabilityManifests(ctx context.Context, re content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleListVulnerabilitiesInManifest lists all CVEs in a specific manifest -func (k *KubescapeTool) handleListVulnerabilitiesInManifest(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleListVulnerabilitiesInManifest(ctx context.Context, request *mcp.CallToolRequest, in listVulnerabilitiesInManifestInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("list_vulnerabilities", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", defaultKubescapeNamespace) - manifestName := mcp.ParseString(request, "manifest_name", "") + namespace := in.Namespace + if namespace == "" { + namespace = defaultKubescapeNamespace + } + manifestName := in.ManifestName if manifestName == "" { - return mcp.NewToolResultError("manifest_name parameter is required"), nil + return mcp.NewToolResultError("manifest_name parameter is required"), nil, nil } manifest, err := k.spdxClient.VulnerabilityManifests(namespace).Get(ctx, manifestName, metav1.GetOptions{}) @@ -551,7 +610,7 @@ func (k *KubescapeTool) handleListVulnerabilitiesInManifest(ctx context.Context, toolErr := errors.NewKubescapeError("get_vulnerability_manifest", err). WithContext("namespace", namespace). WithContext("manifest_name", manifestName) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } // Extract vulnerabilities with summary info @@ -598,28 +657,31 @@ func (k *KubescapeTool) handleListVulnerabilitiesInManifest(ctx context.Context, content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleGetVulnerabilityDetails gets detailed info about a specific CVE in a manifest -func (k *KubescapeTool) handleGetVulnerabilityDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleGetVulnerabilityDetails(ctx context.Context, request *mcp.CallToolRequest, in getVulnerabilityDetailsInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("get_vulnerability_details", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", defaultKubescapeNamespace) - manifestName := mcp.ParseString(request, "manifest_name", "") - cveID := mcp.ParseString(request, "cve_id", "") + namespace := in.Namespace + if namespace == "" { + namespace = defaultKubescapeNamespace + } + manifestName := in.ManifestName + cveID := in.CveID if manifestName == "" { - return mcp.NewToolResultError("manifest_name parameter is required"), nil + return mcp.NewToolResultError("manifest_name parameter is required"), nil, nil } if cveID == "" { - return mcp.NewToolResultError("cve_id parameter is required"), nil + return mcp.NewToolResultError("cve_id parameter is required"), nil, nil } manifest, err := k.spdxClient.VulnerabilityManifests(namespace).Get(ctx, manifestName, metav1.GetOptions{}) @@ -627,7 +689,7 @@ func (k *KubescapeTool) handleGetVulnerabilityDetails(ctx context.Context, reque toolErr := errors.NewKubescapeError("get_vulnerability_manifest", err). WithContext("namespace", namespace). WithContext("manifest_name", manifestName) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } // Find matching CVE entries @@ -639,25 +701,25 @@ func (k *KubescapeTool) handleGetVulnerabilityDetails(ctx context.Context, reque } if len(matches) == 0 { - return mcp.NewToolResultError(fmt.Sprintf("CVE %s not found in manifest %s", cveID, manifestName)), nil + return mcp.NewToolResultError(fmt.Sprintf("CVE %s not found in manifest %s", cveID, manifestName)), nil, nil } content, err := json.MarshalIndent(matches, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleListConfigurationScans lists configuration security scan results -func (k *KubescapeTool) handleListConfigurationScans(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleListConfigurationScans(ctx context.Context, request *mcp.CallToolRequest, in listConfigurationScansInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("list_configuration_scans", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", "") + namespace := in.Namespace queryNamespace := metav1.NamespaceAll if namespace != "" { @@ -668,7 +730,7 @@ func (k *KubescapeTool) handleListConfigurationScans(ctx context.Context, reques if err != nil { toolErr := errors.NewKubescapeError("list_configuration_scans", err). WithContext("namespace", namespace) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } configManifests := []map[string]interface{}{} @@ -688,24 +750,27 @@ func (k *KubescapeTool) handleListConfigurationScans(ctx context.Context, reques content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleGetConfigurationScan gets details of a specific configuration scan -func (k *KubescapeTool) handleGetConfigurationScan(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleGetConfigurationScan(ctx context.Context, request *mcp.CallToolRequest, in getConfigurationScanInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("get_configuration_scan", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", defaultKubescapeNamespace) - manifestName := mcp.ParseString(request, "manifest_name", "") + namespace := in.Namespace + if namespace == "" { + namespace = defaultKubescapeNamespace + } + manifestName := in.ManifestName if manifestName == "" { - return mcp.NewToolResultError("manifest_name parameter is required"), nil + return mcp.NewToolResultError("manifest_name parameter is required"), nil, nil } manifest, err := k.spdxClient.WorkloadConfigurationScans(namespace).Get(ctx, manifestName, metav1.GetOptions{}) @@ -713,25 +778,25 @@ func (k *KubescapeTool) handleGetConfigurationScan(ctx context.Context, request toolErr := errors.NewKubescapeError("get_configuration_scan", err). WithContext("namespace", namespace). WithContext("manifest_name", manifestName) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } content, err := json.MarshalIndent(manifest, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleListApplicationProfiles lists application profiles showing runtime behavior data -func (k *KubescapeTool) handleListApplicationProfiles(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleListApplicationProfiles(ctx context.Context, request *mcp.CallToolRequest, in listApplicationProfilesInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("list_application_profiles", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", "") + namespace := in.Namespace queryNamespace := metav1.NamespaceAll if namespace != "" { @@ -742,7 +807,7 @@ func (k *KubescapeTool) handleListApplicationProfiles(ctx context.Context, reque if err != nil { toolErr := errors.NewKubescapeError("list_application_profiles", err). WithContext("namespace", namespace) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } profileList := []map[string]interface{}{} @@ -793,27 +858,27 @@ func (k *KubescapeTool) handleListApplicationProfiles(ctx context.Context, reque content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleGetApplicationProfile gets detailed runtime behavior for a specific workload -func (k *KubescapeTool) handleGetApplicationProfile(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleGetApplicationProfile(ctx context.Context, request *mcp.CallToolRequest, in getApplicationProfileInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("get_application_profile", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", "") - name := mcp.ParseString(request, "name", "") + namespace := in.Namespace + name := in.Name if name == "" { - return mcp.NewToolResultError("name parameter is required"), nil + return mcp.NewToolResultError("name parameter is required"), nil, nil } if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } profile, err := k.spdxClient.ApplicationProfiles(namespace).Get(ctx, name, metav1.GetOptions{}) @@ -821,7 +886,7 @@ func (k *KubescapeTool) handleGetApplicationProfile(ctx context.Context, request toolErr := errors.NewKubescapeError("get_application_profile", err). WithContext("namespace", namespace). WithContext("name", name) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } // Build detailed response with container behaviors @@ -869,20 +934,20 @@ func (k *KubescapeTool) handleGetApplicationProfile(ctx context.Context, request content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleListNetworkNeighborhoods lists network communication patterns for workloads -func (k *KubescapeTool) handleListNetworkNeighborhoods(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleListNetworkNeighborhoods(ctx context.Context, request *mcp.CallToolRequest, in listNetworkNeighborhoodsInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("list_network_neighborhoods", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", "") + namespace := in.Namespace queryNamespace := metav1.NamespaceAll if namespace != "" { @@ -893,7 +958,7 @@ func (k *KubescapeTool) handleListNetworkNeighborhoods(ctx context.Context, requ if err != nil { toolErr := errors.NewKubescapeError("list_network_neighborhoods", err). WithContext("namespace", namespace) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } neighborhoodList := []map[string]interface{}{} @@ -927,27 +992,27 @@ func (k *KubescapeTool) handleListNetworkNeighborhoods(ctx context.Context, requ content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // handleGetNetworkNeighborhood gets detailed network connections for a specific workload -func (k *KubescapeTool) handleGetNetworkNeighborhood(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func (k *KubescapeTool) handleGetNetworkNeighborhood(ctx context.Context, request *mcp.CallToolRequest, in getNetworkNeighborhoodInput) (*mcp.CallToolResult, any, error) { if k.initError != nil { toolErr := errors.NewKubescapeError("get_network_neighborhood", k.initError) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } - namespace := mcp.ParseString(request, "namespace", "") - name := mcp.ParseString(request, "name", "") + namespace := in.Namespace + name := in.Name if name == "" { - return mcp.NewToolResultError("name parameter is required"), nil + return mcp.NewToolResultError("name parameter is required"), nil, nil } if namespace == "" { - return mcp.NewToolResultError("namespace parameter is required"), nil + return mcp.NewToolResultError("namespace parameter is required"), nil, nil } nn, err := k.spdxClient.NetworkNeighborhoods(namespace).Get(ctx, name, metav1.GetOptions{}) @@ -955,7 +1020,7 @@ func (k *KubescapeTool) handleGetNetworkNeighborhood(ctx context.Context, reques toolErr := errors.NewKubescapeError("get_network_neighborhood", err). WithContext("namespace", namespace). WithContext("name", name) - return toolErr.ToMCPResult(), nil + return kubescapeErrResult(toolErr), nil, nil } // Build detailed response with container network data @@ -1032,10 +1097,10 @@ func (k *KubescapeTool) handleGetNetworkNeighborhood(ctx context.Context, reques content, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil, nil } - return mcp.NewToolResultText(string(content)), nil + return mcp.NewToolResultText(string(content)), nil, nil } // Helper function to truncate strings @@ -1047,160 +1112,127 @@ func truncateString(s string, maxLen int) string { } // RegisterTools registers all Kubescape tools with the MCP server -func RegisterTools(s *server.MCPServer, kubeconfig string, readOnly bool) { +func RegisterTools(s *mcp.Server, kubeconfig string, readOnly bool) { tool := NewKubescapeTool(kubeconfig) - - // Health check tool - s.AddTool(mcp.NewTool("kubescape_check_health", - mcp.WithDescription("Check if Kubescape operator is installed and operational. Verifies namespace, operator pods, storage pods, CRDs, and scan data availability."), - mcp.WithString("namespace", mcp.Description("Namespace to check (default: kubescape)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_check_health", tool.handleCheckHealth))) - - // List vulnerability manifests - s.AddTool(mcp.NewTool("kubescape_list_vulnerability_manifests", - mcp.WithDescription("List vulnerability manifests from Kubescape operator. Returns vulnerability scan results at image or workload level."), - mcp.WithString("namespace", mcp.Description("Filter by namespace (optional, defaults to all namespaces)")), - mcp.WithString("level", mcp.Description("Type of manifests to list: 'image', 'workload', or 'both' (default: both)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_list_vulnerability_manifests", tool.handleListVulnerabilityManifests))) - - // List vulnerabilities in a manifest - s.AddTool(mcp.NewTool("kubescape_list_vulnerabilities", - mcp.WithDescription("List all CVEs/vulnerabilities found in a specific vulnerability manifest. Returns severity summary and vulnerability details."), - mcp.WithString("namespace", mcp.Description("Namespace of the manifest (default: kubescape)")), - mcp.WithString("manifest_name", mcp.Description("Name of the vulnerability manifest"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_list_vulnerabilities", tool.handleListVulnerabilitiesInManifest))) - - // Get detailed vulnerability info - s.AddTool(mcp.NewTool("kubescape_get_vulnerability_details", - mcp.WithDescription("Get detailed information about a specific CVE in a vulnerability manifest, including affected packages and fix information."), - mcp.WithString("namespace", mcp.Description("Namespace of the manifest (default: kubescape)")), - mcp.WithString("manifest_name", mcp.Description("Name of the vulnerability manifest"), mcp.Required()), - mcp.WithString("cve_id", mcp.Description("CVE identifier (e.g., CVE-2023-12345)"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_get_vulnerability_details", tool.handleGetVulnerabilityDetails))) - - // List configuration scans - s.AddTool(mcp.NewTool("kubescape_list_configuration_scans", - mcp.WithDescription("List configuration security scan results from Kubescape operator. Shows workloads that have been scanned for security misconfigurations."), - mcp.WithString("namespace", mcp.Description("Filter by namespace (optional, defaults to all namespaces)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_list_configuration_scans", tool.handleListConfigurationScans))) - - // Get configuration scan details - s.AddTool(mcp.NewTool("kubescape_get_configuration_scan", - mcp.WithDescription("Get detailed configuration security scan results for a specific workload, including failed controls and remediation guidance."), - mcp.WithString("namespace", mcp.Description("Namespace of the scan (default: kubescape)")), - mcp.WithString("manifest_name", mcp.Description("Name of the configuration scan manifest"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_get_configuration_scan", tool.handleGetConfigurationScan))) - - // List application profiles (runtime observability) - s.AddTool(mcp.NewTool("kubescape_list_application_profiles", - mcp.WithDescription("List ApplicationProfiles showing runtime behavior of workloads. These profiles capture: "+ - "executed processes (Execs), file access patterns (Opens), system calls (Syscalls), Linux capabilities used, and HTTP endpoints. "+ - "Use this data to prioritize vulnerability findings - a CVE in an unused package is lower priority than one in an actively running process. "+ - "Requires 'capabilities.runtimeObservability=enable' in Kubescape Helm chart."), - mcp.WithString("namespace", mcp.Description("Filter by namespace (optional, defaults to all namespaces)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_list_application_profiles", tool.handleListApplicationProfiles))) - - // Get application profile details - s.AddTool(mcp.NewTool("kubescape_get_application_profile", - mcp.WithDescription("Get detailed runtime behavior profile for a specific workload. Shows what processes run, what files are accessed, "+ - "what system calls are made, and what capabilities are used per container. "+ - "Compare with CVE findings to prioritize remediation - focus on vulnerabilities affecting actively used components."), - mcp.WithString("namespace", mcp.Description("Namespace of the profile"), mcp.Required()), - mcp.WithString("name", mcp.Description("Name of the application profile"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_get_application_profile", tool.handleGetApplicationProfile))) - - // List network neighborhoods (runtime observability) - s.AddTool(mcp.NewTool("kubescape_list_network_neighborhoods", - mcp.WithDescription("List NetworkNeighborhoods showing actual network communication patterns of workloads. "+ - "These capture: ingress connections (who talks TO the workload), egress connections (who the workload talks TO), "+ - "including DNS names, IP addresses, ports, and protocols. "+ - "Use this to understand attack surface and prioritize network-related security findings. "+ - "Requires 'capabilities.runtimeObservability=enable' in Kubescape Helm chart."), - mcp.WithString("namespace", mcp.Description("Filter by namespace (optional, defaults to all namespaces)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_list_network_neighborhoods", tool.handleListNetworkNeighborhoods))) - - // Get network neighborhood details - s.AddTool(mcp.NewTool("kubescape_get_network_neighborhood", - mcp.WithDescription("Get detailed network connections for a specific workload. Shows all observed ingress and egress traffic "+ - "with DNS names, IPs, ports, and protocols. Use this to verify if a workload with a vulnerability is actually exposed to the network."), - mcp.WithString("namespace", mcp.Description("Namespace of the network neighborhood"), mcp.Required()), - mcp.WithString("name", mcp.Description("Name of the network neighborhood"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("kubescape_get_network_neighborhood", tool.handleGetNetworkNeighborhood))) + _ = readOnly // all kubescape tools are read-only + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_check_health", + Description: "Check if Kubescape operator is installed and operational. Verifies namespace, operator pods, storage pods, CRDs, and scan data availability.", + }, tool.handleCheckHealth) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_list_vulnerability_manifests", + Description: "List vulnerability manifests from Kubescape operator. Returns vulnerability scan results at image or workload level.", + }, tool.handleListVulnerabilityManifests) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_list_vulnerabilities", + Description: "List all CVEs/vulnerabilities found in a specific vulnerability manifest. Returns severity summary and vulnerability details.", + }, tool.handleListVulnerabilitiesInManifest) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_get_vulnerability_details", + Description: "Get detailed information about a specific CVE in a vulnerability manifest, including affected packages and fix information.", + }, tool.handleGetVulnerabilityDetails) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_list_configuration_scans", + Description: "List configuration security scan results from Kubescape operator. Shows workloads that have been scanned for security misconfigurations.", + }, tool.handleListConfigurationScans) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_get_configuration_scan", + Description: "Get detailed configuration security scan results for a specific workload, including failed controls and remediation guidance.", + }, tool.handleGetConfigurationScan) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_list_application_profiles", + Description: "List ApplicationProfiles showing runtime behavior of workloads. These profiles capture: " + + "executed processes (Execs), file access patterns (Opens), system calls (Syscalls), Linux capabilities used, and HTTP endpoints. " + + "Use this data to prioritize vulnerability findings - a CVE in an unused package is lower priority than one in an actively running process. " + + "Requires 'capabilities.runtimeObservability=enable' in Kubescape Helm chart.", + }, tool.handleListApplicationProfiles) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_get_application_profile", + Description: "Get detailed runtime behavior profile for a specific workload. Shows what processes run, what files are accessed, " + + "what system calls are made, and what capabilities are used per container. " + + "Compare with CVE findings to prioritize remediation - focus on vulnerabilities affecting actively used components.", + }, tool.handleGetApplicationProfile) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_list_network_neighborhoods", + Description: "List NetworkNeighborhoods showing actual network communication patterns of workloads. " + + "These capture: ingress connections (who talks TO the workload), egress connections (who the workload talks TO), " + + "including DNS names, IP addresses, ports, and protocols. " + + "Use this to understand attack surface and prioritize network-related security findings. " + + "Requires 'capabilities.runtimeObservability=enable' in Kubescape Helm chart.", + }, tool.handleListNetworkNeighborhoods) + + mcp.AddTool(s, "kubescape", &mcp.Tool{ + Name: "kubescape_get_network_neighborhood", + Description: "Get detailed network connections for a specific workload. Shows all observed ingress and egress traffic " + + "with DNS names, IPs, ports, and protocols. Use this to verify if a workload with a vulnerability is actually exposed to the network.", + }, tool.handleGetNetworkNeighborhood) // NOTE: SBOM tools are disabled as they return too much data for LLM context windows. - // SBOMs contain detailed package information that can be very large. - // To enable in the future, uncomment the handlers and tool registrations below. - // - // s.AddTool(mcp.NewTool("kubescape_list_sboms", ...)) - // s.AddTool(mcp.NewTool("kubescape_get_sbom", ...)) } // Interfaces for testing - allows mocking the Kubernetes clients type KubescapeToolInterface interface { - HandleCheckHealth(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleListVulnerabilityManifests(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleListVulnerabilitiesInManifest(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleGetVulnerabilityDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleListConfigurationScans(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleGetConfigurationScan(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleListApplicationProfiles(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleGetApplicationProfile(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleListNetworkNeighborhoods(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - HandleGetNetworkNeighborhood(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - // NOTE: SBOM handlers are disabled as they return too much data for LLM context - // HandleListSBOMs(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) - // HandleGetSBOM(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + HandleCheckHealth(ctx context.Context, in checkHealthInput) (*mcp.CallToolResult, any, error) + HandleListVulnerabilityManifests(ctx context.Context, in listVulnerabilityManifestsInput) (*mcp.CallToolResult, any, error) + HandleListVulnerabilitiesInManifest(ctx context.Context, in listVulnerabilitiesInManifestInput) (*mcp.CallToolResult, any, error) + HandleGetVulnerabilityDetails(ctx context.Context, in getVulnerabilityDetailsInput) (*mcp.CallToolResult, any, error) + HandleListConfigurationScans(ctx context.Context, in listConfigurationScansInput) (*mcp.CallToolResult, any, error) + HandleGetConfigurationScan(ctx context.Context, in getConfigurationScanInput) (*mcp.CallToolResult, any, error) + HandleListApplicationProfiles(ctx context.Context, in listApplicationProfilesInput) (*mcp.CallToolResult, any, error) + HandleGetApplicationProfile(ctx context.Context, in getApplicationProfileInput) (*mcp.CallToolResult, any, error) + HandleListNetworkNeighborhoods(ctx context.Context, in listNetworkNeighborhoodsInput) (*mcp.CallToolResult, any, error) + HandleGetNetworkNeighborhood(ctx context.Context, in getNetworkNeighborhoodInput) (*mcp.CallToolResult, any, error) } // Ensure KubescapeTool implements the interface var _ KubescapeToolInterface = (*KubescapeTool)(nil) // Export handler methods for testing -func (k *KubescapeTool) HandleCheckHealth(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleCheckHealth(ctx, request) +func (k *KubescapeTool) HandleCheckHealth(ctx context.Context, in checkHealthInput) (*mcp.CallToolResult, any, error) { + return k.handleCheckHealth(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleListVulnerabilityManifests(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleListVulnerabilityManifests(ctx, request) +func (k *KubescapeTool) HandleListVulnerabilityManifests(ctx context.Context, in listVulnerabilityManifestsInput) (*mcp.CallToolResult, any, error) { + return k.handleListVulnerabilityManifests(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleListVulnerabilitiesInManifest(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleListVulnerabilitiesInManifest(ctx, request) +func (k *KubescapeTool) HandleListVulnerabilitiesInManifest(ctx context.Context, in listVulnerabilitiesInManifestInput) (*mcp.CallToolResult, any, error) { + return k.handleListVulnerabilitiesInManifest(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleGetVulnerabilityDetails(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleGetVulnerabilityDetails(ctx, request) +func (k *KubescapeTool) HandleGetVulnerabilityDetails(ctx context.Context, in getVulnerabilityDetailsInput) (*mcp.CallToolResult, any, error) { + return k.handleGetVulnerabilityDetails(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleListConfigurationScans(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleListConfigurationScans(ctx, request) +func (k *KubescapeTool) HandleListConfigurationScans(ctx context.Context, in listConfigurationScansInput) (*mcp.CallToolResult, any, error) { + return k.handleListConfigurationScans(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleGetConfigurationScan(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleGetConfigurationScan(ctx, request) +func (k *KubescapeTool) HandleGetConfigurationScan(ctx context.Context, in getConfigurationScanInput) (*mcp.CallToolResult, any, error) { + return k.handleGetConfigurationScan(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleListApplicationProfiles(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleListApplicationProfiles(ctx, request) +func (k *KubescapeTool) HandleListApplicationProfiles(ctx context.Context, in listApplicationProfilesInput) (*mcp.CallToolResult, any, error) { + return k.handleListApplicationProfiles(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleGetApplicationProfile(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleGetApplicationProfile(ctx, request) +func (k *KubescapeTool) HandleGetApplicationProfile(ctx context.Context, in getApplicationProfileInput) (*mcp.CallToolResult, any, error) { + return k.handleGetApplicationProfile(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleListNetworkNeighborhoods(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleListNetworkNeighborhoods(ctx, request) +func (k *KubescapeTool) HandleListNetworkNeighborhoods(ctx context.Context, in listNetworkNeighborhoodsInput) (*mcp.CallToolResult, any, error) { + return k.handleListNetworkNeighborhoods(ctx, &mcp.CallToolRequest{}, in) } -func (k *KubescapeTool) HandleGetNetworkNeighborhood(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return k.handleGetNetworkNeighborhood(ctx, request) +func (k *KubescapeTool) HandleGetNetworkNeighborhood(ctx context.Context, in getNetworkNeighborhoodInput) (*mcp.CallToolResult, any, error) { + return k.handleGetNetworkNeighborhood(ctx, &mcp.CallToolRequest{}, in) } - -// NOTE: SBOM handlers are disabled as they return too much data for LLM context -// func (k *KubescapeTool) HandleListSBOMs(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { -// return k.handleListSBOMs(ctx, request) -// } -// -// func (k *KubescapeTool) HandleGetSBOM(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { -// return k.handleGetSBOM(ctx, request) -// } diff --git a/pkg/kubescape/kubescape_test.go b/pkg/kubescape/kubescape_test.go index 2b0bcafe..9b331fa9 100644 --- a/pkg/kubescape/kubescape_test.go +++ b/pkg/kubescape/kubescape_test.go @@ -6,10 +6,9 @@ import ( "errors" "testing" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" kubescapefake "github.com/kubescape/storage/pkg/generated/clientset/versioned/fake" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" @@ -19,62 +18,22 @@ import ( kubefake "k8s.io/client-go/kubernetes/fake" ) -// Helper function to create a CallToolRequest with arguments -func makeRequest(args map[string]interface{}) mcp.CallToolRequest { - request := mcp.CallToolRequest{} - request.Params.Arguments = args - return request -} - // Helper function to extract text content from MCP result func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { return "" } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { return textContent.Text } return "" } func TestRegisterTools(t *testing.T) { - s := server.NewMCPServer("test", "1.0.0") - - // Should not panic + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0.0"}, nil) assert.NotPanics(t, func() { RegisterTools(s, "", false) }) - - // Verify tools are registered by checking the server has tools - // NOTE: SBOM tools are disabled (too large for LLM context), so we expect 10 tools - tools := s.ListTools() - assert.Len(t, tools, 10) - - expectedTools := map[string]bool{ - "kubescape_check_health": false, - "kubescape_list_vulnerability_manifests": false, - "kubescape_list_vulnerabilities": false, - "kubescape_get_vulnerability_details": false, - "kubescape_list_configuration_scans": false, - "kubescape_get_configuration_scan": false, - "kubescape_list_application_profiles": false, - "kubescape_get_application_profile": false, - "kubescape_list_network_neighborhoods": false, - "kubescape_get_network_neighborhood": false, - // NOTE: SBOM tools disabled - too large for LLM context - // "kubescape_list_sboms": false, - // "kubescape_get_sbom": false, - } - - for name := range tools { - if _, exists := expectedTools[name]; exists { - expectedTools[name] = true - } - } - - for name, found := range expectedTools { - assert.True(t, found, "Tool %s not found", name) - } } func TestHandleCheckHealth_AllComponentsHealthy(t *testing.T) { @@ -150,7 +109,7 @@ func TestHandleCheckHealth_AllComponentsHealthy(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -186,7 +145,7 @@ func TestHandleCheckHealth_NamespaceNotFound(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -211,7 +170,7 @@ func TestHandleCheckHealth_OperatorPodsNotRunning(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -244,7 +203,7 @@ func TestHandleCheckHealth_OperatorPodsUnhealthy(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -268,7 +227,7 @@ func TestHandleCheckHealth_VulnerabilityCRDMissing(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -299,7 +258,7 @@ func TestHandleCheckHealth_NoScanData(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -331,7 +290,7 @@ func TestHandleCheckHealth_RuntimeObservabilityCRDsMissing(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -381,9 +340,9 @@ func TestHandleCheckHealth_CustomNamespace(t *testing.T) { tool := NewKubescapeToolWithClients(k8sClient, apiExtClient, spdxClient.SpdxV1beta1()) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "custom-ns", - })) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{ + Namespace: "custom-ns", + }) require.NoError(t, err) require.NotNil(t, result) @@ -398,7 +357,7 @@ func TestHandleCheckHealth_CustomNamespace(t *testing.T) { func TestHandleCheckHealth_InitError(t *testing.T) { tool := NewKubescapeToolWithError(errors.New("failed to connect")) - result, err := tool.HandleCheckHealth(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleCheckHealth(context.Background(), checkHealthInput{}) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -433,7 +392,7 @@ func TestHandleListVulnerabilityManifests_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListVulnerabilityManifests(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListVulnerabilityManifests(context.Background(), listVulnerabilityManifestsInput{}) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -459,9 +418,9 @@ func TestHandleListVulnerabilityManifests_FilterByNamespace(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListVulnerabilityManifests(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - })) + result, _, err := tool.HandleListVulnerabilityManifests(context.Background(), listVulnerabilityManifestsInput{ + Namespace: "default", + }) require.NoError(t, err) require.NotNil(t, result) @@ -476,7 +435,7 @@ func TestHandleListVulnerabilityManifests_EmptyResults(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListVulnerabilityManifests(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListVulnerabilityManifests(context.Background(), listVulnerabilityManifestsInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -490,7 +449,7 @@ func TestHandleListVulnerabilityManifests_EmptyResults(t *testing.T) { func TestHandleListVulnerabilityManifests_InitError(t *testing.T) { tool := NewKubescapeToolWithError(errors.New("failed to connect")) - result, err := tool.HandleListVulnerabilityManifests(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListVulnerabilityManifests(context.Background(), listVulnerabilityManifestsInput{}) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -531,9 +490,9 @@ func TestHandleListVulnerabilitiesInManifest_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListVulnerabilitiesInManifest(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "test-manifest", - })) + result, _, err := tool.HandleListVulnerabilitiesInManifest(context.Background(), listVulnerabilitiesInManifestInput{ + ManifestName: "test-manifest", + }) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -552,7 +511,7 @@ func TestHandleListVulnerabilitiesInManifest_MissingManifestName(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListVulnerabilitiesInManifest(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListVulnerabilitiesInManifest(context.Background(), listVulnerabilitiesInManifestInput{}) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -563,9 +522,9 @@ func TestHandleListVulnerabilitiesInManifest_ManifestNotFound(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListVulnerabilitiesInManifest(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "nonexistent", - })) + result, _, err := tool.HandleListVulnerabilitiesInManifest(context.Background(), listVulnerabilitiesInManifestInput{ + ManifestName: "nonexistent", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -602,10 +561,10 @@ func TestHandleGetVulnerabilityDetails_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetVulnerabilityDetails(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "test-manifest", - "cve_id": "CVE-2021-1234", - })) + result, _, err := tool.HandleGetVulnerabilityDetails(context.Background(), getVulnerabilityDetailsInput{ + ManifestName: "test-manifest", + CveID: "CVE-2021-1234", + }) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -622,9 +581,9 @@ func TestHandleGetVulnerabilityDetails_MissingManifestName(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetVulnerabilityDetails(context.Background(), makeRequest(map[string]interface{}{ - "cve_id": "CVE-2021-1234", - })) + result, _, err := tool.HandleGetVulnerabilityDetails(context.Background(), getVulnerabilityDetailsInput{ + CveID: "CVE-2021-1234", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -635,9 +594,9 @@ func TestHandleGetVulnerabilityDetails_MissingCveId(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetVulnerabilityDetails(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "test-manifest", - })) + result, _, err := tool.HandleGetVulnerabilityDetails(context.Background(), getVulnerabilityDetailsInput{ + ManifestName: "test-manifest", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -661,10 +620,10 @@ func TestHandleGetVulnerabilityDetails_CveNotFound(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetVulnerabilityDetails(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "test-manifest", - "cve_id": "CVE-2021-1234", - })) + result, _, err := tool.HandleGetVulnerabilityDetails(context.Background(), getVulnerabilityDetailsInput{ + ManifestName: "test-manifest", + CveID: "CVE-2021-1234", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -689,7 +648,7 @@ func TestHandleListConfigurationScans_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListConfigurationScans(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListConfigurationScans(context.Background(), listConfigurationScansInput{}) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -713,9 +672,9 @@ func TestHandleListConfigurationScans_FilterByNamespace(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListConfigurationScans(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - })) + result, _, err := tool.HandleListConfigurationScans(context.Background(), listConfigurationScansInput{ + Namespace: "default", + }) require.NoError(t, err) require.NotNil(t, result) @@ -730,7 +689,7 @@ func TestHandleListConfigurationScans_EmptyResults(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListConfigurationScans(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListConfigurationScans(context.Background(), listConfigurationScansInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -753,9 +712,9 @@ func TestHandleGetConfigurationScan_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetConfigurationScan(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "test-scan", - })) + result, _, err := tool.HandleGetConfigurationScan(context.Background(), getConfigurationScanInput{ + ManifestName: "test-scan", + }) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -765,7 +724,7 @@ func TestHandleGetConfigurationScan_MissingManifestName(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetConfigurationScan(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleGetConfigurationScan(context.Background(), getConfigurationScanInput{}) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -776,9 +735,9 @@ func TestHandleGetConfigurationScan_NotFound(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetConfigurationScan(context.Background(), makeRequest(map[string]interface{}{ - "manifest_name": "nonexistent", - })) + result, _, err := tool.HandleGetConfigurationScan(context.Background(), getConfigurationScanInput{ + ManifestName: "nonexistent", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -809,11 +768,8 @@ func TestNilArgumentsHandling(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - // Test with nil arguments map - should use defaults - request := mcp.CallToolRequest{} - request.Params.Arguments = nil - - result, err := tool.HandleListVulnerabilityManifests(context.Background(), request) + // Empty input should use defaults + result, _, err := tool.HandleListVulnerabilityManifests(context.Background(), listVulnerabilityManifestsInput{}) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -853,7 +809,7 @@ func TestHandleListApplicationProfiles_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListApplicationProfiles(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListApplicationProfiles(context.Background(), listApplicationProfilesInput{}) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -880,9 +836,9 @@ func TestHandleListApplicationProfiles_FilterByNamespace(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListApplicationProfiles(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - })) + result, _, err := tool.HandleListApplicationProfiles(context.Background(), listApplicationProfilesInput{ + Namespace: "default", + }) require.NoError(t, err) require.NotNil(t, result) @@ -897,7 +853,7 @@ func TestHandleListApplicationProfiles_EmptyResults(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListApplicationProfiles(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListApplicationProfiles(context.Background(), listApplicationProfilesInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -911,7 +867,7 @@ func TestHandleListApplicationProfiles_EmptyResults(t *testing.T) { func TestHandleListApplicationProfiles_InitError(t *testing.T) { tool := NewKubescapeToolWithError(errors.New("failed to connect")) - result, err := tool.HandleListApplicationProfiles(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListApplicationProfiles(context.Background(), listApplicationProfilesInput{}) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -944,10 +900,10 @@ func TestHandleGetApplicationProfile_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetApplicationProfile(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - "name": "test-profile", - })) + result, _, err := tool.HandleGetApplicationProfile(context.Background(), getApplicationProfileInput{ + Namespace: "default", + Name: "test-profile", + }) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -965,9 +921,9 @@ func TestHandleGetApplicationProfile_MissingName(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetApplicationProfile(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - })) + result, _, err := tool.HandleGetApplicationProfile(context.Background(), getApplicationProfileInput{ + Namespace: "default", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -978,9 +934,9 @@ func TestHandleGetApplicationProfile_MissingNamespace(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetApplicationProfile(context.Background(), makeRequest(map[string]interface{}{ - "name": "test-profile", - })) + result, _, err := tool.HandleGetApplicationProfile(context.Background(), getApplicationProfileInput{ + Name: "test-profile", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -991,10 +947,10 @@ func TestHandleGetApplicationProfile_NotFound(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetApplicationProfile(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - "name": "nonexistent", - })) + result, _, err := tool.HandleGetApplicationProfile(context.Background(), getApplicationProfileInput{ + Namespace: "default", + Name: "nonexistent", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -1033,7 +989,7 @@ func TestHandleListNetworkNeighborhoods_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListNetworkNeighborhoods(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListNetworkNeighborhoods(context.Background(), listNetworkNeighborhoodsInput{}) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -1060,9 +1016,9 @@ func TestHandleListNetworkNeighborhoods_FilterByNamespace(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListNetworkNeighborhoods(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - })) + result, _, err := tool.HandleListNetworkNeighborhoods(context.Background(), listNetworkNeighborhoodsInput{ + Namespace: "default", + }) require.NoError(t, err) require.NotNil(t, result) @@ -1077,7 +1033,7 @@ func TestHandleListNetworkNeighborhoods_EmptyResults(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleListNetworkNeighborhoods(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListNetworkNeighborhoods(context.Background(), listNetworkNeighborhoodsInput{}) require.NoError(t, err) require.NotNil(t, result) @@ -1091,7 +1047,7 @@ func TestHandleListNetworkNeighborhoods_EmptyResults(t *testing.T) { func TestHandleListNetworkNeighborhoods_InitError(t *testing.T) { tool := NewKubescapeToolWithError(errors.New("failed to connect")) - result, err := tool.HandleListNetworkNeighborhoods(context.Background(), makeRequest(nil)) + result, _, err := tool.HandleListNetworkNeighborhoods(context.Background(), listNetworkNeighborhoodsInput{}) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -1122,10 +1078,10 @@ func TestHandleGetNetworkNeighborhood_Success(t *testing.T) { tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetNetworkNeighborhood(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - "name": "test-nn", - })) + result, _, err := tool.HandleGetNetworkNeighborhood(context.Background(), getNetworkNeighborhoodInput{ + Namespace: "default", + Name: "test-nn", + }) require.NoError(t, err) require.NotNil(t, result) assert.False(t, result.IsError) @@ -1143,9 +1099,9 @@ func TestHandleGetNetworkNeighborhood_MissingName(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetNetworkNeighborhood(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - })) + result, _, err := tool.HandleGetNetworkNeighborhood(context.Background(), getNetworkNeighborhoodInput{ + Namespace: "default", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -1156,9 +1112,9 @@ func TestHandleGetNetworkNeighborhood_MissingNamespace(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetNetworkNeighborhood(context.Background(), makeRequest(map[string]interface{}{ - "name": "test-nn", - })) + result, _, err := tool.HandleGetNetworkNeighborhood(context.Background(), getNetworkNeighborhoodInput{ + Name: "test-nn", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) @@ -1169,10 +1125,10 @@ func TestHandleGetNetworkNeighborhood_NotFound(t *testing.T) { spdxClient := kubescapefake.NewClientset() tool := NewKubescapeToolWithClients(nil, nil, spdxClient.SpdxV1beta1()) - result, err := tool.HandleGetNetworkNeighborhood(context.Background(), makeRequest(map[string]interface{}{ - "namespace": "default", - "name": "nonexistent", - })) + result, _, err := tool.HandleGetNetworkNeighborhood(context.Background(), getNetworkNeighborhoodInput{ + Namespace: "default", + Name: "nonexistent", + }) require.NoError(t, err) require.NotNil(t, result) assert.True(t, result.IsError) diff --git a/pkg/prometheus/prometheus.go b/pkg/prometheus/prometheus.go index c77e23d4..0b73e0c5 100644 --- a/pkg/prometheus/prometheus.go +++ b/pkg/prometheus/prometheus.go @@ -10,10 +10,8 @@ import ( "time" "github.com/kagent-dev/tools/internal/errors" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/kagent-dev/tools/internal/security" - "github.com/kagent-dev/tools/internal/telemetry" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" ) // clientKey is the context key for the http client. @@ -26,24 +24,35 @@ func getHTTPClient(ctx context.Context) *http.Client { return http.DefaultClient } -// Prometheus tools using direct HTTP API calls +// prometheusErrResult adapts ToolError to an MCP error result. +func prometheusErrResult(toolErr *errors.ToolError) *mcp.CallToolResult { + return toolErr.ToMCPResult() +} + +type prometheusQueryInput struct { + Query string `json:"query" jsonschema:"PromQL query to execute"` + PrometheusURL string `json:"prometheus_url" jsonschema:"Prometheus server URL (default: http://localhost:9090)"` +} -func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090") - query := mcp.ParseString(request, "query", "") +func handlePrometheusQueryTool(ctx context.Context, request *mcp.CallToolRequest, in prometheusQueryInput) (*mcp.CallToolResult, any, error) { + prometheusURL := in.PrometheusURL + if prometheusURL == "" { + prometheusURL = "http://localhost:9090" + } + query := in.Query if query == "" { - return mcp.NewToolResultError("query parameter is required"), nil + return mcp.NewToolResultError("query parameter is required"), nil, nil } // Validate prometheus URL if err := security.ValidateURL(prometheusURL); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil, nil } // Validate PromQL query if err := security.ValidatePromQLQuery(query); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil, nil } // Make request to Prometheus API @@ -60,7 +69,7 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) toolErr := errors.NewPrometheusError("create_request", err). WithContext("prometheus_url", prometheusURL). WithContext("query", query) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } resp, err := client.Do(req) @@ -69,7 +78,7 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) WithContext("prometheus_url", prometheusURL). WithContext("query", query). WithContext("api_url", apiURL) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } defer resp.Body.Close() @@ -79,7 +88,7 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) WithContext("prometheus_url", prometheusURL). WithContext("query", query). WithContext("status_code", resp.StatusCode) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } if resp.StatusCode != http.StatusOK { @@ -88,58 +97,72 @@ func handlePrometheusQueryTool(ctx context.Context, request mcp.CallToolRequest) WithContext("query", query). WithContext("status_code", resp.StatusCode). WithContext("response_body", string(body)) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } // Parse the JSON response to pretty-print it var result interface{} if err := json.Unmarshal(body, &result); err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } prettyJSON, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } - return mcp.NewToolResultText(string(prettyJSON)), nil + return mcp.NewToolResultText(string(prettyJSON)), nil, nil } -func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090") - query := mcp.ParseString(request, "query", "") - start := mcp.ParseString(request, "start", "") - end := mcp.ParseString(request, "end", "") - step := mcp.ParseString(request, "step", "15s") +type prometheusRangeQueryInput struct { + Query string `json:"query" jsonschema:"PromQL query to execute"` + Start string `json:"start" jsonschema:"Start time (Unix timestamp or relative time)"` + End string `json:"end" jsonschema:"End time (Unix timestamp or relative time)"` + Step string `json:"step" jsonschema:"Query resolution step (default: 15s)"` + PrometheusURL string `json:"prometheus_url" jsonschema:"Prometheus server URL (default: http://localhost:9090)"` +} + +func handlePrometheusRangeQueryTool(ctx context.Context, request *mcp.CallToolRequest, in prometheusRangeQueryInput) (*mcp.CallToolResult, any, error) { + prometheusURL := in.PrometheusURL + if prometheusURL == "" { + prometheusURL = "http://localhost:9090" + } + query := in.Query + start := in.Start + end := in.End + step := in.Step + if step == "" { + step = "15s" + } if query == "" { - return mcp.NewToolResultError("query parameter is required"), nil + return mcp.NewToolResultError("query parameter is required"), nil, nil } // Validate prometheus URL if err := security.ValidateURL(prometheusURL); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil, nil } // Validate PromQL query if err := security.ValidatePromQLQuery(query); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid PromQL query: %v", err)), nil, nil } // Validate time parameters if provided if start != "" { if err := security.ValidateCommandInput(start); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid start time: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid start time: %v", err)), nil, nil } } if end != "" { if err := security.ValidateCommandInput(end); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid end time: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid end time: %v", err)), nil, nil } } if step != "" { if err := security.ValidateCommandInput(step); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid step parameter: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid step parameter: %v", err)), nil, nil } } @@ -164,44 +187,51 @@ func handlePrometheusRangeQueryTool(ctx context.Context, request mcp.CallToolReq client := getHTTPClient(ctx) req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil) if err != nil { - return mcp.NewToolResultError("failed to create request: " + err.Error()), nil + return mcp.NewToolResultError("failed to create request: " + err.Error()), nil, nil } resp, err := client.Do(req) if err != nil { - return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil + return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil, nil } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return mcp.NewToolResultError("failed to read response: " + err.Error()), nil + return mcp.NewToolResultError("failed to read response: " + err.Error()), nil, nil } if resp.StatusCode != http.StatusOK { - return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil + return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil, nil } // Parse the JSON response to pretty-print it var result interface{} if err := json.Unmarshal(body, &result); err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } prettyJSON, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } - return mcp.NewToolResultText(string(prettyJSON)), nil + return mcp.NewToolResultText(string(prettyJSON)), nil, nil +} + +type prometheusLabelsInput struct { + PrometheusURL string `json:"prometheus_url" jsonschema:"Prometheus server URL (default: http://localhost:9090)"` } -func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090") +func handlePrometheusLabelsQueryTool(ctx context.Context, request *mcp.CallToolRequest, in prometheusLabelsInput) (*mcp.CallToolResult, any, error) { + prometheusURL := in.PrometheusURL + if prometheusURL == "" { + prometheusURL = "http://localhost:9090" + } // Validate prometheus URL if err := security.ValidateURL(prometheusURL); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil, nil } // Make request to Prometheus API for labels @@ -213,7 +243,7 @@ func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRe toolErr := errors.NewPrometheusError("create_request", err). WithContext("prometheus_url", prometheusURL). WithContext("api_url", apiURL) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } resp, err := client.Do(req) @@ -221,7 +251,7 @@ func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRe toolErr := errors.NewPrometheusError("query_execution", err). WithContext("prometheus_url", prometheusURL). WithContext("api_url", apiURL) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } defer resp.Body.Close() @@ -231,7 +261,7 @@ func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRe WithContext("prometheus_url", prometheusURL). WithContext("api_url", apiURL). WithContext("status_code", resp.StatusCode) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } if resp.StatusCode != http.StatusOK { @@ -240,29 +270,36 @@ func handlePrometheusLabelsQueryTool(ctx context.Context, request mcp.CallToolRe WithContext("api_url", apiURL). WithContext("status_code", resp.StatusCode). WithContext("response_body", string(body)) - return toolErr.ToMCPResult(), nil + return prometheusErrResult(toolErr), nil, nil } // Parse the JSON response to pretty-print it var result interface{} if err := json.Unmarshal(body, &result); err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } prettyJSON, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } - return mcp.NewToolResultText(string(prettyJSON)), nil + return mcp.NewToolResultText(string(prettyJSON)), nil, nil } -func handlePrometheusTargetsQueryTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - prometheusURL := mcp.ParseString(request, "prometheus_url", "http://localhost:9090") +type prometheusTargetsInput struct { + PrometheusURL string `json:"prometheus_url" jsonschema:"Prometheus server URL (default: http://localhost:9090)"` +} + +func handlePrometheusTargetsQueryTool(ctx context.Context, request *mcp.CallToolRequest, in prometheusTargetsInput) (*mcp.CallToolResult, any, error) { + prometheusURL := in.PrometheusURL + if prometheusURL == "" { + prometheusURL = "http://localhost:9090" + } // Validate prometheus URL if err := security.ValidateURL(prometheusURL); err != nil { - return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil + return mcp.NewToolResultError(fmt.Sprintf("Invalid Prometheus URL: %v", err)), nil, nil } // Make request to Prometheus API for targets @@ -271,66 +308,61 @@ func handlePrometheusTargetsQueryTool(ctx context.Context, request mcp.CallToolR client := getHTTPClient(ctx) req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) if err != nil { - return mcp.NewToolResultError("failed to create request: " + err.Error()), nil + return mcp.NewToolResultError("failed to create request: " + err.Error()), nil, nil } resp, err := client.Do(req) if err != nil { - return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil + return mcp.NewToolResultError("failed to query Prometheus: " + err.Error()), nil, nil } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return mcp.NewToolResultError("failed to read response: " + err.Error()), nil + return mcp.NewToolResultError("failed to read response: " + err.Error()), nil, nil } if resp.StatusCode != http.StatusOK { - return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil + return mcp.NewToolResultError(fmt.Sprintf("Prometheus API error (%d): %s", resp.StatusCode, string(body))), nil, nil } // Parse the JSON response to pretty-print it var result interface{} if err := json.Unmarshal(body, &result); err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } prettyJSON, err := json.MarshalIndent(result, "", " ") if err != nil { - return mcp.NewToolResultText(string(body)), nil + return mcp.NewToolResultText(string(body)), nil, nil } - return mcp.NewToolResultText(string(prettyJSON)), nil + return mcp.NewToolResultText(string(prettyJSON)), nil, nil } -func RegisterTools(s *server.MCPServer, readOnly bool) { - s.AddTool(mcp.NewTool("prometheus_query_tool", - mcp.WithDescription("Execute a PromQL query against Prometheus"), - mcp.WithString("query", mcp.Description("PromQL query to execute"), mcp.Required()), - mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_query_tool", handlePrometheusQueryTool))) - - s.AddTool(mcp.NewTool("prometheus_query_range_tool", - mcp.WithDescription("Execute a PromQL range query against Prometheus"), - mcp.WithString("query", mcp.Description("PromQL query to execute"), mcp.Required()), - mcp.WithString("start", mcp.Description("Start time (Unix timestamp or relative time)")), - mcp.WithString("end", mcp.Description("End time (Unix timestamp or relative time)")), - mcp.WithString("step", mcp.Description("Query resolution step (default: 15s)")), - mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_query_range_tool", handlePrometheusRangeQueryTool))) - - s.AddTool(mcp.NewTool("prometheus_label_names_tool", - mcp.WithDescription("Get all available labels from Prometheus"), - mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_label_names_tool", handlePrometheusLabelsQueryTool))) - - s.AddTool(mcp.NewTool("prometheus_targets_tool", - mcp.WithDescription("Get all Prometheus targets and their status"), - mcp.WithString("prometheus_url", mcp.Description("Prometheus server URL (default: http://localhost:9090)")), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_targets_tool", handlePrometheusTargetsQueryTool))) - - s.AddTool(mcp.NewTool("prometheus_promql_tool", - mcp.WithDescription("Generate a PromQL query"), - mcp.WithString("query_description", mcp.Description("A string describing the query to generate"), mcp.Required()), - ), telemetry.AdaptToolHandler(telemetry.WithTracing("prometheus_promql_tool", handlePromql))) +func RegisterTools(s *mcp.Server, readOnly bool) { + mcp.AddTool(s, "prometheus", &mcp.Tool{ + Name: "prometheus_query_tool", + Description: "Execute a PromQL query against Prometheus", + }, handlePrometheusQueryTool) + + mcp.AddTool(s, "prometheus", &mcp.Tool{ + Name: "prometheus_query_range_tool", + Description: "Execute a PromQL range query against Prometheus", + }, handlePrometheusRangeQueryTool) + + mcp.AddTool(s, "prometheus", &mcp.Tool{ + Name: "prometheus_label_names_tool", + Description: "Get all available labels from Prometheus", + }, handlePrometheusLabelsQueryTool) + + mcp.AddTool(s, "prometheus", &mcp.Tool{ + Name: "prometheus_targets_tool", + Description: "Get all Prometheus targets and their status", + }, handlePrometheusTargetsQueryTool) + + mcp.AddTool(s, "prometheus", &mcp.Tool{ + Name: "prometheus_promql_tool", + Description: "Generate a PromQL query", + }, handlePromql) } diff --git a/pkg/prometheus/prometheus_test.go b/pkg/prometheus/prometheus_test.go index 1e8ffc49..792fad20 100644 --- a/pkg/prometheus/prometheus_test.go +++ b/pkg/prometheus/prometheus_test.go @@ -7,18 +7,17 @@ import ( "strings" "testing" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" ) func TestRegisterTools(t *testing.T) { t.Run("read-write", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, false) }) t.Run("read-only", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, true) }) } @@ -26,75 +25,74 @@ func TestRegisterTools(t *testing.T) { func TestPrometheusInputValidation(t *testing.T) { ctx := context.Background() - invalidURL := map[string]interface{}{"prometheus_url": "not a url", "query": "up"} - invalidQuery := map[string]interface{}{"prometheus_url": "http://localhost:9090", "query": "up; drop"} - t.Run("query invalid url", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = invalidURL - res, err := handlePrometheusQueryTool(ctx, req) + res, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + PrometheusURL: "not a url", + Query: "up", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("range invalid url", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = invalidURL - res, err := handlePrometheusRangeQueryTool(ctx, req) + res, _, err := handlePrometheusRangeQueryTool(ctx, &mcp.CallToolRequest{}, prometheusRangeQueryInput{ + PrometheusURL: "not a url", + Query: "up", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("labels invalid url", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"prometheus_url": "not a url"} - res, err := handlePrometheusLabelsQueryTool(ctx, req) + res, _, err := handlePrometheusLabelsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusLabelsInput{ + PrometheusURL: "not a url", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("targets invalid url", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"prometheus_url": "not a url"} - res, err := handlePrometheusTargetsQueryTool(ctx, req) + res, _, err := handlePrometheusTargetsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusTargetsInput{ + PrometheusURL: "not a url", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("query invalid promql", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = invalidQuery - res, err := handlePrometheusQueryTool(ctx, req) + res, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + PrometheusURL: "http://localhost:9090", + Query: "up; drop", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("range invalid promql", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = invalidQuery - res, err := handlePrometheusRangeQueryTool(ctx, req) + res, _, err := handlePrometheusRangeQueryTool(ctx, &mcp.CallToolRequest{}, prometheusRangeQueryInput{ + PrometheusURL: "http://localhost:9090", + Query: "up; drop", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) } func TestPrometheusLabelsTargetsErrorPaths(t *testing.T) { - args := map[string]interface{}{"prometheus_url": "http://localhost:9090"} - t.Run("labels client error", func(t *testing.T) { ctx := contextWithMockClient(newTestClient(nil, assert.AnError)) - req := mcp.CallToolRequest{} - req.Params.Arguments = args - res, err := handlePrometheusLabelsQueryTool(ctx, req) + res, _, err := handlePrometheusLabelsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusLabelsInput{ + PrometheusURL: "http://localhost:9090", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("labels malformed json", func(t *testing.T) { ctx := contextWithMockClient(newTestClient(createMockResponse(200, "not json"), nil)) - req := mcp.CallToolRequest{} - req.Params.Arguments = args - res, err := handlePrometheusLabelsQueryTool(ctx, req) + res, _, err := handlePrometheusLabelsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusLabelsInput{ + PrometheusURL: "http://localhost:9090", + }) assert.NoError(t, err) assert.False(t, res.IsError) assert.Contains(t, getResultText(res), "not json") @@ -102,18 +100,18 @@ func TestPrometheusLabelsTargetsErrorPaths(t *testing.T) { t.Run("targets client error", func(t *testing.T) { ctx := contextWithMockClient(newTestClient(nil, assert.AnError)) - req := mcp.CallToolRequest{} - req.Params.Arguments = args - res, err := handlePrometheusTargetsQueryTool(ctx, req) + res, _, err := handlePrometheusTargetsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusTargetsInput{ + PrometheusURL: "http://localhost:9090", + }) assert.NoError(t, err) assert.True(t, res.IsError) }) t.Run("targets malformed json", func(t *testing.T) { ctx := contextWithMockClient(newTestClient(createMockResponse(200, "not json"), nil)) - req := mcp.CallToolRequest{} - req.Params.Arguments = args - res, err := handlePrometheusTargetsQueryTool(ctx, req) + res, _, err := handlePrometheusTargetsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusTargetsInput{ + PrometheusURL: "http://localhost:9090", + }) assert.NoError(t, err) assert.False(t, res.IsError) assert.Contains(t, getResultText(res), "not json") @@ -154,7 +152,7 @@ func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { return "" } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { return textContent.Text } return "" @@ -192,13 +190,10 @@ func TestHandlePrometheusQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - "prometheus_url": "http://localhost:9090", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + PrometheusURL: "http://localhost:9090", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -211,12 +206,9 @@ func TestHandlePrometheusQueryTool(t *testing.T) { t.Run("missing query parameter", func(t *testing.T) { ctx := context.Background() - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "prometheus_url": "http://localhost:9090", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + PrometheusURL: "http://localhost:9090", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -228,12 +220,9 @@ func TestHandlePrometheusQueryTool(t *testing.T) { client := newTestClient(nil, assert.AnError) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -245,12 +234,9 @@ func TestHandlePrometheusQueryTool(t *testing.T) { client := newTestClient(createMockResponse(500, "Internal Server Error"), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -262,12 +248,9 @@ func TestHandlePrometheusQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, "invalid json {"), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -281,12 +264,9 @@ func TestHandlePrometheusQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -312,15 +292,12 @@ func TestHandlePrometheusRangeQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - "start": "1609459200", - "end": "1609459260", - "step": "60s", - } - - result, err := handlePrometheusRangeQueryTool(ctx, request) + result, _, err := handlePrometheusRangeQueryTool(ctx, &mcp.CallToolRequest{}, prometheusRangeQueryInput{ + Query: "up", + Start: "1609459200", + End: "1609459260", + Step: "60s", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -333,10 +310,7 @@ func TestHandlePrometheusRangeQueryTool(t *testing.T) { t.Run("missing query parameter", func(t *testing.T) { ctx := context.Background() - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} - - result, err := handlePrometheusRangeQueryTool(ctx, request) + result, _, err := handlePrometheusRangeQueryTool(ctx, &mcp.CallToolRequest{}, prometheusRangeQueryInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -349,12 +323,9 @@ func TestHandlePrometheusRangeQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusRangeQueryTool(ctx, request) + result, _, err := handlePrometheusRangeQueryTool(ctx, &mcp.CallToolRequest{}, prometheusRangeQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -372,10 +343,7 @@ func TestHandlePrometheusLabelsQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} - - result, err := handlePrometheusLabelsQueryTool(ctx, request) + result, _, err := handlePrometheusLabelsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusLabelsInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -391,10 +359,7 @@ func TestHandlePrometheusLabelsQueryTool(t *testing.T) { client := newTestClient(nil, assert.AnError) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} - - result, err := handlePrometheusLabelsQueryTool(ctx, request) + result, _, err := handlePrometheusLabelsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusLabelsInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -407,12 +372,9 @@ func TestHandlePrometheusLabelsQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "prometheus_url": "http://custom:9090", - } - - result, err := handlePrometheusLabelsQueryTool(ctx, request) + result, _, err := handlePrometheusLabelsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusLabelsInput{ + PrometheusURL: "http://custom:9090", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -440,10 +402,7 @@ func TestHandlePrometheusTargetsQueryTool(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} - - result, err := handlePrometheusTargetsQueryTool(ctx, request) + result, _, err := handlePrometheusTargetsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusTargetsInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -459,10 +418,7 @@ func TestHandlePrometheusTargetsQueryTool(t *testing.T) { client := newTestClient(createMockResponse(404, "Not Found"), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} - - result, err := handlePrometheusTargetsQueryTool(ctx, request) + result, _, err := handlePrometheusTargetsQueryTool(ctx, &mcp.CallToolRequest{}, prometheusTargetsInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -474,10 +430,7 @@ func TestHandlePrometheusTargetsQueryTool(t *testing.T) { func TestHandlePromql(t *testing.T) { t.Run("missing query description", func(t *testing.T) { ctx := context.Background() - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} - - result, err := handlePromql(ctx, request) + result, _, err := handlePromql(ctx, &mcp.CallToolRequest{}, promqlInput{}) assert.NoError(t, err) assert.NotNil(t, result) @@ -487,12 +440,9 @@ func TestHandlePromql(t *testing.T) { t.Run("with query description", func(t *testing.T) { ctx := context.Background() - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query_description": "CPU usage percentage", - } - - result, err := handlePromql(ctx, request) + result, _, err := handlePromql(ctx, &mcp.CallToolRequest{}, promqlInput{ + QueryDescription: "CPU usage percentage", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -522,12 +472,9 @@ func TestPrometheusToolsContextCancellation(t *testing.T) { ctx := contextWithMockClient(client) _ = cancelCtx - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) // Should handle cancellation gracefully assert.NoError(t, err) @@ -551,12 +498,9 @@ func TestPrometheusToolsEdgeCases(t *testing.T) { client := newTestClient(createMockResponse(200, largeResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -571,12 +515,9 @@ func TestPrometheusToolsEdgeCases(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": `up{instance=~".*:9090"}`, - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: `up{instance=~".*:9090"}`, + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -587,12 +528,9 @@ func TestPrometheusToolsEdgeCases(t *testing.T) { client := newTestClient(createMockResponse(200, ""), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: "up", + }) assert.NoError(t, err) assert.NotNil(t, result) @@ -607,12 +545,9 @@ func TestPrometheusURLEncoding(t *testing.T) { client := newTestClient(createMockResponse(200, mockResponse), nil) ctx := contextWithMockClient(client) - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{ - "query": "up{job=\"test service\"}", - } - - result, err := handlePrometheusQueryTool(ctx, request) + result, _, err := handlePrometheusQueryTool(ctx, &mcp.CallToolRequest{}, prometheusQueryInput{ + Query: `up{job="test service"}`, + }) assert.NoError(t, err) assert.NotNil(t, result) diff --git a/pkg/prometheus/promql.go b/pkg/prometheus/promql.go index dd2b7460..d95b3c1b 100644 --- a/pkg/prometheus/promql.go +++ b/pkg/prometheus/promql.go @@ -4,7 +4,7 @@ import ( "context" _ "embed" - "github.com/mark3labs/mcp-go/mcp" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai" ) @@ -12,15 +12,19 @@ import ( //go:embed promql_prompt.md var promqlPrompt string -func handlePromql(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - queryDescription := mcp.ParseString(request, "query_description", "") +type promqlInput struct { + QueryDescription string `json:"query_description" jsonschema:"A string describing the query to generate"` +} + +func handlePromql(ctx context.Context, request *mcp.CallToolRequest, in promqlInput) (*mcp.CallToolResult, any, error) { + queryDescription := in.QueryDescription if queryDescription == "" { - return mcp.NewToolResultError("query_description is required"), nil + return mcp.NewToolResultError("query_description is required"), nil, nil } llm, err := openai.New() if err != nil { - return mcp.NewToolResultError("failed to create LLM client: " + err.Error()), nil + return mcp.NewToolResultError("failed to create LLM client: " + err.Error()), nil, nil } contents := []llms.MessageContent{ @@ -41,13 +45,13 @@ func handlePromql(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallTo resp, err := llm.GenerateContent(ctx, contents, llms.WithModel("gpt-4o-mini")) if err != nil { - return mcp.NewToolResultError("failed to generate content: " + err.Error()), nil + return mcp.NewToolResultError("failed to generate content: " + err.Error()), nil, nil } choices := resp.Choices if len(choices) < 1 { - return mcp.NewToolResultError("empty response from model"), nil + return mcp.NewToolResultError("empty response from model"), nil, nil } c1 := choices[0] - return mcp.NewToolResultText(c1.Content), nil + return mcp.NewToolResultText(c1.Content), nil, nil } diff --git a/pkg/utils/common.go b/pkg/utils/common.go index f149d013..03c84b00 100644 --- a/pkg/utils/common.go +++ b/pkg/utils/common.go @@ -9,8 +9,7 @@ import ( "github.com/kagent-dev/tools/internal/commands" "github.com/kagent-dev/tools/internal/logger" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" ) // KubeConfigManager manages kubeconfig path with thread safety @@ -48,9 +47,9 @@ func AddKubeconfigArgs(args []string) []string { return args } -// shellTool provides shell command execution functionality +// shellParams is the typed input for the shell tool. type shellParams struct { - Command string `json:"command" description:"The shell command to execute"` + Command string `json:"command" jsonschema:"The shell command to execute"` } func shellTool(ctx context.Context, params shellParams) (string, error) { @@ -66,43 +65,46 @@ func shellTool(ctx context.Context, params shellParams) (string, error) { return commands.NewCommandBuilder(cmd).WithArgs(args...).Execute(ctx) } -func handleShellTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - command := mcp.ParseString(request, "command", "") - if command == "" { - return mcp.NewToolResultError("command parameter is required"), nil +func handleShellTool(ctx context.Context, request *mcp.CallToolRequest, in shellParams) (*mcp.CallToolResult, any, error) { + if in.Command == "" { + return mcp.NewToolResultError("command parameter is required"), nil, nil } - result, err := shellTool(ctx, shellParams{Command: command}) + result, err := shellTool(ctx, in) if err != nil { - return mcp.NewToolResultError(err.Error()), nil + return mcp.NewToolResultError(err.Error()), nil, nil } - return mcp.NewToolResultText(result), nil + return mcp.NewToolResultText(result), nil, nil } +// datetimeInput is the (empty) typed input for the datetime tool. +type datetimeInput struct{} + // handleGetCurrentDateTimeTool provides datetime functionality for both MCP and testing -func handleGetCurrentDateTimeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func handleGetCurrentDateTimeTool(ctx context.Context, request *mcp.CallToolRequest, in datetimeInput) (*mcp.CallToolResult, any, error) { // Returns the current date and time in ISO 8601 format (RFC3339) // This matches the Python implementation: datetime.datetime.now().isoformat() now := time.Now() - return mcp.NewToolResultText(now.Format(time.RFC3339)), nil + return mcp.NewToolResultText(now.Format(time.RFC3339)), nil, nil } -func RegisterTools(s *server.MCPServer, readOnly bool) { +func RegisterTools(s *mcp.Server, readOnly bool) { logger.Get().Info("RegisterTools initialized") // Register shell tool - disabled in read-only mode as it allows arbitrary command execution if !readOnly { - s.AddTool(mcp.NewTool("shell", - mcp.WithDescription("Execute shell commands"), - mcp.WithString("command", mcp.Description("The shell command to execute"), mcp.Required()), - ), handleShellTool) + mcp.AddTool(s, "utils", &mcp.Tool{ + Name: "shell", + Description: "Execute shell commands", + }, handleShellTool) } // Register datetime tool - s.AddTool(mcp.NewTool("datetime_get_current_time", - mcp.WithDescription("Returns the current date and time in ISO 8601 format."), - ), handleGetCurrentDateTimeTool) + mcp.AddTool(s, "utils", &mcp.Tool{ + Name: "datetime_get_current_time", + Description: "Returns the current date and time in ISO 8601 format.", + }, handleGetCurrentDateTimeTool) // Note: LLM Tool implementation would go here if needed } diff --git a/pkg/utils/common_test.go b/pkg/utils/common_test.go index 6502225a..72c1b641 100644 --- a/pkg/utils/common_test.go +++ b/pkg/utils/common_test.go @@ -5,8 +5,7 @@ import ( "testing" "github.com/kagent-dev/tools/internal/cmd" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcp "github.com/kagent-dev/tools/internal/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -68,12 +67,12 @@ func TestShellTool(t *testing.T) { func TestRegisterTools(t *testing.T) { t.Run("read-write registers shell", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, false) }) t.Run("read-only omits shell", func(t *testing.T) { - s := server.NewMCPServer("test", "v0.0.1") + s := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}, nil) RegisterTools(s, true) }) } @@ -84,17 +83,13 @@ func TestHandleShellTool(t *testing.T) { ctx := cmd.WithShellExecutor(context.Background(), mock) t.Run("success", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"command": "echo hi"} - res, err := handleShellTool(ctx, req) + res, _, err := handleShellTool(ctx, &mcp.CallToolRequest{}, shellParams{Command: "echo hi"}) require.NoError(t, err) assert.False(t, res.IsError) }) t.Run("missing command", func(t *testing.T) { - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{} - res, err := handleShellTool(ctx, req) + res, _, err := handleShellTool(ctx, &mcp.CallToolRequest{}, shellParams{}) require.NoError(t, err) assert.True(t, res.IsError) assert.Contains(t, getResultText(res), "command parameter is required") @@ -104,9 +99,7 @@ func TestHandleShellTool(t *testing.T) { m := cmd.NewMockShellExecutor() m.AddCommandString("false", []string{}, "", assert.AnError) errCtx := cmd.WithShellExecutor(context.Background(), m) - req := mcp.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"command": "false"} - res, err := handleShellTool(errCtx, req) + res, _, err := handleShellTool(errCtx, &mcp.CallToolRequest{}, shellParams{Command: "false"}) require.NoError(t, err) assert.True(t, res.IsError) }) @@ -116,7 +109,7 @@ func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { return "" } - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { return textContent.Text } return "" diff --git a/pkg/utils/datetime_test.go b/pkg/utils/datetime_test.go index 8f1cd641..1d105ee1 100644 --- a/pkg/utils/datetime_test.go +++ b/pkg/utils/datetime_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/mark3labs/mcp-go/mcp" + mcp "github.com/kagent-dev/tools/internal/mcp" ) // Test the actual MCP tool handler functions @@ -13,9 +13,8 @@ import ( func TestHandleGetCurrentDateTimeTool(t *testing.T) { ctx := context.Background() - request := mcp.CallToolRequest{} - result, err := handleGetCurrentDateTimeTool(ctx, request) + result, _, err := handleGetCurrentDateTimeTool(ctx, &mcp.CallToolRequest{}, datetimeInput{}) if err != nil { t.Fatalf("handleGetCurrentDateTimeTool failed: %v", err) } @@ -30,7 +29,7 @@ func TestHandleGetCurrentDateTimeTool(t *testing.T) { // Verify the result is a valid RFC3339 timestamp (ISO 8601 format) if len(result.Content) > 0 { - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { _, err := time.Parse(time.RFC3339, textContent.Text) if err != nil { t.Errorf("Result is not valid RFC3339 timestamp: %v", err) @@ -51,10 +50,8 @@ func TestHandleGetCurrentDateTimeTool(t *testing.T) { func TestHandleGetCurrentDateTimeToolNoParameters(t *testing.T) { // Test that the tool works without any parameters (as per Python implementation) ctx := context.Background() - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]interface{}{} // Empty arguments - result, err := handleGetCurrentDateTimeTool(ctx, request) + result, _, err := handleGetCurrentDateTimeTool(ctx, &mcp.CallToolRequest{}, datetimeInput{}) if err != nil { t.Fatalf("handleGetCurrentDateTimeTool failed with empty args: %v", err) } @@ -69,7 +66,7 @@ func TestHandleGetCurrentDateTimeToolNoParameters(t *testing.T) { // Verify we get a valid timestamp if len(result.Content) > 0 { - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { _, err := time.Parse(time.RFC3339, textContent.Text) if err != nil { t.Errorf("Result is not valid RFC3339 timestamp: %v", err) @@ -85,15 +82,14 @@ func TestHandleGetCurrentDateTimeToolNoParameters(t *testing.T) { func TestDateTimeFormatConsistency(t *testing.T) { // Test that our Go implementation produces ISO 8601 format consistent with Python ctx := context.Background() - request := mcp.CallToolRequest{} - result, err := handleGetCurrentDateTimeTool(ctx, request) + result, _, err := handleGetCurrentDateTimeTool(ctx, &mcp.CallToolRequest{}, datetimeInput{}) if err != nil { t.Fatalf("handleGetCurrentDateTimeTool failed: %v", err) } if len(result.Content) > 0 { - if textContent, ok := result.Content[0].(mcp.TextContent); ok { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { timestamp := textContent.Text // Check that it follows RFC3339 format (which is ISO 8601 compliant) diff --git a/test/e2e/helpers_test.go b/test/e2e/helpers_test.go index 8f6d5221..70da3e46 100644 --- a/test/e2e/helpers_test.go +++ b/test/e2e/helpers_test.go @@ -15,9 +15,7 @@ import ( "time" "github.com/kagent-dev/tools/internal/commands" - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/client/transport" - "github.com/mark3labs/mcp-go/mcp" + mcp "github.com/modelcontextprotocol/go-sdk/mcp" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -164,10 +162,10 @@ func (ts *TestServer) Stop() error { return nil } -// MCPClient represents a client for communicating with the MCP server using the official mcp-go client +// MCPClient represents a client for communicating with the MCP server using the official go-sdk client type MCPClient struct { - client *client.Client - log *slog.Logger + session *mcp.ClientSession + log *slog.Logger } // InstallKAgentTools installs KAgent Tools using helm in the specified namespace @@ -247,42 +245,28 @@ func InstallKAgentTools(namespace string, releaseName string) { Expect(nodePort).To(Equal("30885")) } -// GetMCPClient creates a new MCP client configured for the e2e test environment using the official mcp-go client +// GetMCPClient creates a new MCP client configured for the e2e test environment using the official go-sdk client func GetMCPClient() (*MCPClient, error) { - // Create HTTP transport for the MCP server with timeout long enough for operations like Istio installation - httpTransport, err := transport.NewStreamableHTTP("http://127.0.0.1:30885/mcp", transport.WithHTTPTimeout(180*time.Second)) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + // HTTP timeout long enough for operations like Istio installation. + httpTransport := &mcp.StreamableClientTransport{ + Endpoint: "http://127.0.0.1:30885/mcp", + HTTPClient: &http.Client{Timeout: 180 * time.Second}, } - // Create the official MCP client - mcpClient := client.NewClient(httpTransport) + mcpClient := mcp.NewClient(&mcp.Implementation{Name: "e2e-test-client", Version: "1.0.0"}, nil) - // Start the client + // Connect performs the initialization handshake. ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - if err := mcpClient.Start(ctx); err != nil { - return nil, fmt.Errorf("failed to start MCP client: %w", err) - } - - // Initialize the client - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "e2e-test-client", - Version: "1.0.0", - } - initRequest.Params.Capabilities = mcp.ClientCapabilities{} - - _, err = mcpClient.Initialize(ctx, initRequest) + session, err := mcpClient.Connect(ctx, httpTransport, nil) if err != nil { - return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + return nil, fmt.Errorf("failed to connect MCP client: %w", err) } mcpHelper := &MCPClient{ - client: mcpClient, - log: slog.Default(), + session: session, + log: slog.Default(), } // Validate connection by listing tools @@ -299,13 +283,11 @@ func (c *MCPClient) listTools() ([]interface{}, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - request := mcp.ListToolsRequest{} - result, err := c.client.ListTools(ctx, request) + result, err := c.session.ListTools(ctx, &mcp.ListToolsParams{}) if err != nil { return nil, err } - // Convert tools to interface{} slice for compatibility tools := make([]interface{}, len(result.Tools)) for i, tool := range result.Tools { tools[i] = tool @@ -329,19 +311,15 @@ func (c *MCPClient) k8sListResources(resourceType string) (interface{}, error) { Output: "json", } - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "k8s_get_resources", - Arguments: arguments, - }, - } - - result, err := c.client.CallTool(ctx, request) + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: "k8s_get_resources", + Arguments: arguments, + }) if err != nil { return nil, err } if result.IsError { - return nil, fmt.Errorf("tool call failed: %s", result.Content) + return nil, fmt.Errorf("tool call failed: %v", result.Content) } return result, nil } @@ -361,19 +339,15 @@ func (c *MCPClient) helmListReleases() (interface{}, error) { Output: "json", } - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "helm_list_releases", - Arguments: arguments, - }, - } - - result, err := c.client.CallTool(ctx, request) + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: "helm_list_releases", + Arguments: arguments, + }) if err != nil { return nil, err } if result.IsError { - return nil, fmt.Errorf("tool call failed: %s", result.Content) + return nil, fmt.Errorf("tool call failed: %v", result.Content) } return result, nil } @@ -391,19 +365,15 @@ func (c *MCPClient) istioInstall(profile string) (interface{}, error) { Profile: profile, } - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "istio_install_istio", - Arguments: arguments, - }, - } - - result, err := c.client.CallTool(ctx, request) + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: "istio_install_istio", + Arguments: arguments, + }) if err != nil { return nil, err } if result.IsError { - return nil, fmt.Errorf("tool call failed: %s", result.Content) + return nil, fmt.Errorf("tool call failed: %v", result.Content) } return result, nil } @@ -423,19 +393,15 @@ func (c *MCPClient) argoRolloutsList(namespace string) (interface{}, error) { Output: "json", } - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "argo_rollouts_list", - Arguments: arguments, - }, - } - - result, err := c.client.CallTool(ctx, request) + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: "argo_rollouts_list", + Arguments: arguments, + }) if err != nil { return nil, err } if result.IsError { - return nil, fmt.Errorf("tool call failed: %s", result.Content) + return nil, fmt.Errorf("tool call failed: %v", result.Content) } return result, nil } @@ -445,14 +411,10 @@ func (c *MCPClient) ciliumStatus() (interface{}, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - request := mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "cilium_status_and_version", - Arguments: nil, - }, - } - - result, err := c.client.CallTool(ctx, request) + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: "cilium_status_and_version", + Arguments: nil, + }) if err != nil { return nil, err }