diff --git a/mcp_server/pkg/tools/internal/shared/features.go b/mcp_server/pkg/tools/internal/shared/features.go index 07f36c7f4..68ff29933 100644 --- a/mcp_server/pkg/tools/internal/shared/features.go +++ b/mcp_server/pkg/tools/internal/shared/features.go @@ -8,7 +8,10 @@ import ( "github.com/semaphoreio/semaphore/mcp_server/pkg/internalapi" ) -const readToolsFeatureFlag = "mcp_server_read_tools" +const ( + readToolsFeatureFlag = "mcp_server_read_tools" + writeToolsFeatureFlag = "mcp_server_write_tools" +) // EnsureReadToolsFeature verifies that the organization has the AI read tools feature enabled. func EnsureReadToolsFeature(ctx context.Context, api internalapi.Provider, orgID string) error { @@ -23,7 +26,26 @@ func EnsureReadToolsFeature(ctx context.Context, api internalapi.Provider, orgID } if state != feature.Enabled { - return fmt.Errorf("Semaphore MCP tools are disabled for this organization. Please contact support if you believe this is an error.") + return fmt.Errorf("Semaphore MCP read tools are disabled for this organization. Please contact support if you believe this is an error.") + } + + return nil +} + +// EnsureWriteToolsFeature verifies that the organization has the AI write tools feature enabled. +func EnsureWriteToolsFeature(ctx context.Context, api internalapi.Provider, orgID string) error { + featureClient := api.Features() + if featureClient == nil { + return fmt.Errorf("Semaphore MCP tools are temporarily unavailable. Please try again later.") + } + + state, err := featureClient.FeatureState(orgID, writeToolsFeatureFlag) + if err != nil { + return fmt.Errorf("We couldn't verify access to Semaphore MCP tools right now. Please try again in a few moments.") + } + + if state != feature.Enabled { + return fmt.Errorf("Semaphore MCP write tools are disabled for this organization. Please contact support if you believe this is an error.") } return nil diff --git a/mcp_server/pkg/tools/workflows/helpers.go b/mcp_server/pkg/tools/workflows/helpers.go new file mode 100644 index 000000000..d7aace3f5 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/helpers.go @@ -0,0 +1,67 @@ +package workflows + +import "strings" + +// summary represents workflow metadata returned by workflows_search. +type summary struct { + ID string `json:"id"` + InitialPipeline string `json:"initialPipelineId,omitempty"` + ProjectID string `json:"projectId,omitempty"` + OrganizationID string `json:"organizationId,omitempty"` + Branch string `json:"branch,omitempty"` + CommitSHA string `json:"commitSha,omitempty"` + RequesterID string `json:"requesterId,omitempty"` + TriggeredBy string `json:"triggeredBy,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` + RerunOf string `json:"rerunOf,omitempty"` + RepositoryID string `json:"repositoryId,omitempty"` +} + +type listResult struct { + Workflows []summary `json:"workflows"` + NextCursor string `json:"nextCursor,omitempty"` +} + +type runResult struct { + WorkflowID string `json:"workflowId"` + PipelineID string `json:"pipelineId"` + Reference string `json:"reference"` + CommitSHA string `json:"commitSha,omitempty"` + PipelineFile string `json:"pipelineFile"` +} + +type rerunResult struct { + WorkflowID string `json:"workflowId"` + PipelineID string `json:"pipelineId"` + RerunOf string `json:"rerunOf"` + ProjectID string `json:"projectId"` + OrgID string `json:"organizationId"` +} + +func humanizeTriggeredBy(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "Unspecified" + } + parts := strings.Split(value, "_") + for i, part := range parts { + if part == "" { + continue + } + part = strings.ToLower(part) + parts[i] = strings.ToUpper(part[:1]) + part[1:] + } + return strings.Join(parts, " ") +} + +func shortenCommit(sha string) string { + sha = strings.TrimSpace(sha) + if len(sha) > 12 { + return sha[:12] + } + return sha +} + +func normalizeID(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} diff --git a/mcp_server/pkg/tools/workflows/register.go b/mcp_server/pkg/tools/workflows/register.go new file mode 100644 index 000000000..585a5b845 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/register.go @@ -0,0 +1,26 @@ +package workflows + +import ( + "github.com/mark3labs/mcp-go/server" + + "github.com/semaphoreio/semaphore/mcp_server/pkg/internalapi" +) + +const ( + searchToolName = "workflows_search" + runToolName = "workflows_run" + rerunToolName = "workflows_rerun" + defaultLimit = 20 + maxLimit = 100 + missingWorkflowError = "workflow gRPC endpoint is not configured" + projectViewPermission = "project.view" + projectRunPermission = "project.job.rerun" + defaultPipelineFile = ".semaphore/semaphore.yml" +) + +// Register wires workflow tooling into the MCP server. +func Register(s *server.MCPServer, api internalapi.Provider) { + s.AddTool(newSearchTool(searchToolName, searchFullDescription()), listHandler(api)) + s.AddTool(newRunTool(runToolName, runFullDescription()), runHandler(api)) + s.AddTool(newRerunTool(rerunToolName, rerunFullDescription()), rerunHandler(api)) +} diff --git a/mcp_server/pkg/tools/workflows/rerun_tool.go b/mcp_server/pkg/tools/workflows/rerun_tool.go new file mode 100644 index 000000000..3ccb72a45 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/rerun_tool.go @@ -0,0 +1,195 @@ +package workflows + +import ( + "context" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/sirupsen/logrus" + + "github.com/semaphoreio/semaphore/mcp_server/pkg/authz" + workflowpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/plumber_w_f.workflow" + "github.com/semaphoreio/semaphore/mcp_server/pkg/internalapi" + "github.com/semaphoreio/semaphore/mcp_server/pkg/logging" + "github.com/semaphoreio/semaphore/mcp_server/pkg/tools/internal/shared" +) + +func rerunFullDescription() string { + return `Rerun an existing workflow. + +Use this when you need to: +- Rerun a previously completed or failed workflow +- Restart a workflow without changing its parameters + +Required inputs: +- workflow_id: ID of the workflow to rerun + +The authenticated user must have permission to rerun workflows for the originating project.` +} + +func newRerunTool(name, description string) mcp.Tool { + return mcp.NewTool( + name, + mcp.WithDescription(description), + mcp.WithString( + "workflow_id", + mcp.Required(), + mcp.Description("Workflow ID to rerun."), + ), + mcp.WithIdempotentHintAnnotation(false), + ) +} + +func rerunHandler(api internalapi.Provider) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + workflowClient := api.Workflow() + if workflowClient == nil { + return mcp.NewToolResultError(missingWorkflowError), nil + } + + workflowIDRaw, err := req.RequireString("workflow_id") + if err != nil { + return mcp.NewToolResultError(`Missing required argument: workflow_id. Provide the workflow ID to rerun.`), nil + } + workflowID, err := sanitizeWorkflowID(workflowIDRaw, "workflow_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + userID := strings.ToLower(strings.TrimSpace(req.Header.Get("X-Semaphore-User-ID"))) + if err := shared.ValidateUUID(userID, "x-semaphore-user-id header"); err != nil { + return mcp.NewToolResultError(fmt.Sprintf(`%v + +The authentication layer must inject the X-Semaphore-User-ID header so we can authorize workflow reruns.`, err)), nil + } + + describeCtx, cancelDescribe := context.WithTimeout(ctx, api.CallTimeout()) + defer cancelDescribe() + + describeResp, err := workflowClient.Describe(describeCtx, &workflowpb.DescribeRequest{WfId: workflowID}) + if err != nil { + logging.ForComponent("rpc"). + WithFields(logrus.Fields{ + "rpc": "workflow.Describe", + "wfId": workflowID, + }). + WithError(err). + Error("workflow describe RPC failed") + return mcp.NewToolResultError("Unable to load workflow details. Confirm the workflow exists and try again."), nil + } + + if err := shared.CheckStatus(describeResp.GetStatus()); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Unable to load workflow details: %v", err)), nil + } + + workflow := describeResp.GetWorkflow() + if workflow == nil { + return mcp.NewToolResultError("Workflow details are missing from the response. Please retry."), nil + } + + orgID := strings.TrimSpace(workflow.GetOrganizationId()) + if err := shared.ValidateUUID(orgID, "workflow organization_id"); err != nil { + return mcp.NewToolResultError("Unable to determine workflow organization. Please try again later."), nil + } + + projectID := strings.TrimSpace(workflow.GetProjectId()) + if err := shared.ValidateUUID(projectID, "workflow project_id"); err != nil { + return mcp.NewToolResultError("Unable to determine workflow project. Please try again later."), nil + } + + if err := shared.EnsureWriteToolsFeature(ctx, api, orgID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + tracker := shared.TrackToolExecution(ctx, rerunToolName, orgID) + defer tracker.Cleanup() + + if err := authz.CheckProjectPermission(ctx, api, userID, orgID, projectID, projectRunPermission); err != nil { + return shared.ProjectAuthorizationError(err, orgID, projectID, projectRunPermission), nil + } + + rescheduleCtx, cancelReschedule := context.WithTimeout(ctx, api.CallTimeout()) + defer cancelReschedule() + + requestToken := uuid.NewString() + + rescheduleReq := &workflowpb.RescheduleRequest{ + WfId: workflowID, + RequesterId: userID, + RequestToken: requestToken, + } + + rescheduleResp, err := workflowClient.Reschedule(rescheduleCtx, rescheduleReq) + if err != nil { + logging.ForComponent("rpc"). + WithFields(logrus.Fields{ + "rpc": "workflow.Reschedule", + "wfId": workflowID, + "projectId": projectID, + "orgId": orgID, + }). + WithError(err). + Error("workflow reschedule RPC failed") + return mcp.NewToolResultError("Workflow rerun failed. Confirm the workflow exists and try again."), nil + } + + if err := shared.CheckStatus(rescheduleResp.GetStatus()); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Workflow rerun failed: %v", err)), nil + } + + result := rerunResult{ + WorkflowID: strings.TrimSpace(rescheduleResp.GetWfId()), + PipelineID: strings.TrimSpace(rescheduleResp.GetPplId()), + RerunOf: workflowID, + ProjectID: projectID, + OrgID: orgID, + } + + markdown := formatRerunMarkdown(result) + markdown = shared.TruncateResponse(markdown, shared.MaxResponseChars) + + tracker.MarkSuccess() + return &mcp.CallToolResult{ + Content: []mcp.Content{mcp.NewTextContent(markdown)}, + StructuredContent: result, + }, nil + } +} + +func sanitizeWorkflowID(raw, field string) (string, error) { + value := strings.TrimSpace(raw) + if value == "" { + return "", fmt.Errorf("%s is required", field) + } + if strings.ContainsAny(value, " \t\r\n") { + return "", fmt.Errorf("%s must not contain whitespace", field) + } + if len(value) > 128 { + return "", fmt.Errorf("%s must not exceed 128 characters", field) + } + return value, nil +} + +func formatRerunMarkdown(result rerunResult) string { + mb := shared.NewMarkdownBuilder() + mb.H1("Workflow Rerun Scheduled") + if result.WorkflowID != "" { + mb.KeyValue("Workflow ID", fmt.Sprintf("`%s`", result.WorkflowID)) + } + if result.PipelineID != "" { + mb.KeyValue("Initial Pipeline", fmt.Sprintf("`%s`", result.PipelineID)) + } + if result.RerunOf != "" { + mb.KeyValue("Rerun Of", fmt.Sprintf("`%s`", result.RerunOf)) + } + if result.ProjectID != "" { + mb.KeyValue("Project ID", fmt.Sprintf("`%s`", result.ProjectID)) + } + if result.OrgID != "" { + mb.KeyValue("Organization ID", fmt.Sprintf("`%s`", result.OrgID)) + } + return mb.String() +} diff --git a/mcp_server/pkg/tools/workflows/rerun_tool_test.go b/mcp_server/pkg/tools/workflows/rerun_tool_test.go new file mode 100644 index 000000000..ba96d36b6 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/rerun_tool_test.go @@ -0,0 +1,164 @@ +package workflows + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/semaphoreio/semaphore/mcp_server/pkg/feature" + workflowpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/plumber_w_f.workflow" + projecthubpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/projecthub" + statuspb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/status" + support "github.com/semaphoreio/semaphore/mcp_server/test/support" + + code "google.golang.org/genproto/googleapis/rpc/code" +) + +func TestRerunWorkflowSuccess(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + workflowID := "wf-123" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + + workflowStub := &support.WorkflowClientStub{ + DescribeResp: &workflowpb.DescribeResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + Workflow: &workflowpb.WorkflowDetails{ + WfId: workflowID, + ProjectId: projectID, + OrganizationId: orgID, + }, + }, + RescheduleResp: &workflowpb.ScheduleResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + WfId: "wf-new", + PplId: "ppl-new", + }, + } + provider := &support.MockProvider{ + WorkflowClient: workflowStub, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "workflow_id": workflowID, + }, + }, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := rerunHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + result, ok := res.StructuredContent.(rerunResult) + if !ok { + toFail(t, "unexpected structured content type: %T", res.StructuredContent) + } + if result.WorkflowID != "wf-new" || result.PipelineID != "ppl-new" { + toFail(t, "unexpected rerun result: %+v", result) + } + if result.RerunOf != workflowID { + toFail(t, "expected rerunOf to match workflow id, got %s", result.RerunOf) + } + if workflowStub.LastDescribe == nil || workflowStub.LastDescribe.GetWfId() != workflowID { + toFail(t, "expected describe call for workflow") + } + if workflowStub.LastReschedule == nil { + toFail(t, "expected reschedule call to be recorded") + } + if got := workflowStub.LastReschedule.GetRequesterId(); got != userID { + toFail(t, "unexpected requester id: %s", got) + } +} + +func TestRerunWorkflowFeatureDisabled(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + workflowStub := &support.WorkflowClientStub{ + DescribeResp: &workflowpb.DescribeResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + Workflow: &workflowpb.WorkflowDetails{ + WfId: "wf-123", + ProjectId: projectID, + OrganizationId: orgID, + }, + }, + } + provider := &support.MockProvider{ + WorkflowClient: workflowStub, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + FeaturesService: support.FeatureClientStub{State: feature.Hidden}, + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "workflow_id": "wf-123", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", "99999999-aaaa-bbbb-cccc-dddddddddddd") + req.Header = header + + res, err := rerunHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(strings.ToLower(msg), "disabled") { + toFail(t, "expected feature disabled message, got %q", msg) + } +} + +func TestRerunWorkflowPermissionDenied(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + workflowStub := &support.WorkflowClientStub{ + DescribeResp: &workflowpb.DescribeResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + Workflow: &workflowpb.WorkflowDetails{ + WfId: "wf-123", + ProjectId: projectID, + OrganizationId: orgID, + }, + }, + } + provider := &support.MockProvider{ + WorkflowClient: workflowStub, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(), // no permissions + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "workflow_id": "wf-123", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", "99999999-aaaa-bbbb-cccc-dddddddddddd") + req.Header = header + + res, err := rerunHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "Permission denied while accessing project") { + toFail(t, "expected permission denied message, got %q", msg) + } + if workflowStub.LastReschedule != nil { + toFail(t, "workflow reschedule should not have been invoked when permission is missing") + } +} diff --git a/mcp_server/pkg/tools/workflows/run_tool.go b/mcp_server/pkg/tools/workflows/run_tool.go new file mode 100644 index 000000000..5e93c8f36 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/run_tool.go @@ -0,0 +1,547 @@ +package workflows + +import ( + "context" + "fmt" + "regexp" + "sort" + "strconv" + "strings" + "unicode/utf8" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/sirupsen/logrus" + + "github.com/semaphoreio/semaphore/mcp_server/pkg/authz" + workflowpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/plumber_w_f.workflow" + projecthubpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/projecthub" + repopb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/repository_integrator" + "github.com/semaphoreio/semaphore/mcp_server/pkg/internalapi" + "github.com/semaphoreio/semaphore/mcp_server/pkg/logging" + "github.com/semaphoreio/semaphore/mcp_server/pkg/tools/internal/shared" +) + +func runFullDescription() string { + return `Schedule a new workflow run for a project. + +Use this when you need to: +- Kick off a pipeline with a specific branch, tag, or commit +- Trigger a workflow with custom parameters without using the UI + +Required inputs: +- organization_id: Organization UUID that owns the project +- project_id: Project UUID where the workflow should run +- reference: Git reference (branch, tag, or pull request), e.g. "refs/heads/main", "refs/tags/v1.0", or "refs/pull/42" + +Optional inputs: +- commit_sha: Pin the run to a specific commit +- pipeline_file: Override the pipeline definition path (defaults to the project's configured file) +- parameters: A key/value map of parameters to expose as environment variables (values convert to strings) + +The authenticated user must have permissions to run workflows in the specified project.` +} + +func newRunTool(name, description string) mcp.Tool { + return mcp.NewTool( + name, + mcp.WithDescription(description), + mcp.WithString( + "project_id", + mcp.Required(), + mcp.Description("Project UUID where the workflow should run."), + mcp.Pattern(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`), + ), + mcp.WithString( + "organization_id", + mcp.Required(), + mcp.Description("Organization UUID that owns the project."), + mcp.Pattern(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`), + ), + mcp.WithString( + "reference", + mcp.Required(), + mcp.Description("Git reference to run (branch, tag, or pull request, refs/... pattern)."), + ), + mcp.WithString( + "commit_sha", + mcp.Description("Optional commit SHA to pin the workflow run."), + ), + mcp.WithString( + "pipeline_file", + mcp.Description("Optional pipeline definition YAML file path within the repository."), + ), + mcp.WithObject( + "parameters", + mcp.Description("Optional key/value parameters exposed as environment variables."), + mcp.AdditionalProperties(map[string]any{ + "oneOf": []any{ + map[string]any{"type": "string"}, + map[string]any{"type": "number"}, + map[string]any{"type": "boolean"}, + map[string]any{"type": "null"}, + }, + }), + ), + mcp.WithIdempotentHintAnnotation(false), + ) +} + +func runHandler(api internalapi.Provider) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + orgIDRaw, err := req.RequireString("organization_id") + if err != nil { + return mcp.NewToolResultError(`Missing required argument: organization_id. Provide the organization UUID returned by organizations_list.`), nil + } + orgID := strings.TrimSpace(orgIDRaw) + if err := shared.ValidateUUID(orgID, "organization_id"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if err := shared.EnsureWriteToolsFeature(ctx, api, orgID); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + tracker := shared.TrackToolExecution(ctx, runToolName, orgID) + defer tracker.Cleanup() + + workflowClient := api.Workflow() + if workflowClient == nil { + return mcp.NewToolResultError(missingWorkflowError), nil + } + projectClient := api.Projects() + if projectClient == nil { + return mcp.NewToolResultError("project gRPC endpoint is not configured"), nil + } + + projectIDRaw, err := req.RequireString("project_id") + if err != nil { + return mcp.NewToolResultError(`Missing required argument: project_id. Provide the project UUID (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx).`), nil + } + projectID := strings.TrimSpace(projectIDRaw) + if err := shared.ValidateUUID(projectID, "project_id"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + userID := strings.ToLower(strings.TrimSpace(req.Header.Get("X-Semaphore-User-ID"))) + if err := shared.ValidateUUID(userID, "x-semaphore-user-id header"); err != nil { + return mcp.NewToolResultError(fmt.Sprintf(`%v + +The authentication layer must inject the X-Semaphore-User-ID header so we can authorize workflow runs.`, err)), nil + } + + if err := authz.CheckProjectPermission(ctx, api, userID, orgID, projectID, projectRunPermission); err != nil { + return shared.ProjectAuthorizationError(err, orgID, projectID, projectRunPermission), nil + } + + referenceRaw, err := req.RequireString("reference") + if err != nil { + return mcp.NewToolResultError(`Missing required argument: reference. Provide the branch or git ref to run.`), nil + } + reference, err := sanitizeGitReference(referenceRaw, "reference") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + commitSHA := strings.TrimSpace(req.GetString("commit_sha", "")) + if err := validateCommitSHA(commitSHA, "commit_sha"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + pipelineFileInput := strings.TrimSpace(req.GetString("pipeline_file", "")) + if err := validatePipelineFile(pipelineFileInput, "pipeline_file"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + parameters, err := extractParameters(req.GetArguments()["parameters"]) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + describeReq := projectDescribeRequest(projectID, orgID, userID) + callCtx, cancel := context.WithTimeout(ctx, api.CallTimeout()) + defer cancel() + + describeResp, err := projectClient.Describe(callCtx, describeReq) + if err != nil { + logging.ForComponent("rpc"). + WithFields(logrus.Fields{ + "rpc": "project.Describe", + "projectId": projectID, + "orgId": orgID, + }). + WithError(err). + Error("project describe RPC failed") + return mcp.NewToolResultError("Unable to load project details. Please confirm the project exists and retry."), nil + } + + project, err := validateProjectDescribeResponse(describeResp, orgID, projectID) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + spec := project.GetSpec() + if spec == nil { + return mcp.NewToolResultError("Project specification is missing. Please try again once the project is fully initialized."), nil + } + + repo := spec.GetRepository() + if repo == nil { + return mcp.NewToolResultError("Project repository configuration is missing. Configure the repository before scheduling workflows."), nil + } + + pipelineFile := pipelineFileInput + if pipelineFile == "" { + pipelineFile = strings.TrimSpace(repo.GetPipelineFile()) + if pipelineFile == "" { + pipelineFile = defaultPipelineFile + } + } + if err := validatePipelineFile(pipelineFile, "pipeline_file"); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + envVars, err := buildEnvVars(parameters) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + requestToken := uuid.NewString() + branchName := branchNameFromReference(reference) + label := labelFromReference(reference) + gitReference := ensureGitReference(reference) + + serviceType, err := mapIntegrationType(repo.GetIntegrationType()) + if err != nil { + logging.ForComponent("workflows"). + WithFields(logrus.Fields{ + "projectId": projectID, + "orgId": orgID, + "integrationType": repo.GetIntegrationType(), + }). + WithError(err). + Error("unsupported repository integration type") + return mcp.NewToolResultError("Project repository integration type is not supported. Please contact support."), nil + } + + scheduleReq := &workflowpb.ScheduleRequest{ + ProjectId: projectID, + OrganizationId: orgID, + RequesterId: userID, + DefinitionFile: pipelineFile, + RequestToken: requestToken, + GitReference: gitReference, + Label: label, + TriggeredBy: workflowpb.TriggeredBy_API, + StartInConceivedState: true, + Service: serviceType, + EnvVars: envVars, + Repo: &workflowpb.ScheduleRequest_Repo{ + Owner: strings.TrimSpace(repo.GetOwner()), + RepoName: strings.TrimSpace(repo.GetName()), + BranchName: branchName, + CommitSha: commitSHA, + RepositoryId: strings.TrimSpace(repo.GetId()), + }, + } + scheduleCtx, cancelSchedule := context.WithTimeout(ctx, api.CallTimeout()) + defer cancelSchedule() + + scheduleResp, err := workflowClient.Schedule(scheduleCtx, scheduleReq) + if err != nil { + logging.ForComponent("rpc"). + WithFields(logrus.Fields{ + "rpc": "workflow.Schedule", + "projectId": projectID, + "orgId": orgID, + "reference": reference, + }). + WithError(err). + Error("workflow schedule RPC failed") + return mcp.NewToolResultError("Workflow schedule failed. Verify the repository settings and try again."), nil + } + + if err := shared.CheckStatus(scheduleResp.GetStatus()); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Workflow schedule failed: %v", err)), nil + } + + result := runResult{ + WorkflowID: strings.TrimSpace(scheduleResp.GetWfId()), + PipelineID: strings.TrimSpace(scheduleResp.GetPplId()), + Reference: gitReference, + CommitSHA: commitSHA, + PipelineFile: pipelineFile, + } + + markdown := formatRunMarkdown(result) + markdown = shared.TruncateResponse(markdown, shared.MaxResponseChars) + + tracker.MarkSuccess() + return &mcp.CallToolResult{ + Content: []mcp.Content{mcp.NewTextContent(markdown)}, + StructuredContent: result, + }, nil + } +} + +var ( + commitPattern = regexp.MustCompile(`^[0-9a-f]{7,64}$`) + parameterPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) +) + +func sanitizeGitReference(raw, field string) (string, error) { + value := strings.TrimSpace(raw) + if value == "" { + return "", fmt.Errorf("%s is required", field) + } + return shared.SanitizeBranch(value, field) +} + +func validateCommitSHA(value, field string) error { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + if len(value) > 64 { + return fmt.Errorf("%s must not exceed 64 characters", field) + } + if !commitPattern.MatchString(strings.ToLower(value)) { + return fmt.Errorf("%s must be a hexadecimal SHA (7-64 characters)", field) + } + return nil +} + +func validatePipelineFile(value, field string) error { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + length := utf8.RuneCountInString(value) + if length > 512 { + return fmt.Errorf("%s must not exceed 512 characters", field) + } + for _, r := range value { + if r < 32 || r == 127 { + return fmt.Errorf("%s contains control characters", field) + } + if r == '\\' { + return fmt.Errorf("%s must not contain backslashes", field) + } + } + if strings.Contains(value, "..") { + return fmt.Errorf("%s must not contain '..' sequences", field) + } + if strings.HasPrefix(value, "/") { + return fmt.Errorf("%s must be a relative path", field) + } + return nil +} + +func extractParameters(raw any) (map[string]any, error) { + if raw == nil { + return nil, nil + } + params, ok := raw.(map[string]any) + if !ok { + return nil, fmt.Errorf("parameters must be a key/value map with string keys") + } + return params, nil +} + +func buildEnvVars(params map[string]any) ([]*workflowpb.ScheduleRequest_EnvVar, error) { + if len(params) == 0 { + return nil, nil + } + names := make([]string, 0, len(params)) + for name := range params { + names = append(names, name) + } + sort.Strings(names) + result := make([]*workflowpb.ScheduleRequest_EnvVar, 0, len(names)) + for _, name := range names { + clean := strings.TrimSpace(name) + if clean == "" { + return nil, fmt.Errorf("parameter names must not be empty") + } + if err := validateParameterName(clean); err != nil { + return nil, err + } + value, err := parameterValueToString(params[name]) + if err != nil { + return nil, err + } + result = append(result, &workflowpb.ScheduleRequest_EnvVar{Name: clean, Value: value}) + } + return result, nil +} + +func validateParameterName(name string) error { + if utf8.RuneCountInString(name) > 128 { + return fmt.Errorf("parameter names must not exceed 128 characters") + } + for _, r := range name { + if r < 32 || r == 127 { + return fmt.Errorf("parameter %q contains control characters", name) + } + } + if !parameterPattern.MatchString(name) { + return fmt.Errorf("parameter %q must start with a letter or underscore, followed by letters, digits, or underscores", name) + } + return nil +} + +func parameterValueToString(value any) (string, error) { + switch v := value.(type) { + case nil: + return "", nil + case string: + return v, nil + case bool: + if v { + return "true", nil + } + return "false", nil + case float64: + return strconv.FormatFloat(v, 'f', -1, 64), nil + case int: + return strconv.Itoa(v), nil + case int32: + return strconv.FormatInt(int64(v), 10), nil + case int64: + return strconv.FormatInt(v, 10), nil + case uint32: + return strconv.FormatUint(uint64(v), 10), nil + case uint64: + return strconv.FormatUint(v, 10), nil + default: + return "", fmt.Errorf("parameters values must be strings, numbers, booleans, or null") + } +} + +func projectDescribeRequest(projectID, orgID, userID string) *projecthubpb.DescribeRequest { + return &projecthubpb.DescribeRequest{ + Id: projectID, + Metadata: &projecthubpb.RequestMeta{ + ApiVersion: "v1alpha", + Kind: "Project", + OrgId: strings.TrimSpace(orgID), + UserId: strings.TrimSpace(userID), + ReqId: uuid.NewString(), + }, + } +} + +func validateProjectDescribeResponse(resp *projecthubpb.DescribeResponse, orgID, projectID string) (*projecthubpb.Project, error) { + if resp == nil { + return nil, fmt.Errorf("Project describe returned no data") + } + meta := resp.GetMetadata() + if meta == nil || meta.GetStatus() == nil { + return nil, fmt.Errorf("Project describe response is missing status information") + } + if meta.GetStatus().GetCode() != projecthubpb.ResponseMeta_OK { + message := strings.TrimSpace(meta.GetStatus().GetMessage()) + if message == "" { + message = "Project describe request failed" + } + return nil, fmt.Errorf("%s", message) + } + project := resp.GetProject() + if project == nil { + return nil, fmt.Errorf("Project describe response did not include project details") + } + projMeta := project.GetMetadata() + if projMeta == nil { + return nil, fmt.Errorf("Project metadata is missing") + } + if resourceOrg := strings.TrimSpace(projMeta.GetOrgId()); resourceOrg == "" || !strings.EqualFold(resourceOrg, orgID) { + shared.ReportScopeMismatch(shared.ScopeMismatchMetadata{ + Tool: runToolName, + ResourceType: "project", + ResourceID: projMeta.GetId(), + RequestOrgID: orgID, + ResourceOrgID: resourceOrg, + RequestProjectID: projectID, + ResourceProjectID: projMeta.GetId(), + }) + return nil, fmt.Errorf("Project %s does not belong to organization %s", projectID, orgID) + } + return project, nil +} + +// branchNameFromReference extracts the branch name from a git reference. +// Note: Tags intentionally return the full "refs/tags/*" path as required by the workflow service API. +func branchNameFromReference(ref string) string { + value := strings.TrimSpace(ref) + switch { + case strings.HasPrefix(value, "refs/heads/"): + return strings.TrimPrefix(value, "refs/heads/") + case strings.HasPrefix(value, "refs/tags/"): + return value // Workflow service expects full path for tags + case strings.HasPrefix(value, "refs/pull/"): + return "pull-request-" + strings.TrimPrefix(value, "refs/pull/") + default: + return value + } +} + +func labelFromReference(ref string) string { + value := strings.TrimSpace(ref) + switch { + case strings.HasPrefix(value, "refs/tags/"): + return strings.TrimPrefix(value, "refs/tags/") + case strings.HasPrefix(value, "refs/pull/"): + return strings.TrimPrefix(value, "refs/pull/") + case strings.HasPrefix(value, "refs/heads/"): + return strings.TrimPrefix(value, "refs/heads/") + default: + return value + } +} + +func ensureGitReference(ref string) string { + ref = strings.TrimSpace(ref) + if strings.HasPrefix(ref, "refs/") { + return ref + } + return "refs/heads/" + ref +} + +func mapIntegrationType(integration repopb.IntegrationType) (workflowpb.ScheduleRequest_ServiceType, error) { + switch integration { + case repopb.IntegrationType_GITHUB_OAUTH_TOKEN: + return workflowpb.ScheduleRequest_GIT_HUB, nil + case repopb.IntegrationType_GITHUB_APP: + return workflowpb.ScheduleRequest_GIT_HUB, nil + case repopb.IntegrationType_BITBUCKET: + return workflowpb.ScheduleRequest_BITBUCKET, nil + case repopb.IntegrationType_GITLAB: + return workflowpb.ScheduleRequest_GITLAB, nil + case repopb.IntegrationType_GIT: + return workflowpb.ScheduleRequest_GIT, nil + default: + return workflowpb.ScheduleRequest_GIT_HUB, fmt.Errorf("unsupported repository integration type: %v", integration) + } +} + +func formatRunMarkdown(result runResult) string { + mb := shared.NewMarkdownBuilder() + mb.H1("Workflow Scheduled") + if result.WorkflowID != "" { + mb.KeyValue("Workflow ID", fmt.Sprintf("`%s`", result.WorkflowID)) + } + if result.PipelineID != "" { + mb.KeyValue("Initial Pipeline", fmt.Sprintf("`%s`", result.PipelineID)) + } + if result.Reference != "" { + mb.KeyValue("Reference", result.Reference) + } + if result.CommitSHA != "" { + mb.KeyValue("Commit", shortenCommit(result.CommitSHA)) + } + if result.PipelineFile != "" { + mb.KeyValue("Pipeline File", result.PipelineFile) + } + return mb.String() +} diff --git a/mcp_server/pkg/tools/workflows/run_tool_test.go b/mcp_server/pkg/tools/workflows/run_tool_test.go new file mode 100644 index 000000000..08adf0d12 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/run_tool_test.go @@ -0,0 +1,550 @@ +package workflows + +import ( + "context" + "fmt" + "net/http" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/semaphoreio/semaphore/mcp_server/pkg/feature" + workflowpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/plumber_w_f.workflow" + projecthubpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/projecthub" + rbacpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/rbac" + repopb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/repository_integrator" + statuspb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/status" + support "github.com/semaphoreio/semaphore/mcp_server/test/support" + + code "google.golang.org/genproto/googleapis/rpc/code" +) + +func TestRunWorkflowSuccess(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + repo := &projecthubpb.Project_Spec_Repository{ + Owner: "octo", + Name: "repo", + PipelineFile: ".semaphore/ci.yml", + IntegrationType: repopb.IntegrationType_GITHUB_APP, + } + projectStub := &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, repo)} + workflowStub := &support.WorkflowClientStub{ + ScheduleResp: &workflowpb.ScheduleResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + WfId: "wf-001", + PplId: "ppl-001", + }, + } + provider := &support.MockProvider{ + WorkflowClient: workflowStub, + ProjectClient: projectStub, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "refs/heads/feature/login", + "commit_sha": "abc1234", + "parameters": map[string]any{ + "DEPLOY_ENV": "staging", + }, + }, + }, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + result, ok := res.StructuredContent.(runResult) + if !ok { + toFail(t, "unexpected structured content type: %T", res.StructuredContent) + } + if result.WorkflowID != "wf-001" || result.PipelineID != "ppl-001" { + toFail(t, "unexpected run result: %+v", result) + } + + reqMsg := workflowStub.LastSchedule + if reqMsg == nil { + toFail(t, "expected schedule request to be recorded") + } + if reqMsg.GetProjectId() != projectID || reqMsg.GetOrganizationId() != orgID { + toFail(t, "unexpected schedule scope: %+v", reqMsg) + } + if reqMsg.GetDefinitionFile() != ".semaphore/ci.yml" { + toFail(t, "expected pipeline file from project, got %s", reqMsg.GetDefinitionFile()) + } + if reqMsg.GetRepo().GetBranchName() != "feature/login" { + toFail(t, "unexpected branch name: %s", reqMsg.GetRepo().GetBranchName()) + } + if reqMsg.GetRepo().GetCommitSha() != "abc1234" { + toFail(t, "unexpected commit sha: %s", reqMsg.GetRepo().GetCommitSha()) + } + if reqMsg.GetLabel() != "feature/login" { + toFail(t, "expected label to match branch name, got %s", reqMsg.GetLabel()) + } + if got := reqMsg.GetService(); got != workflowpb.ScheduleRequest_GIT_HUB { + toFail(t, "unexpected service type: %v", got) + } + if len(reqMsg.GetEnvVars()) != 1 || reqMsg.GetEnvVars()[0].GetName() != "DEPLOY_ENV" { + toFail(t, "unexpected env vars: %+v", reqMsg.GetEnvVars()) + } +} + +func TestRunWorkflowFeatureDisabled(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + FeaturesService: support.FeatureClientStub{State: feature.Hidden}, + } + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", "99999999-aaaa-bbbb-cccc-dddddddddddd") + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(strings.ToLower(msg), "disabled") { + toFail(t, "expected feature disabled message, got %q", msg) + } +} + +func TestRunWorkflowInvalidParameters(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + "parameters": []string{"bad"}, + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", "99999999-aaaa-bbbb-cccc-dddddddddddd") + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(strings.ToLower(msg), "parameters") { + toFail(t, "expected parameters error, got %q", msg) + } +} + +func TestRunWorkflowPermissionDenied(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + workflowStub := &support.WorkflowClientStub{} + provider := &support.MockProvider{ + WorkflowClient: workflowStub, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(), // no permissions granted + } + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", "99999999-aaaa-bbbb-cccc-dddddddddddd") + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "Permission denied while accessing project") { + toFail(t, "expected permission denied message, got %q", msg) + } + if workflowStub.LastSchedule != nil { + toFail(t, "workflow schedule should not have been invoked when permission is missing") + } +} + +func TestRunWorkflowInvalidCommitSHA(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + "commit_sha": "INVALID", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "commit_sha must be a hexadecimal SHA") { + toFail(t, "expected invalid commit sha error, got %q", msg) + } +} + +func TestRunWorkflowInvalidPipelineFile(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + "pipeline_file": "../outside.yml", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "pipeline_file must not contain '..' sequences") { + toFail(t, "expected invalid pipeline file error, got %q", msg) + } +} + +func TestRunWorkflowProjectDescribeFailure(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Err: fmt.Errorf("project describe failed")}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "Unable to load project details") { + toFail(t, "expected project load error, got %q", msg) + } +} + +func TestRunWorkflowScheduleFailure(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + repo := &projecthubpb.Project_Spec_Repository{ + Owner: "octo", + Name: "repo", + PipelineFile: ".semaphore/ci.yml", + IntegrationType: repopb.IntegrationType_GITHUB_APP, + } + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{ScheduleErr: fmt.Errorf("scheduler unavailable")}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, repo)}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "Workflow schedule failed") { + toFail(t, "expected schedule failure error, got %q", msg) + } +} + +func TestRunWorkflowScopeMismatch(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + wrongOrgID := "bbbbbbbb-cccc-dddd-eeee-ffffffffffff" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + repo := &projecthubpb.Project_Spec_Repository{ + Owner: "octo", + Name: "repo", + IntegrationType: repopb.IntegrationType_GITHUB_APP, + } + response := support.NewProjectDescribeResponse(wrongOrgID, projectID, repo) + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: response}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "does not belong to organization") { + toFail(t, "expected scope mismatch error, got %q", msg) + } +} + +func TestRunWorkflowMissingRepository(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + response := &projecthubpb.DescribeResponse{ + Metadata: &projecthubpb.ResponseMeta{ + Status: &projecthubpb.ResponseMeta_Status{Code: projecthubpb.ResponseMeta_OK}, + }, + Project: &projecthubpb.Project{ + Metadata: &projecthubpb.Project_Metadata{Id: projectID, OrgId: orgID}, + Spec: &projecthubpb.Project_Spec{Repository: nil}, + }, + } + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: response}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "repository configuration is missing") { + toFail(t, "expected missing repository error, got %q", msg) + } +} + +func TestRunWorkflowInvalidParameterNames(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + testCases := []struct { + name string + parameters map[string]any + expectErr string + }{ + {name: "empty parameter name", parameters: map[string]any{"": "value"}, expectErr: "parameter names must not be empty"}, + {name: "whitespace only", parameters: map[string]any{" ": "value"}, expectErr: "parameter names must not be empty"}, + {name: "starts with digit", parameters: map[string]any{"9VAR": "value"}, expectErr: "must start with a letter or underscore"}, + {name: "contains special chars", parameters: map[string]any{"MY-VAR": "value"}, expectErr: "must start with a letter or underscore"}, + {name: "contains space", parameters: map[string]any{"MY VAR": "value"}, expectErr: "must start with a letter or underscore"}, + {name: "contains control chars", parameters: map[string]any{"VAR\x00NAME": "value"}, expectErr: "contains control characters"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + "parameters": tc.parameters, + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, tc.expectErr) { + toFail(t, "expected error containing %q, got %q", tc.expectErr, msg) + } + }) + } +} + +func TestRunWorkflowInvalidParameterValues(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{})}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + testCases := []struct { + name string + parameters map[string]any + expectErr string + }{ + {name: "array value", parameters: map[string]any{"TAGS": []string{"tag1", "tag2"}}, expectErr: "must be strings, numbers, booleans, or null"}, + {name: "nested object", parameters: map[string]any{"CONFIG": map[string]string{"key": "value"}}, expectErr: "must be strings, numbers, booleans, or null"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + "parameters": tc.parameters, + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, tc.expectErr) { + toFail(t, "expected error containing %q, got %q", tc.expectErr, msg) + } + }) + } +} + +func TestRunWorkflowUnsupportedIntegrationType(t *testing.T) { + orgID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + projectID := "11111111-2222-3333-4444-555555555555" + userID := "99999999-aaaa-bbbb-cccc-dddddddddddd" + repo := &projecthubpb.Project_Spec_Repository{ + Owner: "octo", + Name: "repo", + IntegrationType: repopb.IntegrationType(999), + } + provider := &support.MockProvider{ + WorkflowClient: &support.WorkflowClientStub{}, + ProjectClient: &support.ProjectClientStub{Response: support.NewProjectDescribeResponse(orgID, projectID, repo)}, + Timeout: time.Second, + RBACClient: support.NewRBACStub(projectRunPermission), + } + + req := mcp.CallToolRequest{ + Params: mcp.CallToolParams{Arguments: map[string]any{ + "organization_id": orgID, + "project_id": projectID, + "reference": "main", + }}, + } + header := http.Header{} + header.Set("X-Semaphore-User-ID", userID) + req.Header = header + + res, err := runHandler(provider)(context.Background(), req) + if err != nil { + toFail(t, "handler error: %v", err) + } + msg := requireErrorText(t, res) + if !strings.Contains(msg, "integration type is not supported") { + toFail(t, "expected unsupported integration type error, got %q", msg) + } +} + +// Ensure RBAC stub implements the interface at compile time. +var _ rbacpb.RBACClient = (*support.RBACStub)(nil) diff --git a/mcp_server/pkg/tools/workflows/workflows.go b/mcp_server/pkg/tools/workflows/search_tool.go similarity index 89% rename from mcp_server/pkg/tools/workflows/workflows.go rename to mcp_server/pkg/tools/workflows/search_tool.go index 2c7666a86..81331124b 100644 --- a/mcp_server/pkg/tools/workflows/workflows.go +++ b/mcp_server/pkg/tools/workflows/search_tool.go @@ -7,8 +7,6 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" - "github.com/sirupsen/logrus" - "github.com/semaphoreio/semaphore/mcp_server/pkg/authz" workflowpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/plumber_w_f.workflow" userpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/user" @@ -16,14 +14,7 @@ import ( "github.com/semaphoreio/semaphore/mcp_server/pkg/logging" "github.com/semaphoreio/semaphore/mcp_server/pkg/tools/internal/shared" "github.com/semaphoreio/semaphore/mcp_server/pkg/utils" -) - -const ( - searchToolName = "workflows_search" - defaultLimit = 20 - maxLimit = 100 - missingWorkflowError = "workflow gRPC endpoint is not configured" - projectViewPermission = "project.view" + "github.com/sirupsen/logrus" ) func searchFullDescription() string { @@ -64,12 +55,7 @@ Next steps: - Use workflows_search(project_id="...", branch="main") regularly to monitor your own workflows` } -// Register wires the workflows tool into the MCP server. -func Register(s *server.MCPServer, api internalapi.Provider) { - s.AddTool(newTool(searchToolName, searchFullDescription()), listHandler(api)) -} - -func newTool(name, description string) mcp.Tool { +func newSearchTool(name, description string) mcp.Tool { return mcp.NewTool( name, mcp.WithDescription(description), @@ -113,25 +99,6 @@ func newTool(name, description string) mcp.Tool { ) } -type summary struct { - ID string `json:"id"` - InitialPipeline string `json:"initialPipelineId,omitempty"` - ProjectID string `json:"projectId,omitempty"` - OrganizationID string `json:"organizationId,omitempty"` - Branch string `json:"branch,omitempty"` - CommitSHA string `json:"commitSha,omitempty"` - RequesterID string `json:"requesterId,omitempty"` - TriggeredBy string `json:"triggeredBy,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` - RerunOf string `json:"rerunOf,omitempty"` - RepositoryID string `json:"repositoryId,omitempty"` -} - -type listResult struct { - Workflows []summary `json:"workflows"` - NextCursor string `json:"nextCursor,omitempty"` -} - func listHandler(api internalapi.Provider) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { client := api.Workflow() @@ -418,34 +385,6 @@ func formatWorkflowsMarkdown(result listResult, mode, projectID, orgID, branch, return mb.String() } -func humanizeTriggeredBy(value string) string { - value = strings.TrimSpace(value) - if value == "" { - return "Unspecified" - } - parts := strings.Split(value, "_") - for i, part := range parts { - if part == "" { - continue - } - part = strings.ToLower(part) - parts[i] = strings.ToUpper(part[:1]) + part[1:] - } - return strings.Join(parts, " ") -} - -func shortenCommit(sha string) string { - sha = strings.TrimSpace(sha) - if len(sha) > 12 { - return sha[:12] - } - return sha -} - -func normalizeID(value string) string { - return strings.ToLower(strings.TrimSpace(value)) -} - func resolveRequesterID(ctx context.Context, api internalapi.Provider, raw string) (string, error) { candidate := strings.ToLower(strings.TrimSpace(raw)) if candidate == "" { diff --git a/mcp_server/pkg/tools/workflows/workflows_test.go b/mcp_server/pkg/tools/workflows/search_tool_test.go similarity index 54% rename from mcp_server/pkg/tools/workflows/workflows_test.go rename to mcp_server/pkg/tools/workflows/search_tool_test.go index d4f7f852b..4d5c23962 100644 --- a/mcp_server/pkg/tools/workflows/workflows_test.go +++ b/mcp_server/pkg/tools/workflows/search_tool_test.go @@ -10,13 +10,11 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/semaphoreio/semaphore/mcp_server/pkg/feature" workflowpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/plumber_w_f.workflow" - rbacpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/rbac" statuspb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/status" userpb "github.com/semaphoreio/semaphore/mcp_server/pkg/internal_api/user" support "github.com/semaphoreio/semaphore/mcp_server/test/support" "google.golang.org/genproto/googleapis/rpc/code" - "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -32,8 +30,8 @@ func TestListWorkflows_FeatureFlagDisabled(t *testing.T) { provider := &support.MockProvider{ FeaturesService: support.FeatureClientStub{State: feature.Hidden}, Timeout: time.Second, - WorkflowClient: &workflowClientStub{}, - RBACClient: newRBACStub("project.view"), + WorkflowClient: &support.WorkflowClientStub{}, + RBACClient: support.NewRBACStub("project.view"), } res, err := listHandler(provider)(context.Background(), req) @@ -48,8 +46,8 @@ func TestListWorkflows_FeatureFlagDisabled(t *testing.T) { func TestListWorkflows(t *testing.T) { projectID := "11111111-2222-3333-4444-555555555555" - client := &workflowClientStub{ - listResp: &workflowpb.ListKeysetResponse{ + client := &support.WorkflowClientStub{ + ListResp: &workflowpb.ListKeysetResponse{ Status: &statuspb.Status{Code: code.Code_OK}, Workflows: []*workflowpb.WorkflowDetails{ { @@ -69,7 +67,7 @@ func TestListWorkflows(t *testing.T) { provider := &support.MockProvider{ WorkflowClient: client, Timeout: time.Second, - RBACClient: newRBACStub("project.view"), + RBACClient: support.NewRBACStub("project.view"), } handler := listHandler(provider) @@ -109,14 +107,14 @@ func TestListWorkflows(t *testing.T) { toFail(t, "expected next cursor 'cursor', got %q", result.NextCursor) } - if client.lastList == nil { + if client.LastList == nil { toFail(t, "expected list request to be recorded") } - if got := client.lastList.GetRequesterId(); got != "99999999-aaaa-bbbb-cccc-dddddddddddd" { + if got := client.LastList.GetRequesterId(); got != "99999999-aaaa-bbbb-cccc-dddddddddddd" { toFail(t, "expected requester to default to user header, got %s", got) } - if got := client.lastList.GetPageSize(); got != 10 { + if got := client.LastList.GetPageSize(); got != 10 { toFail(t, "expected page size 10, got %d", got) } } @@ -124,21 +122,21 @@ func TestListWorkflows(t *testing.T) { func TestListWorkflowsWithRequesterOverride(t *testing.T) { projectID := "11111111-2222-3333-4444-555555555555" requester := "deploy-bot" - client := &workflowClientStub{ - listResp: &workflowpb.ListKeysetResponse{ + client := &support.WorkflowClientStub{ + ListResp: &workflowpb.ListKeysetResponse{ Status: &statuspb.Status{Code: code.Code_OK}, Workflows: []*workflowpb.WorkflowDetails{}, }, } - userClient := &userClientStub{ - response: &userpb.User{Id: "00000000-1111-2222-3333-444444444444"}, + userClient := &support.UserClientStub{ + Response: &userpb.User{Id: "00000000-1111-2222-3333-444444444444"}, } provider := &support.MockProvider{ WorkflowClient: client, UserClient: userClient, Timeout: time.Second, - RBACClient: newRBACStub("project.view"), + RBACClient: support.NewRBACStub("project.view"), } req := mcp.CallToolRequest{ @@ -160,24 +158,24 @@ func TestListWorkflowsWithRequesterOverride(t *testing.T) { toFail(t, "handler error: %v", err) } - if client.lastList == nil { + if client.LastList == nil { toFail(t, "expected list request to be recorded") } - if got := strings.TrimSpace(client.lastList.GetRequesterId()); got != "00000000-1111-2222-3333-444444444444" { + if got := strings.TrimSpace(client.LastList.GetRequesterId()); got != "00000000-1111-2222-3333-444444444444" { toFail(t, "expected requester override to propagate, got %s", got) } - if userClient.lastRequest == nil || userClient.lastRequest.GetProvider() == nil { + if userClient.LastRequest == nil || userClient.LastRequest.GetProvider() == nil { toFail(t, "expected user lookup to be recorded") } - if login := userClient.lastRequest.GetProvider().GetLogin(); login != requester { + if login := userClient.LastRequest.GetProvider().GetLogin(); login != requester { toFail(t, "expected user lookup login %s, got %s", requester, login) } } func TestListWorkflowsPermissionDenied(t *testing.T) { projectID := "11111111-2222-3333-4444-555555555555" - client := &workflowClientStub{} - rbac := newRBACStub() + client := &support.WorkflowClientStub{} + rbac := support.NewRBACStub() provider := &support.MockProvider{ WorkflowClient: client, @@ -206,17 +204,17 @@ func TestListWorkflowsPermissionDenied(t *testing.T) { if !strings.Contains(msg, `Permission denied while accessing project`) { toFail(t, "expected permission denied message, got %q", msg) } - if client.lastList != nil { - toFail(t, "expected no workflow RPC call, got %+v", client.lastList) + if client.LastList != nil { + toFail(t, "expected no workflow RPC call, got %+v", client.LastList) } - if len(rbac.lastRequests) != 1 { - toFail(t, "expected one RBAC request, got %d", len(rbac.lastRequests)) + if len(rbac.LastRequests) != 1 { + toFail(t, "expected one RBAC request, got %d", len(rbac.LastRequests)) } } func TestListWorkflowsRBACUnavailable(t *testing.T) { projectID := "11111111-2222-3333-4444-555555555555" - client := &workflowClientStub{} + client := &support.WorkflowClientStub{} provider := &support.MockProvider{ WorkflowClient: client, @@ -244,15 +242,15 @@ func TestListWorkflowsRBACUnavailable(t *testing.T) { if !strings.Contains(msg, "Authorization service is not configured") { toFail(t, "expected RBAC unavailable message, got %q", msg) } - if client.lastList != nil { - toFail(t, "expected no workflow RPC call, got %+v", client.lastList) + if client.LastList != nil { + toFail(t, "expected no workflow RPC call, got %+v", client.LastList) } } func TestListWorkflowsScopeMismatchOrganization(t *testing.T) { projectID := "11111111-2222-3333-4444-555555555555" - client := &workflowClientStub{ - listResp: &workflowpb.ListKeysetResponse{ + client := &support.WorkflowClientStub{ + ListResp: &workflowpb.ListKeysetResponse{ Status: &statuspb.Status{Code: code.Code_OK}, Workflows: []*workflowpb.WorkflowDetails{ { @@ -263,7 +261,7 @@ func TestListWorkflowsScopeMismatchOrganization(t *testing.T) { }, }, } - rbac := newRBACStub("project.view") + rbac := support.NewRBACStub("project.view") provider := &support.MockProvider{ WorkflowClient: client, @@ -296,8 +294,8 @@ func TestListWorkflowsScopeMismatchOrganization(t *testing.T) { func TestListWorkflowsScopeMismatchProject(t *testing.T) { projectID := "11111111-2222-3333-4444-555555555555" - client := &workflowClientStub{ - listResp: &workflowpb.ListKeysetResponse{ + client := &support.WorkflowClientStub{ + ListResp: &workflowpb.ListKeysetResponse{ Status: &statuspb.Status{Code: code.Code_OK}, Workflows: []*workflowpb.WorkflowDetails{ { @@ -308,7 +306,7 @@ func TestListWorkflowsScopeMismatchProject(t *testing.T) { }, }, } - rbac := newRBACStub("project.view") + rbac := support.NewRBACStub("project.view") provider := &support.MockProvider{ WorkflowClient: client, @@ -338,179 +336,3 @@ func TestListWorkflowsScopeMismatchProject(t *testing.T) { toFail(t, "expected project scope mismatch message, got %q", msg) } } - -type workflowClientStub struct { - workflowpb.WorkflowServiceClient - listResp *workflowpb.ListKeysetResponse - listErr error - lastList *workflowpb.ListKeysetRequest -} - -func requireErrorText(t *testing.T, res *mcp.CallToolResult) string { - t.Helper() - if res == nil { - t.Fatalf("expected tool result") - } - if !res.IsError { - t.Fatalf("expected error result, got success") - } - if len(res.Content) == 0 { - t.Fatalf("expected error content") - } - text, ok := res.Content[0].(mcp.TextContent) - if !ok { - t.Fatalf("expected text content, got %T", res.Content[0]) - } - return text.Text -} - -func newRBACStub(perms ...string) *rbacStub { - copied := append([]string(nil), perms...) - return &rbacStub{permissions: copied} -} - -type rbacStub struct { - rbacpb.RBACClient - - permissions []string - perProject map[string][]string - perOrg map[string][]string - err error - errorForProject map[string]error - errorForOrg map[string]error - lastRequests []*rbacpb.ListUserPermissionsRequest -} - -func (s *rbacStub) ListUserPermissions(ctx context.Context, in *rbacpb.ListUserPermissionsRequest, opts ...grpc.CallOption) (*rbacpb.ListUserPermissionsResponse, error) { - reqCopy := &rbacpb.ListUserPermissionsRequest{ - UserId: in.GetUserId(), - OrgId: in.GetOrgId(), - ProjectId: in.GetProjectId(), - } - s.lastRequests = append(s.lastRequests, reqCopy) - - if s.err != nil { - return nil, s.err - } - - projectKey := normalizeKey(in.GetProjectId()) - orgKey := normalizeKey(in.GetOrgId()) - - if projectKey != "" { - if err := s.errorForProject[projectKey]; err != nil { - return nil, err - } - } else if orgKey != "" { - if err := s.errorForOrg[orgKey]; err != nil { - return nil, err - } - } - - perms := s.permissions - if projectKey != "" { - if override, ok := s.perProject[projectKey]; ok { - perms = override - } - } else if orgKey != "" { - if override, ok := s.perOrg[orgKey]; ok { - perms = override - } - } - if perms == nil { - perms = []string{} - } - - return &rbacpb.ListUserPermissionsResponse{ - UserId: in.GetUserId(), - OrgId: in.GetOrgId(), - ProjectId: in.GetProjectId(), - Permissions: append([]string(nil), perms...), - }, nil -} - -func normalizeKey(value string) string { - return strings.ToLower(strings.TrimSpace(value)) -} - -func (s *workflowClientStub) Schedule(context.Context, *workflowpb.ScheduleRequest, ...grpc.CallOption) (*workflowpb.ScheduleResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) GetPath(context.Context, *workflowpb.GetPathRequest, ...grpc.CallOption) (*workflowpb.GetPathResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) List(context.Context, *workflowpb.ListRequest, ...grpc.CallOption) (*workflowpb.ListResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) ListKeyset(ctx context.Context, in *workflowpb.ListKeysetRequest, opts ...grpc.CallOption) (*workflowpb.ListKeysetResponse, error) { - s.lastList = in - if s.listErr != nil { - return nil, s.listErr - } - return s.listResp, nil -} - -type userClientStub struct { - userpb.UserServiceClient - response *userpb.User - err error - lastRequest *userpb.DescribeByRepositoryProviderRequest -} - -func (u *userClientStub) DescribeByRepositoryProvider(ctx context.Context, in *userpb.DescribeByRepositoryProviderRequest, opts ...grpc.CallOption) (*userpb.User, error) { - u.lastRequest = in - if u.err != nil { - return nil, u.err - } - if u.response == nil { - u.response = &userpb.User{Id: "ffffffff-ffff-ffff-ffff-ffffffffffff"} - } - return u.response, nil -} - -func (s *workflowClientStub) ListGrouped(context.Context, *workflowpb.ListGroupedRequest, ...grpc.CallOption) (*workflowpb.ListGroupedResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) ListGroupedKS(context.Context, *workflowpb.ListGroupedKSRequest, ...grpc.CallOption) (*workflowpb.ListGroupedKSResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) ListLatestWorkflows(context.Context, *workflowpb.ListLatestWorkflowsRequest, ...grpc.CallOption) (*workflowpb.ListLatestWorkflowsResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) Describe(context.Context, *workflowpb.DescribeRequest, ...grpc.CallOption) (*workflowpb.DescribeResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) DescribeMany(context.Context, *workflowpb.DescribeManyRequest, ...grpc.CallOption) (*workflowpb.DescribeManyResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) Terminate(context.Context, *workflowpb.TerminateRequest, ...grpc.CallOption) (*workflowpb.TerminateResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) ListLabels(context.Context, *workflowpb.ListLabelsRequest, ...grpc.CallOption) (*workflowpb.ListLabelsResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) Reschedule(context.Context, *workflowpb.RescheduleRequest, ...grpc.CallOption) (*workflowpb.ScheduleResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) GetProjectId(context.Context, *workflowpb.GetProjectIdRequest, ...grpc.CallOption) (*workflowpb.GetProjectIdResponse, error) { - panic("not implemented") -} - -func (s *workflowClientStub) Create(context.Context, *workflowpb.CreateRequest, ...grpc.CallOption) (*workflowpb.CreateResponse, error) { - panic("not implemented") -} - -func toFail(t *testing.T, format string, args ...any) { - t.Helper() - t.Fatalf(format, args...) -} diff --git a/mcp_server/pkg/tools/workflows/test_helpers_test.go b/mcp_server/pkg/tools/workflows/test_helpers_test.go new file mode 100644 index 000000000..cc38108b9 --- /dev/null +++ b/mcp_server/pkg/tools/workflows/test_helpers_test.go @@ -0,0 +1,30 @@ +package workflows + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func requireErrorText(t *testing.T, res *mcp.CallToolResult) string { + t.Helper() + if res == nil { + t.Fatalf("expected tool result") + } + if !res.IsError { + t.Fatalf("expected error result, got success") + } + if len(res.Content) == 0 { + t.Fatalf("expected error content") + } + text, ok := res.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("expected text content, got %T", res.Content[0]) + } + return text.Text +} + +func toFail(t *testing.T, format string, args ...any) { + t.Helper() + t.Fatalf(format, args...) +} diff --git a/mcp_server/test/support/stubs.go b/mcp_server/test/support/stubs.go index f53050069..609172782 100644 --- a/mcp_server/test/support/stubs.go +++ b/mcp_server/test/support/stubs.go @@ -32,15 +32,15 @@ import ( func New() internalapi.Provider { return &provider{ timeout: time.Second, - workflows: &workflowStub{}, + workflows: &WorkflowClientStub{}, organizations: &organizationStub{}, - projects: &projectStub{}, + projects: &ProjectClientStub{}, pipelines: &pipelineStub{}, jobs: &jobStub{}, loghub: &loghubStub{}, loghub2: &loghub2Stub{}, - users: &userStub{}, - rbac: &rbacStub{}, + users: &UserClientStub{}, + rbac: &RBACStub{}, features: &featureStub{}, } } @@ -81,31 +81,6 @@ func (p *provider) RBAC() rbacpb.RBACClient { return p.rbac } func (p *provider) Features() featuresvc.FeatureClient { return p.features } -// --- workflow stub --- - -type workflowStub struct { - workflowpb.WorkflowServiceClient -} - -func (w *workflowStub) ListKeyset(ctx context.Context, in *workflowpb.ListKeysetRequest, opts ...grpc.CallOption) (*workflowpb.ListKeysetResponse, error) { - return &workflowpb.ListKeysetResponse{ - Status: &statuspb.Status{Code: code.Code_OK}, - Workflows: []*workflowpb.WorkflowDetails{ - { - WfId: "wf-local", - InitialPplId: "ppl-local", - ProjectId: orDefault(in.GetProjectId(), "project-local"), - BranchName: "main", - CommitSha: "abcdef0", - CreatedAt: timestamppb.New(time.Unix(1_700_000_000, 0)), - TriggeredBy: workflowpb.TriggeredBy_MANUAL_RUN, - OrganizationId: "org-local", - }, - }, - NextPageToken: "", - }, nil -} - // --- pipeline stub --- type pipelineStub struct { @@ -206,38 +181,6 @@ func (l *loghub2Stub) GenerateToken(ctx context.Context, in *loghub2pb.GenerateT return &loghub2pb.GenerateTokenResponse{Token: "stub-token", Type: loghub2pb.TokenType_PULL}, nil } -// --- rbac stub --- - -type rbacStub struct { - rbacpb.RBACClient - orgIDs []string -} - -func (r *rbacStub) ListAccessibleOrgs(ctx context.Context, in *rbacpb.ListAccessibleOrgsRequest, opts ...grpc.CallOption) (*rbacpb.ListAccessibleOrgsResponse, error) { - ids := r.orgIDs - if len(ids) == 0 { - ids = []string{"org-local"} - } - return &rbacpb.ListAccessibleOrgsResponse{OrgIds: ids}, nil -} - -// --- user stub --- - -type userStub struct { - userpb.UserServiceClient -} - -func (u *userStub) DescribeByRepositoryProvider(ctx context.Context, in *userpb.DescribeByRepositoryProviderRequest, opts ...grpc.CallOption) (*userpb.User, error) { - login := "" - if in != nil && in.GetProvider() != nil { - login = strings.TrimSpace(in.GetProvider().GetLogin()) - } - if login == "" { - login = "stub-user" - } - return &userpb.User{Id: fmt.Sprintf("user-%s", login)}, nil -} - func orDefault(value, fallback string) string { if value != "" { return value @@ -299,44 +242,6 @@ func (o *organizationStub) DescribeMany(ctx context.Context, in *orgpb.DescribeM }, nil } -// --- project stub --- - -type projectStub struct { - projecthubpb.ProjectServiceClient -} - -func (p *projectStub) List(ctx context.Context, in *projecthubpb.ListRequest, opts ...grpc.CallOption) (*projecthubpb.ListResponse, error) { - return &projecthubpb.ListResponse{ - Metadata: &projecthubpb.ResponseMeta{ - Status: &projecthubpb.ResponseMeta_Status{Code: projecthubpb.ResponseMeta_OK}, - }, - Pagination: &projecthubpb.PaginationResponse{ - PageNumber: in.GetPagination().GetPage(), - PageSize: in.GetPagination().GetPageSize(), - TotalEntries: 1, - TotalPages: 1, - }, - Projects: []*projecthubpb.Project{ - { - Metadata: &projecthubpb.Project_Metadata{ - Id: "project-local", - Name: "Example Project", - OrgId: "org-local", - OwnerId: "user-local", - CreatedAt: timestamppb.New(time.Unix(1_700_000_000, 0)), - }, - Spec: &projecthubpb.Project_Spec{ - Repository: &projecthubpb.Project_Spec_Repository{ - Url: "https://github.com/example/project", - DefaultBranch: "main", - PipelineFile: ".semaphore/semaphore.yml", - }, - }, - }, - }, - }, nil -} - // --- features stub --- type featureStub struct { @@ -429,3 +334,307 @@ func (s *FeatureHubServiceStub) CallCount() int { defer s.mu.Unlock() return s.callCount } + +// --- workflow tool focused stubs --- + +// WorkflowClientStub records workflow RPC requests and returns configurable responses. +type WorkflowClientStub struct { + workflowpb.WorkflowServiceClient + + ListResp *workflowpb.ListKeysetResponse + ListErr error + LastList *workflowpb.ListKeysetRequest + ScheduleResp *workflowpb.ScheduleResponse + ScheduleErr error + LastSchedule *workflowpb.ScheduleRequest + RescheduleResp *workflowpb.ScheduleResponse + RescheduleErr error + LastReschedule *workflowpb.RescheduleRequest + GetProjectResp *workflowpb.GetProjectIdResponse + GetProjectErr error + LastGetProjectId *workflowpb.GetProjectIdRequest + DescribeResp *workflowpb.DescribeResponse + DescribeErr error + LastDescribe *workflowpb.DescribeRequest +} + +func (s *WorkflowClientStub) Schedule(ctx context.Context, in *workflowpb.ScheduleRequest, opts ...grpc.CallOption) (*workflowpb.ScheduleResponse, error) { + s.LastSchedule = in + if s.ScheduleErr != nil { + return nil, s.ScheduleErr + } + if s.ScheduleResp != nil { + return s.ScheduleResp, nil + } + return &workflowpb.ScheduleResponse{Status: &statuspb.Status{Code: code.Code_OK}}, nil +} + +func (s *WorkflowClientStub) ListKeyset(ctx context.Context, in *workflowpb.ListKeysetRequest, opts ...grpc.CallOption) (*workflowpb.ListKeysetResponse, error) { + s.LastList = in + if s.ListErr != nil { + return nil, s.ListErr + } + if s.ListResp != nil { + return s.ListResp, nil + } + projectID := "project-local" + if in != nil && strings.TrimSpace(in.GetProjectId()) != "" { + projectID = in.GetProjectId() + } + return &workflowpb.ListKeysetResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + Workflows: []*workflowpb.WorkflowDetails{ + { + WfId: "wf-local", + InitialPplId: "ppl-local", + ProjectId: projectID, + BranchName: "main", + CommitSha: "abcdef0", + CreatedAt: timestamppb.New(time.Unix(1_700_000_000, 0)), + TriggeredBy: workflowpb.TriggeredBy_MANUAL_RUN, + OrganizationId: "org-local", + }, + }, + NextPageToken: "", + }, nil +} + +func (s *WorkflowClientStub) Reschedule(ctx context.Context, in *workflowpb.RescheduleRequest, opts ...grpc.CallOption) (*workflowpb.ScheduleResponse, error) { + s.LastReschedule = in + if s.RescheduleErr != nil { + return nil, s.RescheduleErr + } + if s.RescheduleResp != nil { + return s.RescheduleResp, nil + } + return &workflowpb.ScheduleResponse{Status: &statuspb.Status{Code: code.Code_OK}, WfId: "wf-rerun", PplId: "ppl-rerun"}, nil +} + +func (s *WorkflowClientStub) GetProjectId(ctx context.Context, in *workflowpb.GetProjectIdRequest, opts ...grpc.CallOption) (*workflowpb.GetProjectIdResponse, error) { + s.LastGetProjectId = in + if s.GetProjectErr != nil { + return nil, s.GetProjectErr + } + if s.GetProjectResp != nil { + return s.GetProjectResp, nil + } + return &workflowpb.GetProjectIdResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + ProjectId: "project-local", + }, nil +} + +func (s *WorkflowClientStub) Describe(ctx context.Context, in *workflowpb.DescribeRequest, opts ...grpc.CallOption) (*workflowpb.DescribeResponse, error) { + s.LastDescribe = in + if s.DescribeErr != nil { + return nil, s.DescribeErr + } + if s.DescribeResp != nil { + return s.DescribeResp, nil + } + return &workflowpb.DescribeResponse{ + Status: &statuspb.Status{Code: code.Code_OK}, + Workflow: &workflowpb.WorkflowDetails{ + WfId: orDefault(in.GetWfId(), "wf-local"), + ProjectId: "project-local", + OrganizationId: "org-local", + }, + }, nil +} + +// ProjectClientStub records describe requests and returns configurable responses. +type ProjectClientStub struct { + projecthubpb.ProjectServiceClient + + Response *projecthubpb.DescribeResponse + Err error + LastDescribe *projecthubpb.DescribeRequest + ListResponse *projecthubpb.ListResponse + ListErr error +} + +func (s *ProjectClientStub) Describe(ctx context.Context, in *projecthubpb.DescribeRequest, opts ...grpc.CallOption) (*projecthubpb.DescribeResponse, error) { + s.LastDescribe = in + if s.Err != nil { + return nil, s.Err + } + if s.Response != nil { + return s.Response, nil + } + orgID := "org-local" + projectID := "project-local" + if in != nil { + if pid := strings.TrimSpace(in.GetId()); pid != "" { + projectID = pid + } + } + return NewProjectDescribeResponse(orgID, projectID, &projecthubpb.Project_Spec_Repository{ + PipelineFile: ".semaphore/semaphore.yml", + }), nil +} + +func (s *ProjectClientStub) List(ctx context.Context, in *projecthubpb.ListRequest, opts ...grpc.CallOption) (*projecthubpb.ListResponse, error) { + if s.ListErr != nil { + return nil, s.ListErr + } + if s.ListResponse != nil { + return s.ListResponse, nil + } + page := int32(0) + size := int32(0) + if in != nil && in.GetPagination() != nil { + page = in.GetPagination().GetPage() + size = in.GetPagination().GetPageSize() + } + return &projecthubpb.ListResponse{ + Metadata: &projecthubpb.ResponseMeta{ + Status: &projecthubpb.ResponseMeta_Status{Code: projecthubpb.ResponseMeta_OK}, + }, + Pagination: &projecthubpb.PaginationResponse{ + PageNumber: page, + PageSize: size, + TotalEntries: 1, + TotalPages: 1, + }, + Projects: []*projecthubpb.Project{ + { + Metadata: &projecthubpb.Project_Metadata{ + Id: "project-local", + Name: "Example Project", + OrgId: "org-local", + OwnerId: "user-local", + CreatedAt: timestamppb.New(time.Unix(1_700_000_000, 0)), + }, + Spec: &projecthubpb.Project_Spec{ + Repository: &projecthubpb.Project_Spec_Repository{ + Url: "https://github.com/example/project", + DefaultBranch: "main", + PipelineFile: ".semaphore/semaphore.yml", + }, + }, + }, + }, + }, nil +} + +// NewProjectDescribeResponse creates a minimal project describe response for tests. +func NewProjectDescribeResponse(orgID, projectID string, repo *projecthubpb.Project_Spec_Repository) *projecthubpb.DescribeResponse { + if repo == nil { + repo = &projecthubpb.Project_Spec_Repository{} + } + return &projecthubpb.DescribeResponse{ + Metadata: &projecthubpb.ResponseMeta{ + Status: &projecthubpb.ResponseMeta_Status{Code: projecthubpb.ResponseMeta_OK}, + }, + Project: &projecthubpb.Project{ + Metadata: &projecthubpb.Project_Metadata{Id: projectID, OrgId: orgID}, + Spec: &projecthubpb.Project_Spec{Repository: repo}, + }, + } +} + +// NewRBACStub returns an RBAC stub with optional default permissions. +func NewRBACStub(perms ...string) *RBACStub { + copied := append([]string(nil), perms...) + return &RBACStub{Permissions: copied} +} + +// RBACStub records authorization checks and returns configurable results. +type RBACStub struct { + rbacpb.RBACClient + + Permissions []string + PerProject map[string][]string + PerOrg map[string][]string + Err error + ErrorForProject map[string]error + ErrorForOrg map[string]error + LastRequests []*rbacpb.ListUserPermissionsRequest + AccessibleOrgIDs []string +} + +func (s *RBACStub) ListUserPermissions(ctx context.Context, in *rbacpb.ListUserPermissionsRequest, opts ...grpc.CallOption) (*rbacpb.ListUserPermissionsResponse, error) { + reqCopy := &rbacpb.ListUserPermissionsRequest{ + UserId: in.GetUserId(), + OrgId: in.GetOrgId(), + ProjectId: in.GetProjectId(), + } + s.LastRequests = append(s.LastRequests, reqCopy) + + if s.Err != nil { + return nil, s.Err + } + + projectKey := normalizeKey(in.GetProjectId()) + orgKey := normalizeKey(in.GetOrgId()) + + if projectKey != "" { + if err := s.ErrorForProject[projectKey]; err != nil { + return nil, err + } + } else if orgKey != "" { + if err := s.ErrorForOrg[orgKey]; err != nil { + return nil, err + } + } + + perms := s.Permissions + if projectKey != "" { + if override, ok := s.PerProject[projectKey]; ok { + perms = override + } + } else if orgKey != "" { + if override, ok := s.PerOrg[orgKey]; ok { + perms = override + } + } + if perms == nil { + perms = []string{} + } + + return &rbacpb.ListUserPermissionsResponse{ + UserId: in.GetUserId(), + OrgId: in.GetOrgId(), + ProjectId: in.GetProjectId(), + Permissions: append([]string(nil), perms...), + }, nil +} + +func normalizeKey(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func (s *RBACStub) ListAccessibleOrgs(ctx context.Context, in *rbacpb.ListAccessibleOrgsRequest, opts ...grpc.CallOption) (*rbacpb.ListAccessibleOrgsResponse, error) { + ids := s.AccessibleOrgIDs + if len(ids) == 0 { + ids = []string{"org-local"} + } + return &rbacpb.ListAccessibleOrgsResponse{OrgIds: append([]string(nil), ids...)}, nil +} + +// UserClientStub captures user lookups and returns configurable responses. +type UserClientStub struct { + userpb.UserServiceClient + + Response *userpb.User + Err error + LastRequest *userpb.DescribeByRepositoryProviderRequest +} + +func (u *UserClientStub) DescribeByRepositoryProvider(ctx context.Context, in *userpb.DescribeByRepositoryProviderRequest, opts ...grpc.CallOption) (*userpb.User, error) { + u.LastRequest = in + if u.Err != nil { + return nil, u.Err + } + if u.Response != nil { + return u.Response, nil + } + login := "stub-user" + if in != nil && in.GetProvider() != nil { + candidate := strings.TrimSpace(in.GetProvider().GetLogin()) + if candidate != "" { + login = candidate + } + } + return &userpb.User{Id: fmt.Sprintf("user-%s", login)}, nil +}