From 01042e57d5cfba77743a61a81868b36aa5fef3ac Mon Sep 17 00:00:00 2001 From: Dmytro Rashko Date: Tue, 23 Jun 2026 16:05:16 +0200 Subject: [PATCH] test: raise pkg coverage above 80% and bump tool versions Add unit tests across pkg providers so every package in ./pkg exceeds 80% statement coverage: - utils: 6.5% -> 100% (shell tool, kubeconfig, RegisterTools) - cilium: 46.7% -> 85.5% (cilium-dbg and CLI handlers) - argo: 49.7% -> 88.3% (gateway plugin, list/check logs, RegisterTools) - istio: 61.6% -> 92.7% (waypoint apply/delete/status, ztunnel) - prometheus: 72.4% -> 81.6% (RegisterTools, validation/error paths) - k8s: 73.9% -> 82.2% (RegisterTools, NewK8sToolWithConfig) Extract the shell tool closure into a named handleShellTool for testability. Bump tool versions to latest releases (make check-releases): - cilium 0.19.2 -> 0.19.4 - istio 1.29.1 -> 1.30.1 - helm 4.1.3 -> 4.2.2 - kubectl 1.35.3 -> 1.36.2 Signed-off-by: Dmytro Rashko --- Makefile | 8 +- go.mod | 2 +- pkg/argo/argo_test.go | 156 ++++++++++++++++++++++++++++++ pkg/cilium/cilium_test.go | 144 +++++++++++++++++++++++++++ pkg/istio/istio_test.go | 140 +++++++++++++++++++++++++++ pkg/k8s/k8s_test.go | 17 ++++ pkg/prometheus/prometheus_test.go | 116 ++++++++++++++++++++++ pkg/utils/common.go | 29 +++--- pkg/utils/common_test.go | 123 +++++++++++++++++++++++ 9 files changed, 716 insertions(+), 19 deletions(-) create mode 100644 pkg/utils/common_test.go diff --git a/Makefile b/Makefile index dde2eea2..2c7dd490 100644 --- a/Makefile +++ b/Makefile @@ -136,11 +136,11 @@ DOCKER_BUILDER ?= docker buildx DOCKER_BUILD_ARGS ?= --pull --load --platform linux/$(LOCALARCH) --builder $(BUILDX_BUILDER_NAME) # tools image build args -TOOLS_ISTIO_VERSION ?= 1.29.1 +TOOLS_ISTIO_VERSION ?= 1.30.1 TOOLS_ARGO_ROLLOUTS_VERSION ?= 1.9.0 -TOOLS_KUBECTL_VERSION ?= 1.35.3 -TOOLS_HELM_VERSION ?= 4.1.3 -TOOLS_CILIUM_VERSION ?= 0.19.2 +TOOLS_KUBECTL_VERSION ?= 1.36.2 +TOOLS_HELM_VERSION ?= 4.2.2 +TOOLS_CILIUM_VERSION ?= 0.19.4 # build args TOOLS_IMAGE_BUILD_ARGS = --build-arg VERSION=$(VERSION) diff --git a/go.mod b/go.mod index 5f1e7439..7535dbd8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/kagent-dev/tools -go 1.26.1 +go 1.26.4 require ( github.com/joho/godotenv v1.5.1 diff --git a/pkg/argo/argo_test.go b/pkg/argo/argo_test.go index 3af620f2..ce00d7b8 100644 --- a/pkg/argo/argo_test.go +++ b/pkg/argo/argo_test.go @@ -7,10 +7,166 @@ import ( "github.com/kagent-dev/tools/internal/cmd" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestRegisterTools(t *testing.T) { + t.Run("read-write", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, false) + }) + t.Run("read-only", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, true) + }) +} + +func TestHandleListRollouts(t *testing.T) { + t.Run("default namespace and type", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "list", "rollouts", "-n", "argo-rollouts"}, "NAME STATUS\nmyapp Healthy", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + result, err := handleListRollouts(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "myapp") + }) + + t.Run("experiments type custom namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "list", "experiments", "-n", "prod"}, "NAME", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"type": "experiments", "namespace": "prod"} + result, err := handleListRollouts(ctx, req) + assert.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("command failure", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"argo", "rollouts", "list", "rollouts", "-n", "argo-rollouts"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + result, err := handleListRollouts(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "Error listing rollouts") + }) +} + +func TestHandleCheckPluginLogs(t *testing.T) { + t.Run("plugin install found in logs", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + logs := `Downloading plugin argoproj-labs/gatewayAPI from: https://github.com/x/releases/download/v0.5.0/gatewayapi-plugin-linux-amd64" +Download complete, it took 1.5s` + mock.AddCommandString("kubectl", []string{"logs", "-n", "argo-rollouts", "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"}, logs, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + result, err := handleCheckPluginLogs(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), "0.5.0") + assert.Contains(t, getResultText(result), `"installed": true`) + }) + + t.Run("plugin install not found", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"logs", "-n", "argo-rollouts", "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"}, "no plugin here", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + result, err := handleCheckPluginLogs(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), "Plugin installation not found") + }) + + t.Run("command failure", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"logs", "-n", "argo-rollouts", "-l", "app.kubernetes.io/name=argo-rollouts", "--tail", "100"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + result, err := handleCheckPluginLogs(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), `"installed": false`) + }) +} + +func TestConfigureGatewayPlugin(t *testing.T) { + t.Run("applies configmap successfully", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddPartialMatcherString("kubectl", []string{"apply", "-f"}, "configmap/argo-rollouts-config created", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + status := configureGatewayPlugin(ctx, "0.5.0", "argo-rollouts") + assert.True(t, status.Installed) + assert.Equal(t, "0.5.0", status.Version) + assert.NotEmpty(t, status.Architecture) + }) + + t.Run("apply failure", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddPartialMatcherString("kubectl", []string{"apply", "-f"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + status := configureGatewayPlugin(ctx, "0.5.0", "argo-rollouts") + assert.False(t, status.Installed) + assert.Contains(t, status.ErrorMessage, "Error applying Gateway API plugin config") + }) +} + +func TestHandleVerifyGatewayPluginAlreadyConfigured(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", []string{"get", "configmap", "argo-rollouts-config", "-n", "argo-rollouts", "-o", "yaml"}, "data:\n trafficRouterPlugins: argoproj-labs/gatewayAPI", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + result, err := handleVerifyGatewayPlugin(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), "already configured") +} + +func TestHandleVerifyArgoRolloutsControllerInstallStatuses(t *testing.T) { + baseCmd := []string{"get", "pods", "-n", "argo-rollouts", "-l", "app.kubernetes.io/component=rollouts-controller", "-o", "jsonpath={.items[*].status.phase}"} + + t.Run("all running", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", baseCmd, "Running Running", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), "All pods are running") + }) + + t.Run("not all running", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", baseCmd, "Running Pending", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), "Not all pods are running") + }) + + t.Run("no pods", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", baseCmd, "", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.Contains(t, getResultText(result), "No pods found") + }) + + t.Run("command error", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("kubectl", baseCmd, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleVerifyArgoRolloutsControllerInstall(ctx, mcp.CallToolRequest{}) + assert.NoError(t, err) + assert.True(t, result.IsError) + }) +} + // Helper function to extract text content from MCP result func getResultText(result *mcp.CallToolResult) string { if result == nil || len(result.Content) == 0 { diff --git a/pkg/cilium/cilium_test.go b/pkg/cilium/cilium_test.go index 50bbed6d..84313e77 100644 --- a/pkg/cilium/cilium_test.go +++ b/pkg/cilium/cilium_test.go @@ -619,3 +619,147 @@ func getResultText(r *mcp.CallToolResult) string { } return "" } + +type ciliumHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// TestCiliumDbgHandlers exercises the success path of every cilium-dbg based handler. +func TestCiliumDbgHandlers(t *testing.T) { + cases := []struct { + name string + handler ciliumHandler + args map[string]any + dbgArgs []string + expect string + }{ + {"manage_endpoint_labels", handleManageEndpointLabels, map[string]any{"endpoint_id": "34", "labels": "key=val"}, []string{"endpoint", "labels", "34", "--add", "key=val"}, "ok"}, + {"manage_endpoint_configuration", handleManageEndpointConfiguration, map[string]any{"endpoint_id": "34", "config": "Debug=true"}, []string{"endpoint", "config", "34", "Debug=true"}, "ok"}, + {"disconnect_endpoint", handleDisconnectEndpoint, map[string]any{"endpoint_id": "34"}, []string{"endpoint", "disconnect", "34"}, "ok"}, + {"get_identity_details", handleGetIdentityDetails, map[string]any{"identity_id": "123"}, []string{"identity", "get", "123"}, "ok"}, + {"flush_ipsec_state", handleFlushIPsecState, map[string]any{}, []string{"encrypt", "flush", "-f"}, "ok"}, + {"list_envoy_config", handleListEnvoyConfig, map[string]any{"resource_name": "clusters"}, []string{"envoy", "admin", "clusters"}, "ok"}, + {"show_ipcache_cidr", handleShowIPCacheInformation, map[string]any{"cidr": "10.0.0.0/24"}, []string{"ip", "get", "10.0.0.0/24"}, "ok"}, + {"show_ipcache_labels", handleShowIPCacheInformation, map[string]any{"labels": "app=foo"}, []string{"ip", "get", "--labels", "app=foo"}, "ok"}, + {"delete_kvstore_key", handleDeleteKeyFromKVStore, map[string]any{"key": "foo"}, []string{"kvstore", "delete", "foo"}, "ok"}, + {"get_kvstore_key", handleGetKVStoreKey, map[string]any{"key": "foo"}, []string{"kvstore", "get", "foo"}, "ok"}, + {"set_kvstore_key", handleSetKVStoreKey, map[string]any{"key": "foo", "value": "bar"}, []string{"kvstore", "set", "foo=bar"}, "ok"}, + {"show_load_information", handleShowLoadInformation, map[string]any{}, []string{"loadinfo"}, "ok"}, + {"display_policy_node_info", handleDisplayPolicyNodeInformation, map[string]any{}, []string{"policy", "get"}, "ok"}, + {"display_policy_node_info_labels", handleDisplayPolicyNodeInformation, map[string]any{"labels": "k=v"}, []string{"policy", "get", "k=v"}, "ok"}, + {"delete_policy_rules_all", handleDeletePolicyRules, map[string]any{"all": "true"}, []string{"policy", "delete", "--all"}, "ok"}, + {"delete_policy_rules_labels", handleDeletePolicyRules, map[string]any{"labels": "k=v"}, []string{"policy", "delete", "k=v"}, "ok"}, + {"update_xdp_cidr", handleUpdateXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8"}, []string{"prefilter", "update", "--cidr", "10.0.0.0/8"}, "ok"}, + {"update_xdp_cidr_rev", handleUpdateXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8", "revision": "2"}, []string{"prefilter", "update", "--cidr", "10.0.0.0/8", "--revision", "2"}, "ok"}, + {"delete_xdp_cidr", handleDeleteXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8"}, []string{"prefilter", "delete", "--cidr", "10.0.0.0/8"}, "ok"}, + {"delete_xdp_cidr_rev", handleDeleteXDPCIDRFilters, map[string]any{"cidr_prefixes": "10.0.0.0/8", "revision": "2"}, []string{"prefilter", "delete", "--cidr", "10.0.0.0/8", "--revision", "2"}, "ok"}, + {"validate_cnp", handleValidateCiliumNetworkPolicies, map[string]any{"enable_k8s": "true", "enable_k8s_api_discovery": "true"}, []string{"preflight", "validate-cnp", "--enable-k8s", "--enable-k8s-api-discovery"}, "ok"}, + {"list_pcap_recorders", handleListPCAPRecorders, map[string]any{}, []string{"recorder", "list"}, "ok"}, + {"get_pcap_recorder", handleGetPCAPRecorder, map[string]any{"recorder_id": "1"}, []string{"recorder", "get", "1"}, "ok"}, + {"delete_pcap_recorder", handleDeletePCAPRecorder, map[string]any{"recorder_id": "1"}, []string{"recorder", "delete", "1"}, "ok"}, + {"update_pcap_recorder", handleUpdatePCAPRecorder, map[string]any{"recorder_id": "1", "filters": "f"}, []string{"recorder", "update", "1", "--filters", "f", "--caplen", "0", "--id", "0"}, "ok"}, + {"get_service_information", handleGetServiceInformation, map[string]any{"service_id": "5"}, []string{"service", "get", "5"}, "ok"}, + {"delete_service_all", handleDeleteService, map[string]any{"all": "true"}, []string{"service", "delete", "--all"}, "ok"}, + {"delete_service_id", handleDeleteService, map[string]any{"service_id": "5"}, []string{"service", "delete", "5"}, "ok"}, + {"update_service", handleUpdateService, map[string]any{"backends": "b", "frontend": "f", "id": "1"}, []string{"service", "update", "1", "--backends", "b", "--frontend", "f", "--protocol", "TCP", "--states", "active"}, "ok"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mockCiliumDbgCommand(mock, tc.dbgArgs, tc.expect, nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + tc.args["node_name"] = "test-node" + result, err := tc.handler(ctx, newRequestWithArgs(tc.args)) + require.NoError(t, err) + assert.False(t, result.IsError, "handler returned error result: %s", getResultText(result)) + assert.Contains(t, getResultText(result), tc.expect) + }) + } +} + +// TestCiliumDbgHandlersMissingParams covers required-parameter validation branches. +func TestCiliumDbgHandlersMissingParams(t *testing.T) { + cases := []struct { + name string + handler ciliumHandler + args map[string]any + }{ + {"manage_endpoint_labels", handleManageEndpointLabels, map[string]any{}}, + {"manage_endpoint_configuration_no_id", handleManageEndpointConfiguration, map[string]any{}}, + {"manage_endpoint_configuration_no_config", handleManageEndpointConfiguration, map[string]any{"endpoint_id": "34"}}, + {"disconnect_endpoint", handleDisconnectEndpoint, map[string]any{}}, + {"get_identity_details", handleGetIdentityDetails, map[string]any{}}, + {"list_envoy_config", handleListEnvoyConfig, map[string]any{}}, + {"show_ipcache_none", handleShowIPCacheInformation, map[string]any{}}, + {"delete_kvstore_key", handleDeleteKeyFromKVStore, map[string]any{}}, + {"get_kvstore_key", handleGetKVStoreKey, map[string]any{}}, + {"set_kvstore_key", handleSetKVStoreKey, map[string]any{"key": "foo"}}, + {"delete_policy_rules_none", handleDeletePolicyRules, map[string]any{}}, + {"update_xdp_cidr", handleUpdateXDPCIDRFilters, map[string]any{}}, + {"delete_xdp_cidr", handleDeleteXDPCIDRFilters, map[string]any{}}, + {"get_pcap_recorder", handleGetPCAPRecorder, map[string]any{}}, + {"delete_pcap_recorder", handleDeletePCAPRecorder, map[string]any{}}, + {"update_pcap_recorder", handleUpdatePCAPRecorder, map[string]any{"recorder_id": "1"}}, + {"get_service_information", handleGetServiceInformation, map[string]any{}}, + {"delete_service_none", handleDeleteService, map[string]any{}}, + {"update_service", handleUpdateService, map[string]any{"backends": "b"}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := tc.handler(ctx, newRequestWithArgs(tc.args)) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Empty(t, mock.GetCallLog()) + }) + } +} + +// TestCiliumCliHandlers covers the cilium-CLI based handlers. +func TestCiliumCliHandlers(t *testing.T) { + cases := []struct { + name string + handler ciliumHandler + args map[string]any + cliArgs []string + }{ + {"show_cluster_mesh_status", handleShowClusterMeshStatus, map[string]any{}, []string{"clustermesh", "status"}}, + {"show_features_status", handleShowFeaturesStatus, map[string]any{}, []string{"features", "status"}}, + {"toggle_cluster_mesh_enable", handleToggleClusterMesh, map[string]any{"enable": "true"}, []string{"clustermesh", "enable"}}, + {"toggle_cluster_mesh_disable", handleToggleClusterMesh, map[string]any{"enable": "false"}, []string{"clustermesh", "disable"}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", tc.cliArgs, "cli-ok", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := tc.handler(ctx, newRequestWithArgs(tc.args)) + require.NoError(t, err) + assert.False(t, result.IsError) + assert.Contains(t, getResultText(result), "cli-ok") + }) + } +} + +func TestCiliumCliHandlersError(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("cilium", []string{"clustermesh", "status"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleShowClusterMeshStatus(ctx, newRequestWithArgs(map[string]any{})) + require.NoError(t, err) + assert.True(t, result.IsError) + assert.Contains(t, getResultText(result), "Error getting cluster mesh status") +} + +// TestCiliumDbgHandlerError covers the error path shared by dbg handlers. +func TestCiliumDbgHandlerError(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mockCiliumDbgCommand(mock, []string{"loadinfo"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleShowLoadInformation(ctx, newRequestWithArgs(map[string]any{"node_name": "test-node"})) + require.NoError(t, err) + assert.True(t, result.IsError) +} diff --git a/pkg/istio/istio_test.go b/pkg/istio/istio_test.go index 36abeef9..4eacea90 100644 --- a/pkg/istio/istio_test.go +++ b/pkg/istio/istio_test.go @@ -355,3 +355,143 @@ func TestIstioErrorHandling(t *testing.T) { assert.True(t, result.IsError) }) } + +func TestHandleWaypointApply(t *testing.T) { + t.Run("basic apply", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, "applied", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default"} + result, err := handleWaypointApply(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("apply with enroll-namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default", "--enroll-namespace"}, "applied", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default", "enroll_namespace": "true"} + result, err := handleWaypointApply(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("missing namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleWaypointApply(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, result.IsError) + }) + + t.Run("command failure", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "apply", "-n", "default"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default"} + result, err := handleWaypointApply(ctx, req) + require.NoError(t, err) + assert.True(t, result.IsError) + }) +} + +func TestHandleWaypointDelete(t *testing.T) { + t.Run("delete all", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "delete", "--all", "-n", "default"}, "deleted", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default", "all": "true"} + result, err := handleWaypointDelete(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("delete by names", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "delete", "wp1", "wp2", "-n", "default"}, "deleted", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default", "names": "wp1, wp2"} + result, err := handleWaypointDelete(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("missing namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleWaypointDelete(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, result.IsError) + }) +} + +func TestHandleWaypointStatus(t *testing.T) { + t.Run("status with name", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "status", "wp1", "-n", "default"}, "status", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default", "name": "wp1"} + result, err := handleWaypointStatus(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("status without name", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"waypoint", "status", "-n", "default"}, "status", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"namespace": "default"} + result, err := handleWaypointStatus(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("missing namespace", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleWaypointStatus(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, result.IsError) + }) +} + +func TestHandleZtunnelConfig(t *testing.T) { + t.Run("default config type", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"ztunnel", "config", "all"}, "ztunnel config", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleZtunnelConfig(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("with namespace and config type", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"ztunnel", "config", "workloads", "-n", "istio-system"}, "ztunnel config", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"config_type": "workloads", "namespace": "istio-system"} + result, err := handleZtunnelConfig(ctx, req) + require.NoError(t, err) + assert.False(t, result.IsError) + }) + + t.Run("command failure", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("istioctl", []string{"ztunnel", "config", "all"}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + result, err := handleZtunnelConfig(ctx, mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, result.IsError) + }) +} diff --git a/pkg/k8s/k8s_test.go b/pkg/k8s/k8s_test.go index 44df8a92..bc976010 100644 --- a/pkg/k8s/k8s_test.go +++ b/pkg/k8s/k8s_test.go @@ -7,11 +7,28 @@ import ( "github.com/kagent-dev/tools/internal/cmd" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/llms" ) +func TestRegisterTools(t *testing.T) { + t.Run("read-write", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, nil, "", false) + }) + t.Run("read-only", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, nil, "/tmp/kubeconfig", true) + }) +} + +func TestNewK8sToolWithConfig(t *testing.T) { + tool := NewK8sToolWithConfig("/tmp/kc", nil) + assert.Equal(t, "/tmp/kc", tool.kubeconfig) +} + // Helper function to create a test K8sTool func newTestK8sTool() *K8sTool { return NewK8sTool(nil) diff --git a/pkg/prometheus/prometheus_test.go b/pkg/prometheus/prometheus_test.go index 647d1f39..1e8ffc49 100644 --- a/pkg/prometheus/prometheus_test.go +++ b/pkg/prometheus/prometheus_test.go @@ -8,9 +8,125 @@ import ( "testing" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" ) +func TestRegisterTools(t *testing.T) { + t.Run("read-write", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, false) + }) + t.Run("read-only", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, true) + }) +} + +func TestPrometheusInputValidation(t *testing.T) { + ctx := context.Background() + + invalidURL := map[string]interface{}{"prometheus_url": "not a url", "query": "up"} + invalidQuery := map[string]interface{}{"prometheus_url": "http://localhost:9090", "query": "up; drop"} + + t.Run("query invalid url", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = invalidURL + res, err := handlePrometheusQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("range invalid url", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = invalidURL + res, err := handlePrometheusRangeQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("labels invalid url", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"prometheus_url": "not a url"} + res, err := handlePrometheusLabelsQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("targets invalid url", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"prometheus_url": "not a url"} + res, err := handlePrometheusTargetsQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("query invalid promql", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = invalidQuery + res, err := handlePrometheusQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("range invalid promql", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = invalidQuery + res, err := handlePrometheusRangeQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) +} + +func TestPrometheusLabelsTargetsErrorPaths(t *testing.T) { + args := map[string]interface{}{"prometheus_url": "http://localhost:9090"} + + t.Run("labels client error", func(t *testing.T) { + ctx := contextWithMockClient(newTestClient(nil, assert.AnError)) + req := mcp.CallToolRequest{} + req.Params.Arguments = args + res, err := handlePrometheusLabelsQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("labels malformed json", func(t *testing.T) { + ctx := contextWithMockClient(newTestClient(createMockResponse(200, "not json"), nil)) + req := mcp.CallToolRequest{} + req.Params.Arguments = args + res, err := handlePrometheusLabelsQueryTool(ctx, req) + assert.NoError(t, err) + assert.False(t, res.IsError) + assert.Contains(t, getResultText(res), "not json") + }) + + t.Run("targets client error", func(t *testing.T) { + ctx := contextWithMockClient(newTestClient(nil, assert.AnError)) + req := mcp.CallToolRequest{} + req.Params.Arguments = args + res, err := handlePrometheusTargetsQueryTool(ctx, req) + assert.NoError(t, err) + assert.True(t, res.IsError) + }) + + t.Run("targets malformed json", func(t *testing.T) { + ctx := contextWithMockClient(newTestClient(createMockResponse(200, "not json"), nil)) + req := mcp.CallToolRequest{} + req.Params.Arguments = args + res, err := handlePrometheusTargetsQueryTool(ctx, req) + assert.NoError(t, err) + assert.False(t, res.IsError) + assert.Contains(t, getResultText(res), "not json") + }) +} + +func TestGetHTTPClientDefault(t *testing.T) { + assert.Equal(t, http.DefaultClient, getHTTPClient(context.Background())) + custom := &http.Client{} + ctx := context.WithValue(context.Background(), clientKey{}, custom) + assert.Equal(t, custom, getHTTPClient(ctx)) +} + // mockRoundTripper is used to mock HTTP responses for testing type mockRoundTripper struct { response *http.Response diff --git a/pkg/utils/common.go b/pkg/utils/common.go index 23070889..f149d013 100644 --- a/pkg/utils/common.go +++ b/pkg/utils/common.go @@ -66,6 +66,20 @@ func shellTool(ctx context.Context, params shellParams) (string, error) { return commands.NewCommandBuilder(cmd).WithArgs(args...).Execute(ctx) } +func handleShellTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + command := mcp.ParseString(request, "command", "") + if command == "" { + return mcp.NewToolResultError("command parameter is required"), nil + } + + result, err := shellTool(ctx, shellParams{Command: command}) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return mcp.NewToolResultText(result), nil +} + // handleGetCurrentDateTimeTool provides datetime functionality for both MCP and testing func handleGetCurrentDateTimeTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Returns the current date and time in ISO 8601 format (RFC3339) @@ -82,20 +96,7 @@ func RegisterTools(s *server.MCPServer, readOnly bool) { s.AddTool(mcp.NewTool("shell", mcp.WithDescription("Execute shell commands"), mcp.WithString("command", mcp.Description("The shell command to execute"), mcp.Required()), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - command := mcp.ParseString(request, "command", "") - if command == "" { - return mcp.NewToolResultError("command parameter is required"), nil - } - - params := shellParams{Command: command} - result, err := shellTool(ctx, params) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - return mcp.NewToolResultText(result), nil - }) + ), handleShellTool) } // Register datetime tool diff --git a/pkg/utils/common_test.go b/pkg/utils/common_test.go new file mode 100644 index 00000000..6502225a --- /dev/null +++ b/pkg/utils/common_test.go @@ -0,0 +1,123 @@ +package utils + +import ( + "context" + "testing" + + "github.com/kagent-dev/tools/internal/cmd" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKubeconfigManager(t *testing.T) { + // Preserve and restore global state to avoid cross-test interference. + original := GetKubeconfig() + t.Cleanup(func() { SetKubeconfig(original) }) + + t.Run("set and get", func(t *testing.T) { + SetKubeconfig("/tmp/my-kubeconfig") + assert.Equal(t, "/tmp/my-kubeconfig", GetKubeconfig()) + }) + + t.Run("AddKubeconfigArgs with path set", func(t *testing.T) { + SetKubeconfig("/tmp/kc") + got := AddKubeconfigArgs([]string{"get", "pods"}) + assert.Equal(t, []string{"--kubeconfig", "/tmp/kc", "get", "pods"}, got) + }) + + t.Run("AddKubeconfigArgs with empty path", func(t *testing.T) { + SetKubeconfig("") + got := AddKubeconfigArgs([]string{"get", "pods"}) + assert.Equal(t, []string{"get", "pods"}, got) + }) +} + +func TestShellTool(t *testing.T) { + t.Run("executes command", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("echo", []string{"hello"}, "hello\n", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + out, err := shellTool(ctx, shellParams{Command: "echo hello"}) + require.NoError(t, err) + assert.Equal(t, "hello\n", out) + + callLog := mock.GetCallLog() + require.Len(t, callLog, 1) + assert.Equal(t, "echo", callLog[0].Command) + assert.Equal(t, []string{"hello"}, callLog[0].Args) + }) + + t.Run("empty command", func(t *testing.T) { + _, err := shellTool(context.Background(), shellParams{Command: " "}) + require.Error(t, err) + assert.Contains(t, err.Error(), "empty command") + }) + + t.Run("command failure propagates", func(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("false", []string{}, "", assert.AnError) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + _, err := shellTool(ctx, shellParams{Command: "false"}) + require.Error(t, err) + }) +} + +func TestRegisterTools(t *testing.T) { + t.Run("read-write registers shell", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, false) + }) + + t.Run("read-only omits shell", func(t *testing.T) { + s := server.NewMCPServer("test", "v0.0.1") + RegisterTools(s, true) + }) +} + +func TestHandleShellTool(t *testing.T) { + mock := cmd.NewMockShellExecutor() + mock.AddCommandString("echo", []string{"hi"}, "hi\n", nil) + ctx := cmd.WithShellExecutor(context.Background(), mock) + + t.Run("success", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"command": "echo hi"} + res, err := handleShellTool(ctx, req) + require.NoError(t, err) + assert.False(t, res.IsError) + }) + + t.Run("missing command", func(t *testing.T) { + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{} + res, err := handleShellTool(ctx, req) + require.NoError(t, err) + assert.True(t, res.IsError) + assert.Contains(t, getResultText(res), "command parameter is required") + }) + + t.Run("command error", func(t *testing.T) { + m := cmd.NewMockShellExecutor() + m.AddCommandString("false", []string{}, "", assert.AnError) + errCtx := cmd.WithShellExecutor(context.Background(), m) + req := mcp.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"command": "false"} + res, err := handleShellTool(errCtx, req) + require.NoError(t, err) + assert.True(t, res.IsError) + }) +} + +func getResultText(result *mcp.CallToolResult) string { + if result == nil || len(result.Content) == 0 { + return "" + } + if textContent, ok := result.Content[0].(mcp.TextContent); ok { + return textContent.Text + } + return "" +}