diff --git a/actions/setup/js/create_pull_request.cjs b/actions/setup/js/create_pull_request.cjs index 46b50eba205..1e505b9d450 100644 --- a/actions/setup/js/create_pull_request.cjs +++ b/actions/setup/js/create_pull_request.cjs @@ -31,6 +31,7 @@ const { isStagedMode } = require("./safe_output_helpers.cjs"); const { withRetry, isTransientError } = require("./error_recovery.cjs"); const { tryEnforceArrayLimit } = require("./limit_enforcement_helpers.cjs"); const { findAgent, getIssueDetails, assignAgentToIssue } = require("./assign_agent_helpers.cjs"); +const { globPatternToRegex } = require("./glob_pattern_helpers.cjs"); /** * @typedef {import('./types/handler-factory').HandlerFactoryFunction} HandlerFactoryFunction @@ -88,6 +89,50 @@ const LABEL_MAX_RETRIES = 3; /** @type {number} Initial delay in ms before the first label retry (3 seconds) */ const LABEL_INITIAL_DELAY_MS = 3000; +/** + * Parse allowed base branch patterns from config value (array or comma-separated string) + * @param {string[]|string|undefined} allowedBaseBranchesValue + * @returns {Set} + */ +function parseAllowedBaseBranches(allowedBaseBranchesValue) { + const set = new Set(); + if (Array.isArray(allowedBaseBranchesValue)) { + allowedBaseBranchesValue + .map(branch => String(branch).trim()) + .filter(Boolean) + .forEach(branch => set.add(branch)); + } else if (typeof allowedBaseBranchesValue === "string") { + allowedBaseBranchesValue + .split(",") + .map(branch => branch.trim()) + .filter(Boolean) + .forEach(branch => set.add(branch)); + } + return set; +} + +/** + * Check if a base branch matches an allowed pattern. + * Supports exact matches and "*" glob patterns (e.g. "release/*"). + * @param {string} baseBranch + * @param {Set} allowedBaseBranches + * @returns {boolean} + */ +function isBaseBranchAllowed(baseBranch, allowedBaseBranches) { + if (allowedBaseBranches.has(baseBranch)) { + return true; + } + for (const pattern of allowedBaseBranches) { + if (pattern === "*") { + return true; + } + if (pattern.includes("*") && globPatternToRegex(pattern, { pathMode: true, caseSensitive: true }).test(baseBranch)) { + return true; + } + } + return false; +} + /** * Merges the required fallback label with any workflow-configured labels, * deduplicating and filtering empty values. @@ -250,6 +295,7 @@ async function main(config = {}) { const maxCount = config.max || 1; // PRs are typically limited to 1 const maxSizeKb = config.max_patch_size ? parseInt(String(config.max_patch_size), 10) : 1024; const { defaultTargetRepo, allowedRepos } = resolveTargetRepoConfig(config); + const allowedBaseBranches = parseAllowedBaseBranches(config.allowed_base_branches); const githubClient = await createAuthenticatedGitHubClient(config); // Check if copilot assignment is enabled for fallback issues @@ -350,6 +396,9 @@ async function main(config = {}) { if (allowedRepos.size > 0) { core.info(`Allowed repos: ${Array.from(allowedRepos).join(", ")}`); } + if (allowedBaseBranches.size > 0) { + core.info(`Allowed base branches: ${Array.from(allowedBaseBranches).join(", ")}`); + } if (envLabels.length > 0) { core.info(`Default labels: ${envLabels.join(", ")}`); } @@ -444,6 +493,49 @@ async function main(config = {}) { // NOTE: Must be resolved before checkout so cross-repo checkout uses the correct branch let baseBranch = configBaseBranch || (await getBaseBranch(repoParts)); + // Optional agent-provided base branch override. + // This is only allowed when allowed_base_branches is configured. + if (typeof pullRequestItem.base === "string" && pullRequestItem.base.trim() !== "") { + const requestedBaseBranchRaw = pullRequestItem.base.trim(); + const requestedBaseBranchForLog = JSON.stringify(requestedBaseBranchRaw); + core.info(`Base branch override requested: ${requestedBaseBranchForLog}`); + if (allowedBaseBranches.size === 0) { + core.warning(`Rejecting base branch override ${requestedBaseBranchForLog}: allowed-base-branches is not configured`); + return { + success: false, + error: "Base branch override is not allowed. Configure safe-outputs.create-pull-request.allowed-base-branches to allow per-run base overrides.", + }; + } + + const requestedBaseBranch = normalizeBranchName(requestedBaseBranchRaw); + if (!requestedBaseBranch) { + core.warning(`Rejecting base branch override ${requestedBaseBranchForLog}: sanitization resulted in empty branch name`); + return { + success: false, + error: `Invalid base branch override: sanitization resulted in empty string (original: "${requestedBaseBranchRaw}")`, + }; + } + if (requestedBaseBranchRaw !== requestedBaseBranch) { + core.warning(`Rejecting base branch override ${requestedBaseBranchForLog}: sanitized value '${requestedBaseBranch}' does not match original`); + return { + success: false, + error: `Invalid base branch override: contains invalid characters (original: "${requestedBaseBranchRaw}", normalized: "${requestedBaseBranch}")`, + }; + } + const requestedBaseBranchSafeForLog = JSON.stringify(requestedBaseBranch); + if (!isBaseBranchAllowed(requestedBaseBranch, allowedBaseBranches)) { + core.warning(`Rejecting base branch override ${requestedBaseBranchSafeForLog}: does not match allowed patterns (${Array.from(allowedBaseBranches).join(", ")})`); + return { + success: false, + error: `Base branch override '${requestedBaseBranch}' is not allowed. Allowed patterns: ${Array.from(allowedBaseBranches).join(", ")}`, + }; + } + + core.info(`Base branch override accepted: ${requestedBaseBranchSafeForLog}`); + baseBranch = requestedBaseBranch; + core.info(`Using agent-provided base branch override: ${baseBranch}`); + } + // Multi-repo support: Switch checkout to target repo if different from current // This enables creating PRs in multiple repos from a single workflow run if (checkoutManager && itemRepo) { diff --git a/actions/setup/js/create_pull_request.test.cjs b/actions/setup/js/create_pull_request.test.cjs index e6840627545..54e9f634b95 100644 --- a/actions/setup/js/create_pull_request.test.cjs +++ b/actions/setup/js/create_pull_request.test.cjs @@ -1223,6 +1223,114 @@ describe("create_pull_request - wildcard target-repo", () => { }); }); +describe("create_pull_request - base branch override policy", () => { + let tempDir; + let originalEnv; + + beforeEach(() => { + originalEnv = { ...process.env }; + process.env.GH_AW_WORKFLOW_ID = "test-workflow"; + process.env.GITHUB_REPOSITORY = "test-owner/test-repo"; + process.env.GITHUB_BASE_REF = "main"; + tempDir = fs.mkdtempSync(path.join(os.tmpdir(), "create-pr-base-override-test-")); + + global.core = { + info: vi.fn(), + warning: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + setFailed: vi.fn(), + setOutput: vi.fn(), + startGroup: vi.fn(), + endGroup: vi.fn(), + summary: { + addRaw: vi.fn().mockReturnThis(), + write: vi.fn().mockResolvedValue(undefined), + }, + }; + global.github = { + rest: { + pulls: { + create: vi.fn().mockResolvedValue({ data: { number: 100, html_url: "https://github.com/test-owner/test-repo/pull/100", node_id: "PR_100" } }), + requestReviewers: vi.fn().mockResolvedValue({}), + }, + repos: { + get: vi.fn().mockResolvedValue({ data: { default_branch: "main" } }), + }, + issues: { + addLabels: vi.fn().mockResolvedValue({}), + }, + }, + graphql: vi.fn(), + }; + global.context = { + eventName: "workflow_dispatch", + repo: { owner: "test-owner", repo: "test-repo" }, + payload: {}, + }; + global.exec = { + exec: vi.fn().mockResolvedValue(0), + getExecOutput: vi.fn().mockResolvedValue({ exitCode: 0, stdout: "", stderr: "" }), + }; + + delete require.cache[require.resolve("./create_pull_request.cjs")]; + }); + + afterEach(() => { + for (const key of Object.keys(process.env)) { + if (!(key in originalEnv)) { + delete process.env[key]; + } + } + Object.assign(process.env, originalEnv); + + if (tempDir && fs.existsSync(tempDir)) { + fs.rmSync(tempDir, { recursive: true, force: true }); + } + + delete global.core; + delete global.github; + delete global.context; + delete global.exec; + vi.clearAllMocks(); + }); + + it("should reject base override when allowed-base-branches is not configured", async () => { + const { main } = require("./create_pull_request.cjs"); + const handler = await main({ allow_empty: true }); + + const result = await handler({ title: "Test PR", body: "Test body", base: "release/1.0" }, {}); + + expect(result.success).toBe(false); + expect(result.error).toContain("Base branch override is not allowed"); + expect(global.core.warning).toHaveBeenCalledWith(expect.stringContaining("allowed-base-branches is not configured")); + }); + + it("should allow base override when it matches allowed-base-branches", async () => { + const { main } = require("./create_pull_request.cjs"); + const handler = await main({ allow_empty: true, allowed_base_branches: ["release/*", "main"] }); + + const result = await handler({ title: "Test PR", body: "Test body", base: "release/1.0" }, {}); + + expect(result.success).toBe(true); + expect(global.github.rest.pulls.create).toHaveBeenCalledWith(expect.objectContaining({ base: "release/1.0" })); + expect(global.core.info).toHaveBeenCalledWith(expect.stringContaining('Base branch override requested: "release/1.0"')); + expect(global.core.info).toHaveBeenCalledWith(expect.stringContaining('Base branch override accepted: "release/1.0"')); + expect(global.core.info).toHaveBeenCalledWith(expect.stringContaining("Using agent-provided base branch override: release/1.0")); + }); + + it("should reject base override when it does not match allowed-base-branches", async () => { + const { main } = require("./create_pull_request.cjs"); + const handler = await main({ allow_empty: true, allowed_base_branches: ["release/*"] }); + + const result = await handler({ title: "Test PR", body: "Test body", base: "main" }, {}); + + expect(result.success).toBe(false); + expect(result.error).toContain("Base branch override 'main' is not allowed"); + expect(global.core.warning).toHaveBeenCalledWith(expect.stringContaining("does not match allowed patterns")); + }); +}); + describe("create_pull_request - patch apply fallback to original base commit", () => { let tempDir; let originalEnv; diff --git a/actions/setup/js/safe_outputs_tools.json b/actions/setup/js/safe_outputs_tools.json index e914f7052b5..3d87db2ed1d 100644 --- a/actions/setup/js/safe_outputs_tools.json +++ b/actions/setup/js/safe_outputs_tools.json @@ -279,6 +279,10 @@ "type": "string", "description": "Source branch name containing the changes. If omitted, uses the current working branch." }, + "base": { + "type": "string", + "description": "Target base branch for the pull request. This override is only allowed when workflow configuration sets safe-outputs.create-pull-request.allowed-base-branches and the value matches one of those patterns." + }, "labels": { "type": "array", "items": { diff --git a/docs/src/content/docs/reference/safe-outputs-specification.md b/docs/src/content/docs/reference/safe-outputs-specification.md index fe623015ba5..9ce93b58a59 100644 --- a/docs/src/content/docs/reference/safe-outputs-specification.md +++ b/docs/src/content/docs/reference/safe-outputs-specification.md @@ -2148,6 +2148,7 @@ safe-outputs: "title": {"type": "string"}, "body": {"type": "string"}, "branch": {"type": "string", "description": "Source branch (defaults to current)"}, + "base": {"type": "string", "description": "Target base branch override (allowed only when configured by allowed-base-branches)"}, "labels": {"type": "array", "items": {"type": "string"}}, "draft": {"type": "boolean", "description": "Create as draft (default: true)"} }, @@ -2168,6 +2169,7 @@ safe-outputs: - `max`: Operation limit (default: 1) - `base-branch`: Target branch +- `allowed-base-branches`: Allowed base-branch override patterns for per-run `base` tool input - `draft`: Draft status - `commit-changes`: Auto-commit workspace - `reviewers`: Auto-request reviewers diff --git a/pkg/parser/schemas/main_workflow_schema.json b/pkg/parser/schemas/main_workflow_schema.json index 42abef8ec10..1c03ef1d411 100644 --- a/pkg/parser/schemas/main_workflow_schema.json +++ b/pkg/parser/schemas/main_workflow_schema.json @@ -5838,6 +5838,13 @@ "type": "string", "description": "Base branch for the pull request. Defaults to the workflow's branch (github.ref_name) if not specified. Useful for cross-repository PRs targeting non-default branches (e.g., 'vnext', 'release/v1.0')." }, + "allowed-base-branches": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Optional list of allowed base branch patterns (glob syntax, e.g. 'main', 'release/*'). When configured, the agent may provide a `base` field in create_pull_request output to override base-branch for a single run, but only if it matches one of these patterns." + }, "footer": { "type": "boolean", "description": "Controls whether AI-generated footer is added to the pull request. When false, the visible footer content is omitted but XML markers (workflow-id, tracker-id, metadata) are still included for searchability. Defaults to true.", diff --git a/pkg/workflow/compiler_safe_outputs_config_test.go b/pkg/workflow/compiler_safe_outputs_config_test.go index 6c59b462b6f..ad796125ee1 100644 --- a/pkg/workflow/compiler_safe_outputs_config_test.go +++ b/pkg/workflow/compiler_safe_outputs_config_test.go @@ -1384,10 +1384,13 @@ func TestAutoEnabledHandlers(t *testing.T) { // TestCreatePullRequestBaseBranch tests the base-branch field configuration func TestCreatePullRequestBaseBranch(t *testing.T) { tests := []struct { - name string - baseBranch string - expectedBaseBranch string - shouldHaveBaseBranchKey bool + name string + baseBranch string + allowedBaseBranches []string + expectedBaseBranch string + shouldHaveBaseBranchKey bool + expectedAllowedBaseBranches []string + shouldHaveAllowedBaseBranchesKey bool }{ { name: "custom base branch", @@ -1407,6 +1410,15 @@ func TestCreatePullRequestBaseBranch(t *testing.T) { expectedBaseBranch: "release/v1.0", shouldHaveBaseBranchKey: true, }, + { + name: "allowed base branches list", + baseBranch: "main", + allowedBaseBranches: []string{"release/*", "main"}, + expectedBaseBranch: "main", + shouldHaveBaseBranchKey: true, + expectedAllowedBaseBranches: []string{"release/*", "main"}, + shouldHaveAllowedBaseBranchesKey: true, + }, } for _, tt := range tests { @@ -1420,7 +1432,8 @@ func TestCreatePullRequestBaseBranch(t *testing.T) { BaseSafeOutputConfig: BaseSafeOutputConfig{ Max: strPtr("1"), }, - BaseBranch: tt.baseBranch, + BaseBranch: tt.baseBranch, + AllowedBaseBranches: tt.allowedBaseBranches, }, }, } @@ -1453,6 +1466,19 @@ func TestCreatePullRequestBaseBranch(t *testing.T) { } else { require.False(t, ok, "base_branch should NOT be in config when no custom value set") } + + allowedBaseBranches, ok := prConfig["allowed_base_branches"] + if tt.shouldHaveAllowedBaseBranchesKey { + require.True(t, ok, "allowed_base_branches should be in config") + allowedSlice, ok := allowedBaseBranches.([]any) + require.True(t, ok, "allowed_base_branches should be an array") + require.Len(t, allowedSlice, len(tt.expectedAllowedBaseBranches), "allowed_base_branches length should match") + for i, expected := range tt.expectedAllowedBaseBranches { + assert.Equal(t, expected, allowedSlice[i], "allowed_base_branches element should match") + } + } else { + require.False(t, ok, "allowed_base_branches should NOT be in config when no values set") + } } } } diff --git a/pkg/workflow/compiler_safe_outputs_handlers.go b/pkg/workflow/compiler_safe_outputs_handlers.go index a6ec04e4cb0..2b91e93edd1 100644 --- a/pkg/workflow/compiler_safe_outputs_handlers.go +++ b/pkg/workflow/compiler_safe_outputs_handlers.go @@ -376,6 +376,7 @@ var handlerRegistry = map[string]handlerBuilder{ AddIfPositive("expires", c.Expires). AddIfNotEmpty("target-repo", c.TargetRepoSlug). AddStringSlice("allowed_repos", c.AllowedRepos). + AddStringSlice("allowed_base_branches", c.AllowedBaseBranches). AddDefault("max_patch_size", maxPatchSize). AddIfNotEmpty("github-token", c.GitHubToken). AddTemplatableBool("footer", getEffectiveFooterForTemplatable(c.Footer, cfg.Footer)). diff --git a/pkg/workflow/create_pull_request.go b/pkg/workflow/create_pull_request.go index 270da2522b4..caf8b1371a0 100644 --- a/pkg/workflow/create_pull_request.go +++ b/pkg/workflow/create_pull_request.go @@ -27,6 +27,7 @@ type CreatePullRequestsConfig struct { AllowEmpty *string `yaml:"allow-empty,omitempty"` // Allow creating PR without patch file or with empty patch (useful for preparing feature branches) TargetRepoSlug string `yaml:"target-repo,omitempty"` // Target repository in format "owner/repo" for cross-repository pull requests AllowedRepos []string `yaml:"allowed-repos,omitempty"` // List of additional repositories that pull requests can be created in (additionally to the target-repo) + AllowedBaseBranches []string `yaml:"allowed-base-branches,omitempty"` // List of allowed base branch globs (e.g. "release/*"). Enables agent-provided `base` override when configured. Expires int `yaml:"expires,omitempty"` // Hours until the pull request expires and should be automatically closed (only for same-repo PRs) AutoMerge *string `yaml:"auto-merge,omitempty"` // Enable auto-merge for the pull request when all required checks pass BaseBranch string `yaml:"base-branch,omitempty"` // Base branch for the pull request (defaults to github.ref_name if not specified) diff --git a/pkg/workflow/js/safe_outputs_tools.json b/pkg/workflow/js/safe_outputs_tools.json index 02bb7c0b5a9..cfdcb698220 100644 --- a/pkg/workflow/js/safe_outputs_tools.json +++ b/pkg/workflow/js/safe_outputs_tools.json @@ -316,6 +316,10 @@ "type": "string", "description": "Source branch name containing the changes. If omitted, uses the current working branch." }, + "base": { + "type": "string", + "description": "Target base branch for the pull request. This override is only allowed when workflow configuration sets safe-outputs.create-pull-request.allowed-base-branches and the value matches one of those patterns." + }, "labels": { "type": "array", "items": { diff --git a/pkg/workflow/safe_output_validation_config_test.go b/pkg/workflow/safe_output_validation_config_test.go index b0a0ca16aa2..6e08c8bf6a0 100644 --- a/pkg/workflow/safe_output_validation_config_test.go +++ b/pkg/workflow/safe_output_validation_config_test.go @@ -217,3 +217,19 @@ func TestValidationConfigConsistency(t *testing.T) { } } } + +func TestCreatePullRequestBaseValidationMaxLength(t *testing.T) { + config, ok := ValidationConfig["create_pull_request"] + if !ok { + t.Fatal("create_pull_request not found in ValidationConfig") + } + + baseField, ok := config.Fields["base"] + if !ok { + t.Fatal("base field not found in create_pull_request validation config") + } + + if baseField.MaxLength != 128 { + t.Errorf("base field MaxLength = %d, want 128", baseField.MaxLength) + } +} diff --git a/pkg/workflow/safe_outputs_validation_config.go b/pkg/workflow/safe_outputs_validation_config.go index e13a3871feb..dfed0a7b1c4 100644 --- a/pkg/workflow/safe_outputs_validation_config.go +++ b/pkg/workflow/safe_outputs_validation_config.go @@ -76,6 +76,7 @@ var ValidationConfig = map[string]TypeValidationConfig{ "title": {Required: true, Type: "string", Sanitize: true, MaxLength: 128}, "body": {Required: true, Type: "string", Sanitize: true, MaxLength: MaxBodyLength}, "branch": {Required: true, Type: "string", Sanitize: true, MaxLength: 256}, + "base": {Type: "string", Sanitize: true, MaxLength: 128}, "labels": {Type: "array", ItemType: "string", ItemSanitize: true, ItemMaxLength: 128}, "draft": {Type: "boolean"}, "repo": {Type: "string", MaxLength: 256}, // Optional: target repository in format "owner/repo"