Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions actions/setup/js/create_pull_request.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const { isStagedMode } = require("./safe_output_helpers.cjs");
const { withRetry, isTransientError } = require("./error_recovery.cjs");
const { tryEnforceArrayLimit } = require("./limit_enforcement_helpers.cjs");
const { findAgent, getIssueDetails, assignAgentToIssue } = require("./assign_agent_helpers.cjs");
const { globPatternToRegex } = require("./glob_pattern_helpers.cjs");

/**
* @typedef {import('./types/handler-factory').HandlerFactoryFunction} HandlerFactoryFunction
Expand Down Expand Up @@ -88,6 +89,50 @@ const LABEL_MAX_RETRIES = 3;
/** @type {number} Initial delay in ms before the first label retry (3 seconds) */
const LABEL_INITIAL_DELAY_MS = 3000;

/**
* Parse allowed base branch patterns from config value (array or comma-separated string)
* @param {string[]|string|undefined} allowedBaseBranchesValue
* @returns {Set<string>}
*/
function parseAllowedBaseBranches(allowedBaseBranchesValue) {
const set = new Set();
if (Array.isArray(allowedBaseBranchesValue)) {
allowedBaseBranchesValue
.map(branch => String(branch).trim())
.filter(Boolean)
.forEach(branch => set.add(branch));
} else if (typeof allowedBaseBranchesValue === "string") {
allowedBaseBranchesValue
.split(",")
.map(branch => branch.trim())
.filter(Boolean)
.forEach(branch => set.add(branch));
}
return set;
}

/**
* Check if a base branch matches an allowed pattern.
* Supports exact matches and "*" glob patterns (e.g. "release/*").
* @param {string} baseBranch
* @param {Set<string>} allowedBaseBranches
* @returns {boolean}
*/
function isBaseBranchAllowed(baseBranch, allowedBaseBranches) {
if (allowedBaseBranches.has(baseBranch)) {
return true;
}
for (const pattern of allowedBaseBranches) {
if (pattern === "*") {
return true;
}
if (pattern.includes("*") && globPatternToRegex(pattern, { pathMode: true, caseSensitive: true }).test(baseBranch)) {
return true;
}
}
return false;
}

/**
* Merges the required fallback label with any workflow-configured labels,
* deduplicating and filtering empty values.
Expand Down Expand Up @@ -250,6 +295,7 @@ async function main(config = {}) {
const maxCount = config.max || 1; // PRs are typically limited to 1
const maxSizeKb = config.max_patch_size ? parseInt(String(config.max_patch_size), 10) : 1024;
const { defaultTargetRepo, allowedRepos } = resolveTargetRepoConfig(config);
const allowedBaseBranches = parseAllowedBaseBranches(config.allowed_base_branches);
const githubClient = await createAuthenticatedGitHubClient(config);

// Check if copilot assignment is enabled for fallback issues
Expand Down Expand Up @@ -350,6 +396,9 @@ async function main(config = {}) {
if (allowedRepos.size > 0) {
core.info(`Allowed repos: ${Array.from(allowedRepos).join(", ")}`);
}
if (allowedBaseBranches.size > 0) {
core.info(`Allowed base branches: ${Array.from(allowedBaseBranches).join(", ")}`);
}
if (envLabels.length > 0) {
core.info(`Default labels: ${envLabels.join(", ")}`);
}
Expand Down Expand Up @@ -444,6 +493,49 @@ async function main(config = {}) {
// NOTE: Must be resolved before checkout so cross-repo checkout uses the correct branch
let baseBranch = configBaseBranch || (await getBaseBranch(repoParts));

// Optional agent-provided base branch override.
// This is only allowed when allowed_base_branches is configured.
if (typeof pullRequestItem.base === "string" && pullRequestItem.base.trim() !== "") {
const requestedBaseBranchRaw = pullRequestItem.base.trim();
const requestedBaseBranchForLog = JSON.stringify(requestedBaseBranchRaw);
core.info(`Base branch override requested: ${requestedBaseBranchForLog}`);
if (allowedBaseBranches.size === 0) {
core.warning(`Rejecting base branch override ${requestedBaseBranchForLog}: allowed-base-branches is not configured`);
return {
success: false,
error: "Base branch override is not allowed. Configure safe-outputs.create-pull-request.allowed-base-branches to allow per-run base overrides.",
};
}

const requestedBaseBranch = normalizeBranchName(requestedBaseBranchRaw);
if (!requestedBaseBranch) {
core.warning(`Rejecting base branch override ${requestedBaseBranchForLog}: sanitization resulted in empty branch name`);
return {
success: false,
error: `Invalid base branch override: sanitization resulted in empty string (original: "${requestedBaseBranchRaw}")`,
};
}
if (requestedBaseBranchRaw !== requestedBaseBranch) {
core.warning(`Rejecting base branch override ${requestedBaseBranchForLog}: sanitized value '${requestedBaseBranch}' does not match original`);
return {
success: false,
error: `Invalid base branch override: contains invalid characters (original: "${requestedBaseBranchRaw}", normalized: "${requestedBaseBranch}")`,
};
}
const requestedBaseBranchSafeForLog = JSON.stringify(requestedBaseBranch);
if (!isBaseBranchAllowed(requestedBaseBranch, allowedBaseBranches)) {
core.warning(`Rejecting base branch override ${requestedBaseBranchSafeForLog}: does not match allowed patterns (${Array.from(allowedBaseBranches).join(", ")})`);
return {
success: false,
error: `Base branch override '${requestedBaseBranch}' is not allowed. Allowed patterns: ${Array.from(allowedBaseBranches).join(", ")}`,
};
}

core.info(`Base branch override accepted: ${requestedBaseBranchSafeForLog}`);
baseBranch = requestedBaseBranch;
core.info(`Using agent-provided base branch override: ${baseBranch}`);
}

// Multi-repo support: Switch checkout to target repo if different from current
// This enables creating PRs in multiple repos from a single workflow run
if (checkoutManager && itemRepo) {
Expand Down
108 changes: 108 additions & 0 deletions actions/setup/js/create_pull_request.test.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,114 @@ describe("create_pull_request - wildcard target-repo", () => {
});
});

describe("create_pull_request - base branch override policy", () => {
let tempDir;
let originalEnv;

beforeEach(() => {
originalEnv = { ...process.env };
process.env.GH_AW_WORKFLOW_ID = "test-workflow";
process.env.GITHUB_REPOSITORY = "test-owner/test-repo";
process.env.GITHUB_BASE_REF = "main";
tempDir = fs.mkdtempSync(path.join(os.tmpdir(), "create-pr-base-override-test-"));

global.core = {
info: vi.fn(),
warning: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
setFailed: vi.fn(),
setOutput: vi.fn(),
startGroup: vi.fn(),
endGroup: vi.fn(),
summary: {
addRaw: vi.fn().mockReturnThis(),
write: vi.fn().mockResolvedValue(undefined),
},
};
global.github = {
rest: {
pulls: {
create: vi.fn().mockResolvedValue({ data: { number: 100, html_url: "https://github.com/test-owner/test-repo/pull/100", node_id: "PR_100" } }),
requestReviewers: vi.fn().mockResolvedValue({}),
},
repos: {
get: vi.fn().mockResolvedValue({ data: { default_branch: "main" } }),
},
issues: {
addLabels: vi.fn().mockResolvedValue({}),
},
},
graphql: vi.fn(),
};
global.context = {
eventName: "workflow_dispatch",
repo: { owner: "test-owner", repo: "test-repo" },
payload: {},
};
global.exec = {
exec: vi.fn().mockResolvedValue(0),
getExecOutput: vi.fn().mockResolvedValue({ exitCode: 0, stdout: "", stderr: "" }),
};

delete require.cache[require.resolve("./create_pull_request.cjs")];
});

afterEach(() => {
for (const key of Object.keys(process.env)) {
if (!(key in originalEnv)) {
delete process.env[key];
}
}
Object.assign(process.env, originalEnv);

if (tempDir && fs.existsSync(tempDir)) {
fs.rmSync(tempDir, { recursive: true, force: true });
}

delete global.core;
delete global.github;
delete global.context;
delete global.exec;
vi.clearAllMocks();
});

it("should reject base override when allowed-base-branches is not configured", async () => {
const { main } = require("./create_pull_request.cjs");
const handler = await main({ allow_empty: true });

const result = await handler({ title: "Test PR", body: "Test body", base: "release/1.0" }, {});

expect(result.success).toBe(false);
expect(result.error).toContain("Base branch override is not allowed");
expect(global.core.warning).toHaveBeenCalledWith(expect.stringContaining("allowed-base-branches is not configured"));
});

it("should allow base override when it matches allowed-base-branches", async () => {
const { main } = require("./create_pull_request.cjs");
const handler = await main({ allow_empty: true, allowed_base_branches: ["release/*", "main"] });

const result = await handler({ title: "Test PR", body: "Test body", base: "release/1.0" }, {});

expect(result.success).toBe(true);
expect(global.github.rest.pulls.create).toHaveBeenCalledWith(expect.objectContaining({ base: "release/1.0" }));
expect(global.core.info).toHaveBeenCalledWith(expect.stringContaining('Base branch override requested: "release/1.0"'));
expect(global.core.info).toHaveBeenCalledWith(expect.stringContaining('Base branch override accepted: "release/1.0"'));
expect(global.core.info).toHaveBeenCalledWith(expect.stringContaining("Using agent-provided base branch override: release/1.0"));
});

it("should reject base override when it does not match allowed-base-branches", async () => {
const { main } = require("./create_pull_request.cjs");
const handler = await main({ allow_empty: true, allowed_base_branches: ["release/*"] });

const result = await handler({ title: "Test PR", body: "Test body", base: "main" }, {});

expect(result.success).toBe(false);
expect(result.error).toContain("Base branch override 'main' is not allowed");
expect(global.core.warning).toHaveBeenCalledWith(expect.stringContaining("does not match allowed patterns"));
});
});

describe("create_pull_request - patch apply fallback to original base commit", () => {
let tempDir;
let originalEnv;
Expand Down
4 changes: 4 additions & 0 deletions actions/setup/js/safe_outputs_tools.json
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@
"type": "string",
"description": "Source branch name containing the changes. If omitted, uses the current working branch."
},
"base": {
"type": "string",
"description": "Target base branch for the pull request. This override is only allowed when workflow configuration sets safe-outputs.create-pull-request.allowed-base-branches and the value matches one of those patterns."
},
"labels": {
"type": "array",
"items": {
Expand Down
2 changes: 2 additions & 0 deletions docs/src/content/docs/reference/safe-outputs-specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,7 @@ safe-outputs:
"title": {"type": "string"},
"body": {"type": "string"},
"branch": {"type": "string", "description": "Source branch (defaults to current)"},
"base": {"type": "string", "description": "Target base branch override (allowed only when configured by allowed-base-branches)"},
"labels": {"type": "array", "items": {"type": "string"}},
"draft": {"type": "boolean", "description": "Create as draft (default: true)"}
},
Expand All @@ -2168,6 +2169,7 @@ safe-outputs:

- `max`: Operation limit (default: 1)
- `base-branch`: Target branch
- `allowed-base-branches`: Allowed base-branch override patterns for per-run `base` tool input
- `draft`: Draft status
- `commit-changes`: Auto-commit workspace
- `reviewers`: Auto-request reviewers
Expand Down
7 changes: 7 additions & 0 deletions pkg/parser/schemas/main_workflow_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -5838,6 +5838,13 @@
"type": "string",
"description": "Base branch for the pull request. Defaults to the workflow's branch (github.ref_name) if not specified. Useful for cross-repository PRs targeting non-default branches (e.g., 'vnext', 'release/v1.0')."
},
"allowed-base-branches": {
"type": "array",
"items": {
"type": "string"
},
"description": "Optional list of allowed base branch patterns (glob syntax, e.g. 'main', 'release/*'). When configured, the agent may provide a `base` field in create_pull_request output to override base-branch for a single run, but only if it matches one of these patterns."
},
"footer": {
"type": "boolean",
"description": "Controls whether AI-generated footer is added to the pull request. When false, the visible footer content is omitted but XML markers (workflow-id, tracker-id, metadata) are still included for searchability. Defaults to true.",
Expand Down
36 changes: 31 additions & 5 deletions pkg/workflow/compiler_safe_outputs_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1384,10 +1384,13 @@ func TestAutoEnabledHandlers(t *testing.T) {
// TestCreatePullRequestBaseBranch tests the base-branch field configuration
func TestCreatePullRequestBaseBranch(t *testing.T) {
tests := []struct {
name string
baseBranch string
expectedBaseBranch string
shouldHaveBaseBranchKey bool
name string
baseBranch string
allowedBaseBranches []string
expectedBaseBranch string
shouldHaveBaseBranchKey bool
expectedAllowedBaseBranches []string
shouldHaveAllowedBaseBranchesKey bool
}{
{
name: "custom base branch",
Expand All @@ -1407,6 +1410,15 @@ func TestCreatePullRequestBaseBranch(t *testing.T) {
expectedBaseBranch: "release/v1.0",
shouldHaveBaseBranchKey: true,
},
{
name: "allowed base branches list",
baseBranch: "main",
allowedBaseBranches: []string{"release/*", "main"},
expectedBaseBranch: "main",
shouldHaveBaseBranchKey: true,
expectedAllowedBaseBranches: []string{"release/*", "main"},
shouldHaveAllowedBaseBranchesKey: true,
},
}

for _, tt := range tests {
Expand All @@ -1420,7 +1432,8 @@ func TestCreatePullRequestBaseBranch(t *testing.T) {
BaseSafeOutputConfig: BaseSafeOutputConfig{
Max: strPtr("1"),
},
BaseBranch: tt.baseBranch,
BaseBranch: tt.baseBranch,
AllowedBaseBranches: tt.allowedBaseBranches,
},
},
}
Expand Down Expand Up @@ -1453,6 +1466,19 @@ func TestCreatePullRequestBaseBranch(t *testing.T) {
} else {
require.False(t, ok, "base_branch should NOT be in config when no custom value set")
}

allowedBaseBranches, ok := prConfig["allowed_base_branches"]
if tt.shouldHaveAllowedBaseBranchesKey {
require.True(t, ok, "allowed_base_branches should be in config")
allowedSlice, ok := allowedBaseBranches.([]any)
require.True(t, ok, "allowed_base_branches should be an array")
require.Len(t, allowedSlice, len(tt.expectedAllowedBaseBranches), "allowed_base_branches length should match")
for i, expected := range tt.expectedAllowedBaseBranches {
assert.Equal(t, expected, allowedSlice[i], "allowed_base_branches element should match")
}
} else {
require.False(t, ok, "allowed_base_branches should NOT be in config when no values set")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/workflow/compiler_safe_outputs_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ var handlerRegistry = map[string]handlerBuilder{
AddIfPositive("expires", c.Expires).
AddIfNotEmpty("target-repo", c.TargetRepoSlug).
AddStringSlice("allowed_repos", c.AllowedRepos).
AddStringSlice("allowed_base_branches", c.AllowedBaseBranches).
AddDefault("max_patch_size", maxPatchSize).
AddIfNotEmpty("github-token", c.GitHubToken).
AddTemplatableBool("footer", getEffectiveFooterForTemplatable(c.Footer, cfg.Footer)).
Expand Down
1 change: 1 addition & 0 deletions pkg/workflow/create_pull_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type CreatePullRequestsConfig struct {
AllowEmpty *string `yaml:"allow-empty,omitempty"` // Allow creating PR without patch file or with empty patch (useful for preparing feature branches)
TargetRepoSlug string `yaml:"target-repo,omitempty"` // Target repository in format "owner/repo" for cross-repository pull requests
AllowedRepos []string `yaml:"allowed-repos,omitempty"` // List of additional repositories that pull requests can be created in (additionally to the target-repo)
AllowedBaseBranches []string `yaml:"allowed-base-branches,omitempty"` // List of allowed base branch globs (e.g. "release/*"). Enables agent-provided `base` override when configured.
Expires int `yaml:"expires,omitempty"` // Hours until the pull request expires and should be automatically closed (only for same-repo PRs)
AutoMerge *string `yaml:"auto-merge,omitempty"` // Enable auto-merge for the pull request when all required checks pass
BaseBranch string `yaml:"base-branch,omitempty"` // Base branch for the pull request (defaults to github.ref_name if not specified)
Expand Down
4 changes: 4 additions & 0 deletions pkg/workflow/js/safe_outputs_tools.json
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@
"type": "string",
"description": "Source branch name containing the changes. If omitted, uses the current working branch."
},
"base": {
"type": "string",
"description": "Target base branch for the pull request. This override is only allowed when workflow configuration sets safe-outputs.create-pull-request.allowed-base-branches and the value matches one of those patterns."
},
"labels": {
"type": "array",
"items": {
Expand Down
16 changes: 16 additions & 0 deletions pkg/workflow/safe_output_validation_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,19 @@ func TestValidationConfigConsistency(t *testing.T) {
}
}
}

func TestCreatePullRequestBaseValidationMaxLength(t *testing.T) {
config, ok := ValidationConfig["create_pull_request"]
if !ok {
t.Fatal("create_pull_request not found in ValidationConfig")
}

baseField, ok := config.Fields["base"]
if !ok {
t.Fatal("base field not found in create_pull_request validation config")
}

if baseField.MaxLength != 128 {
t.Errorf("base field MaxLength = %d, want 128", baseField.MaxLength)
}
}
Loading