From 615d448d8f5f23fdcf4d8c007c94f974b60ff8c5 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 03:37:45 +0000 Subject: [PATCH 01/22] Try to parallelize the unit tests --- .github/actions/run-command-on-pr/action.yml | 30 +++- .../run-sharded-cloud-build/action.yml | 143 ++++++++++++++++++ .github/workflows/on-pr-comment.yml | 25 +-- .github/workflows/on-pr-merge.yml | 26 ++-- Makefile | 21 +++ gigl/common/utils/test_utils.py | 106 +++++++++++-- tests/integration/main.py | 15 +- tests/unit/common/utils/test_sharding_test.py | 126 +++++++++++++++ tests/unit/main.py | 15 +- 9 files changed, 468 insertions(+), 39 deletions(-) create mode 100644 .github/actions/run-sharded-cloud-build/action.yml create mode 100644 tests/unit/common/utils/test_sharding_test.py diff --git a/.github/actions/run-command-on-pr/action.yml b/.github/actions/run-command-on-pr/action.yml index c20bdb4d7..5f8781bb0 100644 --- a/.github/actions/run-command-on-pr/action.yml +++ b/.github/actions/run-command-on-pr/action.yml @@ -48,6 +48,24 @@ inputs: required: false default: "" + # Parameters for sharded Cloud Build execution + use_sharded_cloud_build: + description: "Whether to use sharded Cloud Build instead of a single Cloud Build job" + required: false + default: "false" + total_shards: + description: "Number of parallel test shards (only used when use_sharded_cloud_build is true)" + required: false + default: "4" + sharded_test_command: + description: "The Make target to run for each shard (e.g. 'unit_test_py_shard')" + required: false + default: "" + run_type_check: + description: "Whether to launch a separate type-check Cloud Build job (only used when use_sharded_cloud_build is true)" + required: false + default: "false" + runs: using: "composite" steps: @@ -92,13 +110,23 @@ runs: run: ${{ inputs.command }} - name: Run specified command on the PR branch using Cloud Run - if: ${{ inputs.use_cloud_run == 'true' }} + if: ${{ inputs.use_cloud_run == 'true' && inputs.use_sharded_cloud_build != 'true' }} uses: snapchat/gigl/.github/actions/run-cloud-run-command-on-active-checkout@main with: cmd: ${{ inputs.command }} service_account: ${{ inputs.gcp_service_account_email }} project: ${{ inputs.gcp_project_id }} + - name: Run sharded Cloud Build tests + if: ${{ inputs.use_sharded_cloud_build == 'true' }} + uses: ./.github/actions/run-sharded-cloud-build + with: + test_command: ${{ inputs.sharded_test_command }} + total_shards: ${{ inputs.total_shards }} + run_type_check: ${{ inputs.run_type_check }} + service_account: ${{ inputs.gcp_service_account_email }} + project: ${{ inputs.gcp_project_id }} + - name: Commment workflow succeeded if: ${{ inputs.should_leave_progress_comments == 'true' }} uses: snapchat/gigl/.github/actions/comment-on-pr@main diff --git a/.github/actions/run-sharded-cloud-build/action.yml b/.github/actions/run-sharded-cloud-build/action.yml new file mode 100644 index 000000000..ec5b2c511 --- /dev/null +++ b/.github/actions/run-sharded-cloud-build/action.yml @@ -0,0 +1,143 @@ +name: "Run Sharded Cloud Build" +description: "Launches N Cloud Build shard jobs (and optionally a type-check job) in parallel, then waits for all to complete." +inputs: + test_command: + description: "The Make target to run for each shard (e.g. 'unit_test_py_shard')" + required: true + total_shards: + description: "Number of parallel test shards" + required: true + run_type_check: + description: "Whether to launch a separate type-check Cloud Build job" + required: false + default: "false" + service_account: + description: "Service account email for Cloud Build" + required: true + project: + description: "Google Cloud Project ID" + required: true + machine_type: + description: "Machine type for Cloud Build jobs" + default: "e2-highcpu-32" + timeout: + description: "Timeout for each Cloud Build job (duration format, e.g. '3h')" + default: "3h" + +runs: + using: "composite" + steps: + - name: Launch sharded Cloud Build jobs + id: launch_builds + shell: bash + run: | + set -euo pipefail + + BUILD_IDS="" + BUILD_LABELS="" + TOTAL_SHARDS=${{ inputs.total_shards }} + + # Launch type-check job if enabled + if [[ "${{ inputs.run_type_check }}" == "true" ]]; then + echo "Launching type-check Cloud Build job..." + TYPE_CHECK_BUILD_ID=$(gcloud builds submit . \ + --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ + --substitutions=_CMD="make type_check_only" \ + --service-account="projects/${{ inputs.project }}/serviceAccounts/${{ inputs.service_account }}" \ + --project="${{ inputs.project }}" \ + --machine-type="${{ inputs.machine_type }}" \ + --timeout="${{ inputs.timeout }}" \ + --async \ + --format='value(id)' 2>&1 | tail -1) + BUILD_IDS="${TYPE_CHECK_BUILD_ID}" + BUILD_LABELS="type_check" + echo " Type-check build ID: ${TYPE_CHECK_BUILD_ID}" + fi + + # Launch shard jobs + for i in $(seq 0 $(( TOTAL_SHARDS - 1 ))); do + echo "Launching shard ${i}/${TOTAL_SHARDS} Cloud Build job..." + SHARD_BUILD_ID=$(gcloud builds submit . \ + --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ + --substitutions=_CMD="make ${{ inputs.test_command }} SHARD_INDEX=${i} TOTAL_SHARDS=${TOTAL_SHARDS}" \ + --service-account="projects/${{ inputs.project }}/serviceAccounts/${{ inputs.service_account }}" \ + --project="${{ inputs.project }}" \ + --machine-type="${{ inputs.machine_type }}" \ + --timeout="${{ inputs.timeout }}" \ + --async \ + --format='value(id)' 2>&1 | tail -1) + if [[ -n "${BUILD_IDS}" ]]; then + BUILD_IDS="${BUILD_IDS},${SHARD_BUILD_ID}" + BUILD_LABELS="${BUILD_LABELS},shard_${i}" + else + BUILD_IDS="${SHARD_BUILD_ID}" + BUILD_LABELS="shard_${i}" + fi + echo " Shard ${i} build ID: ${SHARD_BUILD_ID}" + done + + echo "build_ids=${BUILD_IDS}" >> $GITHUB_OUTPUT + echo "build_labels=${BUILD_LABELS}" >> $GITHUB_OUTPUT + echo "All Cloud Build jobs launched." + + - name: Wait for all Cloud Build jobs to complete + shell: bash + run: | + set -euo pipefail + + IFS=',' read -ra BUILD_IDS <<< "${{ steps.launch_builds.outputs.build_ids }}" + IFS=',' read -ra BUILD_LABELS <<< "${{ steps.launch_builds.outputs.build_labels }}" + + PIDS=() + TEMP_DIR=$(mktemp -d) + + # Stream logs for each build in background and record exit status + for idx in "${!BUILD_IDS[@]}"; do + BUILD_ID="${BUILD_IDS[$idx]}" + LABEL="${BUILD_LABELS[$idx]}" + ( + echo "=== Streaming logs for ${LABEL} (${BUILD_ID}) ===" + gcloud builds log --stream "${BUILD_ID}" --project="${{ inputs.project }}" 2>&1 | \ + sed "s/^/[${LABEL}] /" || true + + # Check final build status + STATUS=$(gcloud builds describe "${BUILD_ID}" \ + --project="${{ inputs.project }}" \ + --format='value(status)') + echo "=== ${LABEL} (${BUILD_ID}) finished with status: ${STATUS} ===" + if [[ "${STATUS}" != "SUCCESS" ]]; then + echo "FAILED" > "${TEMP_DIR}/${LABEL}" + else + echo "SUCCESS" > "${TEMP_DIR}/${LABEL}" + fi + ) & + PIDS+=($!) + done + + # Wait for all background processes + for PID in "${PIDS[@]}"; do + wait "${PID}" || true + done + + # Check results + FAILED=0 + for idx in "${!BUILD_LABELS[@]}"; do + LABEL="${BUILD_LABELS[$idx]}" + BUILD_ID="${BUILD_IDS[$idx]}" + RESULT=$(cat "${TEMP_DIR}/${LABEL}" 2>/dev/null || echo "UNKNOWN") + if [[ "${RESULT}" != "SUCCESS" ]]; then + echo "FAILED: ${LABEL} (${BUILD_ID}) - status: ${RESULT}" + FAILED=1 + else + echo "PASSED: ${LABEL} (${BUILD_ID})" + fi + done + + rm -rf "${TEMP_DIR}" + + if [[ "${FAILED}" -eq 1 ]]; then + echo "One or more Cloud Build jobs failed." + exit 1 + fi + + echo "All Cloud Build jobs completed successfully." diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 870c415fe..49a7b9d68 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -3,6 +3,7 @@ name: On Demand Pr Comment Workflows on: issue_comment: types: [created] + push: permissions: # Needed for gcloud auth: https://github.com/google-github-actions/auth @@ -42,11 +43,10 @@ jobs: unit-test-python: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest - # TODO(kmonte): Reduce this :( - timeout-minutes: 120 + timeout-minutes: 60 steps: - name: Run Python Unit Tests - uses: snapchat/gigl/.github/actions/run-command-on-pr@main + uses: ./.github/actions/run-sharded-cloud-build with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} @@ -57,12 +57,13 @@ jobs: # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand # how to leverage Workload Identity Federation to read assets from GCS, et al. See: # https://github.com/tensorflow/tensorflow/issues/57104 - use_cloud_run: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - command: | - make unit_test_py + use_sharded_cloud_build: "true" + total_shards: "4" + sharded_test_command: "unit_test_py_shard" + run_type_check: "true" unit-test-scala: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_scala') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} @@ -87,23 +88,23 @@ jobs: integration-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest - # TODO(kmonte): Reduce this :( - timeout-minutes: 120 + timeout-minutes: 60 steps: - name: Run Integration Tests - uses: snapchat/gigl/.github/actions/run-command-on-pr@main + uses: ./.github/actions/run-sharded-cloud-build with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} should_leave_progress_comments: "true" descriptive_workflow_name: "Integration Test" setup_gcloud: "true" - use_cloud_run: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - command: | - make integration_test + use_sharded_cloud_build: "true" + total_shards: "4" + sharded_test_command: "integration_test_shard" + run_type_check: "false" integration-e2e-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/e2e_test') || contains(github.event.comment.body, '/all_test')) }} diff --git a/.github/workflows/on-pr-merge.yml b/.github/workflows/on-pr-merge.yml index 0e1f9ddd0..02e7e75db 100644 --- a/.github/workflows/on-pr-merge.yml +++ b/.github/workflows/on-pr-merge.yml @@ -32,16 +32,18 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.workload_identity_provider }} gcp_service_account_email: ${{ secrets.gcp_service_account_email }} - - name: Run Python Unit Tests - # We use cloud run here instead of using github hosted runners because of limitation of tests + - name: Run Python Unit Tests (sharded) + # We use Cloud Build instead of GitHub hosted runners because of limitation of tests # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand # how to leverage Workload Identity Federation to read assets from GCS, et al. See: # https://github.com/tensorflow/tensorflow/issues/57104 - uses: ./.github/actions/run-cloud-run-command-on-active-checkout + uses: ./.github/actions/run-sharded-cloud-build with: - cmd: "make unit_test_py" - service_account: ${{ secrets.gcp_service_account_email }} - project: ${{ vars.GCP_PROJECT_ID }} + test_command: "unit_test_py_shard" + total_shards: "4" + run_type_check: "true" + service_account: ${{ secrets.gcp_service_account_email }} + project: ${{ vars.GCP_PROJECT_ID }} ci-unit-test-scala: # Because of limitation discussed https://github.com/orgs/community/discussions/46757#discussioncomment-4912738 @@ -82,12 +84,14 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.workload_identity_provider }} gcp_service_account_email: ${{ secrets.gcp_service_account_email }} - - name: Run Integration Tests - uses: ./.github/actions/run-cloud-run-command-on-active-checkout + - name: Run Integration Tests (sharded) + uses: ./.github/actions/run-sharded-cloud-build with: - cmd: "make integration_test" - service_account: ${{ secrets.gcp_service_account_email }} - project: ${{ vars.GCP_PROJECT_ID }} + test_command: "integration_test_shard" + total_shards: "4" + run_type_check: "false" + service_account: ${{ secrets.gcp_service_account_email }} + project: ${{ vars.GCP_PROJECT_ID }} ci-integration-e2e-test: if: github.event_name == 'merge_group' diff --git a/Makefile b/Makefile index e15a063f3..3eb13e25b 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,8 @@ DOCKER_IMAGE_DEV_WORKBENCH_NAME_WITH_TAG?=${DOCKER_IMAGE_DEV_WORKBENCH_NAME}:${D PYTHON_DIRS:=.github/scripts examples gigl tests snapchat scripts PY_TEST_FILES?="*_test.py" +SHARD_INDEX?=0 +TOTAL_SHARDS?=0 # You can override GIGL_TEST_DEFAULT_RESOURCE_CONFIG by setting it in your environment i.e. # adding `export GIGL_TEST_DEFAULT_RESOURCE_CONFIG=your_resource_config` to your shell config (~/.bashrc, ~/.zshrc, etc.) GIGL_TEST_DEFAULT_RESOURCE_CONFIG?=${PWD}/deployment/configs/unittest_resource_config.yaml @@ -81,6 +83,17 @@ unit_test_py: clean_build_files_py type_check --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ --test_file_pattern=$(PY_TEST_FILES) \ +# Runs only the type checker without tests. Used as a standalone Cloud Build shard job. +type_check_only: clean_build_files_py type_check + +# Runs a single shard of the Python unit tests (no type checking). +# Usage: make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4 +unit_test_py_shard: clean_build_files_py + uv run python -m tests.unit.main \ + --env=test \ + --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ + --test_file_pattern=$(PY_TEST_FILES) \ + --shard_index=$(SHARD_INDEX) --total_shards=$(TOTAL_SHARDS) unit_test_scala: clean_build_files_scala ( cd scala; sbt test ) @@ -121,6 +134,14 @@ integration_test: --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ --test_file_pattern=$(PY_TEST_FILES) \ +# Runs a single shard of the integration tests. +# Usage: make integration_test_shard SHARD_INDEX=0 TOTAL_SHARDS=4 +integration_test_shard: clean_build_files_py + uv run python -m tests.integration.main \ + --env=test \ + --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ + --test_file_pattern=$(PY_TEST_FILES) \ + --shard_index=$(SHARD_INDEX) --total_shards=$(TOTAL_SHARDS) notebooks_test: RESOURCE_CONFIG_PATH=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} python -m tests.config_tests.notebooks_test diff --git a/gigl/common/utils/test_utils.py b/gigl/common/utils/test_utils.py index ff7c7dc14..4314d864b 100644 --- a/gigl/common/utils/test_utils.py +++ b/gigl/common/utils/test_utils.py @@ -1,4 +1,5 @@ import argparse +import hashlib import time import unittest from concurrent.futures import ProcessPoolExecutor @@ -14,18 +15,22 @@ @dataclass(frozen=True) class TestArgs: - """Container for CLI arguements to Python tests. + """Container for CLI arguments to Python tests. Attributes: - test_file_pattern (str): Glob pattern for filtering which test files to run. + test_file_pattern: Glob pattern for filtering which test files to run. See doc comment in `parse_args` for more details. + shard_index: Zero-based index of the current shard. + total_shards: Total number of shards. 0 means no sharding. """ test_file_pattern: str + shard_index: int = 0 + total_shards: int = 0 def parse_args() -> TestArgs: - """Parses test-exclusive CLI arguements.""" + """Parses test-exclusive CLI arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "-tf", @@ -43,12 +48,75 @@ def parse_args() -> TestArgs: ``` """, ) + parser.add_argument( + "--shard_index", + type=int, + default=0, + help="Zero-based index of the current shard (used with --total_shards).", + ) + parser.add_argument( + "--total_shards", + type=int, + default=0, + help="Total number of shards. 0 or 1 means no sharding (run all tests).", + ) args, _ = parser.parse_known_args() - test_args = TestArgs(test_file_pattern=args.test_file_pattern) + test_args = TestArgs( + test_file_pattern=args.test_file_pattern, + shard_index=args.shard_index, + total_shards=args.total_shards, + ) logger.info(f"Test args: {test_args}") return test_args +def _filter_tests_by_shard( + suite: unittest.TestSuite, shard_index: int, total_shards: int +) -> unittest.TestSuite: + """Filters a test suite to only include tests belonging to the given shard. + + Sharding is done at the file (module) level so that setUpClass/tearDownClass + are not split across shards. Each top-level test group's module name is + SHA-256 hashed and assigned to a shard via ``hash % total_shards``. + + Args: + suite: The full test suite discovered by unittest. + shard_index: Zero-based index of the current shard. + total_shards: Total number of shards. If <= 1, the suite is returned + unchanged. + + Returns: + A new TestSuite containing only the tests assigned to this shard. + """ + if total_shards <= 1: + return suite + + filtered = unittest.TestSuite() + for test_group in suite: + module_name = _get_test_group_module_name(test_group) + hash_value = int(hashlib.sha256(module_name.encode()).hexdigest(), 16) + if hash_value % total_shards == shard_index: + filtered.addTest(test_group) + return filtered + + +def _get_test_group_module_name(test_group: unittest.TestSuite | TestCase) -> str: + """Extracts the module name from a test group for shard assignment. + + Args: + test_group: A test suite or individual test case. + + Returns: + The module name string used for hashing. + """ + if isinstance(test_group, unittest.TestSuite): + # Recurse into nested suites to find the first actual test case + for item in test_group: + return _get_test_group_module_name(item) + # TestCase instance — use its module + return type(test_group).__module__ + + def _run_individual_test(test: TestCase) -> Tuple[bool, int]: # If we don't have any test cases, we skip running the test. # This reduces some noise in the logs. @@ -64,15 +132,23 @@ def _run_individual_test(test: TestCase) -> Tuple[bool, int]: def run_tests( - start_dir: LocalUri, pattern: str, use_sequential_execution: bool = False + start_dir: LocalUri, + pattern: str, + use_sequential_execution: bool = False, + shard_index: int = 0, + total_shards: int = 0, ) -> bool: - """ + """Discovers and runs tests, optionally filtering by shard. + Args: - start_dir (LocalUri): Local Directory for running tests - pattern (str): file text pattern for running tests - use_sequential_execution (bool): Whether sequential exection should be used - Return: - bool: Whether all tests passed successfully + start_dir: Local directory for running tests. + pattern: File text pattern for running tests. + use_sequential_execution: Whether sequential execution should be used. + shard_index: Zero-based index of the current shard. + total_shards: Total number of shards. 0 or 1 means no sharding. + + Returns: + Whether all tests passed successfully. """ start = time.perf_counter() @@ -83,6 +159,14 @@ def run_tests( pattern=pattern, ) + total_discovered: int = suite.countTestCases() + suite = _filter_tests_by_shard(suite, shard_index, total_shards) + + if total_shards > 1: + logger.info( + f"Shard {shard_index}/{total_shards}: running {suite.countTestCases()}/{total_discovered} test cases" + ) + was_successful: bool total_num_test_cases: int = 0 diff --git a/tests/integration/main.py b/tests/integration/main.py index fd1765afd..ebadd5cab 100644 --- a/tests/integration/main.py +++ b/tests/integration/main.py @@ -7,7 +7,11 @@ from tests.test_assets.uri_constants import DEFAULT_NABLP_TASK_CONFIG_URI -def run(pattern: str = "*_test.py") -> bool: +def run( + pattern: str = "*_test.py", + shard_index: int = 0, + total_shards: int = 0, +) -> bool: initialize_metrics( task_config_uri=DEFAULT_NABLP_TASK_CONFIG_URI, service_name="integration_test" ) @@ -17,9 +21,16 @@ def run(pattern: str = "*_test.py") -> bool: ), pattern=pattern, use_sequential_execution=True, + shard_index=shard_index, + total_shards=total_shards, ) if __name__ == "__main__": - was_successful: bool = run(pattern=parse_args().test_file_pattern) + test_args = parse_args() + was_successful: bool = run( + pattern=test_args.test_file_pattern, + shard_index=test_args.shard_index, + total_shards=test_args.total_shards, + ) sys.exit(not was_successful) diff --git a/tests/unit/common/utils/test_sharding_test.py b/tests/unit/common/utils/test_sharding_test.py new file mode 100644 index 000000000..8a07ed75e --- /dev/null +++ b/tests/unit/common/utils/test_sharding_test.py @@ -0,0 +1,126 @@ +import unittest + +from gigl.common.utils.test_utils import _filter_tests_by_shard +from tests.test_assets.test_case import TestCase + + +def _make_test_suite_with_modules(module_names: list[str]) -> unittest.TestSuite: + """Creates a test suite where each top-level group simulates a different module. + + Each module gets a dynamically created TestCase subclass with one test method, + mirroring the structure produced by ``unittest.TestLoader.discover()``. + + Args: + module_names: List of module name strings to simulate. + + Returns: + A TestSuite containing one nested TestSuite per module name. + """ + outer_suite = unittest.TestSuite() + for module_name in module_names: + # Dynamically create a TestCase class with a unique module + test_class = type( + f"TestFor_{module_name.replace('.', '_')}", + (unittest.TestCase,), + { + "test_placeholder": lambda self: None, + "__module__": module_name, + }, + ) + inner_suite = unittest.TestSuite([test_class("test_placeholder")]) + outer_suite.addTest(inner_suite) + return outer_suite + + +class FilterTestsByShardTest(TestCase): + """Tests for the _filter_tests_by_shard function.""" + + MODULES: list[str] = [ + "tests.unit.module_a_test", + "tests.unit.module_b_test", + "tests.unit.module_c_test", + "tests.unit.module_d_test", + "tests.unit.module_e_test", + "tests.unit.module_f_test", + "tests.unit.module_g_test", + "tests.unit.module_h_test", + ] + + def test_no_sharding_when_total_shards_is_zero(self) -> None: + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index=0, total_shards=0) + self.assertEqual(result.countTestCases(), suite.countTestCases()) + + def test_no_sharding_when_total_shards_is_one(self) -> None: + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index=0, total_shards=1) + self.assertEqual(result.countTestCases(), suite.countTestCases()) + + def test_all_tests_covered_across_shards(self) -> None: + """Union of all shards must equal the full suite.""" + total_shards = 4 + all_test_counts: list[int] = [] + for shard_index in range(total_shards): + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index, total_shards) + all_test_counts.append(result.countTestCases()) + + self.assertEqual( + sum(all_test_counts), + len(self.MODULES), + f"Total tests across shards ({sum(all_test_counts)}) != total modules ({len(self.MODULES)})", + ) + + def test_no_overlap_between_shards(self) -> None: + """Each module must appear in exactly one shard.""" + total_shards = 4 + seen_modules: list[str] = [] + for shard_index in range(total_shards): + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index, total_shards) + for test_group in result: + assert isinstance(test_group, unittest.TestSuite) + for test_case in test_group: + module = type(test_case).__module__ + self.assertNotIn( + module, + seen_modules, + f"Module {module} appeared in multiple shards", + ) + seen_modules.append(module) + + def test_deterministic_assignment(self) -> None: + """Running the same shard twice must produce identical results.""" + total_shards = 3 + shard_index = 1 + suite1 = _make_test_suite_with_modules(self.MODULES) + result1 = _filter_tests_by_shard(suite1, shard_index, total_shards) + modules1 = [ + type(tc).__module__ + for tg in result1 + if isinstance(tg, unittest.TestSuite) + for tc in tg + ] + + suite2 = _make_test_suite_with_modules(self.MODULES) + result2 = _filter_tests_by_shard(suite2, shard_index, total_shards) + modules2 = [ + type(tc).__module__ + for tg in result2 + if isinstance(tg, unittest.TestSuite) + for tc in tg + ] + + self.assertEqual(modules1, modules2) + + def test_each_shard_gets_at_least_one_test_when_enough_modules(self) -> None: + """With enough modules, each shard should get at least one test.""" + total_shards = 3 + for shard_index in range(total_shards): + suite = _make_test_suite_with_modules(self.MODULES) + result = _filter_tests_by_shard(suite, shard_index, total_shards) + self.assertGreater( + result.countTestCases(), + 0, + f"Shard {shard_index} got no tests", + ) diff --git a/tests/unit/main.py b/tests/unit/main.py index 83b3b75d4..b22a8a430 100644 --- a/tests/unit/main.py +++ b/tests/unit/main.py @@ -5,16 +5,27 @@ from gigl.common.utils.test_utils import parse_args, run_tests -def run(pattern: str = "*_test.py") -> bool: +def run( + pattern: str = "*_test.py", + shard_index: int = 0, + total_shards: int = 0, +) -> bool: return run_tests( start_dir=LocalUri.join( local_fs_constants.get_project_root_directory(), "tests", "unit" ), pattern=pattern, use_sequential_execution=True, + shard_index=shard_index, + total_shards=total_shards, ) if __name__ == "__main__": - was_successful: bool = run(pattern=parse_args().test_file_pattern) + test_args = parse_args() + was_successful: bool = run( + pattern=test_args.test_file_pattern, + shard_index=test_args.shard_index, + total_shards=test_args.total_shards, + ) sys.exit(not was_successful) From a19cdd40bf591fb43af2437dfcc7e24c784a4970 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 04:00:05 +0000 Subject: [PATCH 02/22] maybe --- .github/workflows/on-pr-comment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 49a7b9d68..2912e20a4 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -3,7 +3,7 @@ name: On Demand Pr Comment Workflows on: issue_comment: types: [created] - push: + push: # TODO(revert) permissions: # Needed for gcloud auth: https://github.com/google-github-actions/auth From 89dfcd30f09d3c24383c29d104aacf61c86e4fa1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 04:02:40 +0000 Subject: [PATCH 03/22] hmmm --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 2912e20a4..298296e5c 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -41,7 +41,7 @@ jobs: message: ${{ steps.parse_commands.outputs.help_message }} unit-test-python: - if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} + #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 steps: @@ -86,7 +86,7 @@ jobs: make unit_test_scala integration-test: - if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} + #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 steps: From 34b797e6eae296fc195abe26c352d3a1bbeadb19 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 04:06:25 +0000 Subject: [PATCH 04/22] test --- .github/workflows/on-pr-comment.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 298296e5c..46a573311 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -45,6 +45,8 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: + - name: Checkout code + uses: actions/checkout@v4 - name: Run Python Unit Tests uses: ./.github/actions/run-sharded-cloud-build with: @@ -90,6 +92,8 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: + - name: Checkout code + uses: actions/checkout@v4 - name: Run Integration Tests uses: ./.github/actions/run-sharded-cloud-build with: From e7da1731b0cb59440763a9203cfb127ab5b96e2d Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 04:15:43 +0000 Subject: [PATCH 05/22] fix --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 46a573311..c5d0bbcbb 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -48,7 +48,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Run Python Unit Tests - uses: ./.github/actions/run-sharded-cloud-build + uses: ./.github/actions/run-command-on-pr with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} @@ -95,7 +95,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Run Integration Tests - uses: ./.github/actions/run-sharded-cloud-build + uses: ./.github/actions/run-command-on-pr with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} From adce2cd9c9b539f56c8d83e2dd0b69a7eaa6b333 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 04:52:22 +0000 Subject: [PATCH 06/22] bleh --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index c5d0bbcbb..40ef4e345 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -52,7 +52,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} - should_leave_progress_comments: "true" + #should_leave_progress_comments: "true" descriptive_workflow_name: "Python Unit Test" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests @@ -99,7 +99,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} - should_leave_progress_comments: "true" + #should_leave_progress_comments: "true" descriptive_workflow_name: "Integration Test" setup_gcloud: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} From a6bdb8bfca72a96a99e3ce3b58fcfda6c19b5929 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 05:06:54 +0000 Subject: [PATCH 07/22] no checkout --- .github/workflows/on-pr-comment.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 40ef4e345..ae89d28e0 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -45,8 +45,6 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: - - name: Checkout code - uses: actions/checkout@v4 - name: Run Python Unit Tests uses: ./.github/actions/run-command-on-pr with: @@ -92,8 +90,6 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: - - name: Checkout code - uses: actions/checkout@v4 - name: Run Integration Tests uses: ./.github/actions/run-command-on-pr with: From b9ab9d6e62fb40b75b4c4cd3c2a28488ba80a4a1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 05:12:33 +0000 Subject: [PATCH 08/22] update --- .github/workflows/on-pr-comment.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index ae89d28e0..c9c45349a 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -45,6 +45,10 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.sha }} - name: Run Python Unit Tests uses: ./.github/actions/run-command-on-pr with: @@ -90,6 +94,10 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.sha }} - name: Run Integration Tests uses: ./.github/actions/run-command-on-pr with: From 5feab220c3e8b9fb5a1ddd6ede7128aaf694d48b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 05:31:52 +0000 Subject: [PATCH 09/22] update --- .github/workflows/on-pr-comment.yml | 36 +++++++++++++---------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index c9c45349a..b45df2092 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -49,25 +49,21 @@ jobs: uses: actions/checkout@v4 with: ref: ${{ github.sha }} - - name: Run Python Unit Tests - uses: ./.github/actions/run-command-on-pr + - name: Setup development environment + uses: ./.github/actions/setup-python-tools with: - github-token: ${{ secrets.GITHUB_TOKEN }} - pr_number: ${{ github.event.issue.number }} - #should_leave_progress_comments: "true" - descriptive_workflow_name: "Python Unit Test" setup_gcloud: "true" - # We use cloud run here instead of using github hosted runners because of limitation of tests - # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand - # how to leverage Workload Identity Federation to read assets from GCS, et al. See: - # https://github.com/tensorflow/tensorflow/issues/57104 gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - use_sharded_cloud_build: "true" + - name: Run Python Unit Tests (sharded) + uses: ./.github/actions/run-sharded-cloud-build + with: + test_command: "unit_test_py_shard" total_shards: "4" - sharded_test_command: "unit_test_py_shard" run_type_check: "true" + service_account: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} + project: ${{ vars.GCP_PROJECT_ID }} unit-test-scala: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_scala') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} @@ -98,21 +94,21 @@ jobs: uses: actions/checkout@v4 with: ref: ${{ github.sha }} - - name: Run Integration Tests - uses: ./.github/actions/run-command-on-pr + - name: Setup development environment + uses: ./.github/actions/setup-python-tools with: - github-token: ${{ secrets.GITHUB_TOKEN }} - pr_number: ${{ github.event.issue.number }} - #should_leave_progress_comments: "true" - descriptive_workflow_name: "Integration Test" setup_gcloud: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - use_sharded_cloud_build: "true" + - name: Run Integration Tests (sharded) + uses: ./.github/actions/run-sharded-cloud-build + with: + test_command: "integration_test_shard" total_shards: "4" - sharded_test_command: "integration_test_shard" run_type_check: "false" + service_account: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} + project: ${{ vars.GCP_PROJECT_ID }} integration-e2e-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/e2e_test') || contains(github.event.comment.body, '/all_test')) }} From 9c00855ea998c9cddf798665fb2fab796bca216b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 06:19:11 +0000 Subject: [PATCH 10/22] better logs --- .../run-sharded-cloud-build/action.yml | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/.github/actions/run-sharded-cloud-build/action.yml b/.github/actions/run-sharded-cloud-build/action.yml index ec5b2c511..c8a70422a 100644 --- a/.github/actions/run-sharded-cloud-build/action.yml +++ b/.github/actions/run-sharded-cloud-build/action.yml @@ -36,6 +36,8 @@ runs: BUILD_IDS="" BUILD_LABELS="" TOTAL_SHARDS=${{ inputs.total_shards }} + PROJECT="${{ inputs.project }}" + CLOUD_BUILD_URL_BASE="https://console.cloud.google.com/cloud-build/builds" # Launch type-check job if enabled if [[ "${{ inputs.run_type_check }}" == "true" ]]; then @@ -43,15 +45,15 @@ runs: TYPE_CHECK_BUILD_ID=$(gcloud builds submit . \ --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ --substitutions=_CMD="make type_check_only" \ - --service-account="projects/${{ inputs.project }}/serviceAccounts/${{ inputs.service_account }}" \ - --project="${{ inputs.project }}" \ + --service-account="projects/${PROJECT}/serviceAccounts/${{ inputs.service_account }}" \ + --project="${PROJECT}" \ --machine-type="${{ inputs.machine_type }}" \ --timeout="${{ inputs.timeout }}" \ --async \ --format='value(id)' 2>&1 | tail -1) BUILD_IDS="${TYPE_CHECK_BUILD_ID}" BUILD_LABELS="type_check" - echo " Type-check build ID: ${TYPE_CHECK_BUILD_ID}" + echo " Type-check build: ${CLOUD_BUILD_URL_BASE}/${TYPE_CHECK_BUILD_ID}?project=${PROJECT}" fi # Launch shard jobs @@ -60,8 +62,8 @@ runs: SHARD_BUILD_ID=$(gcloud builds submit . \ --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ --substitutions=_CMD="make ${{ inputs.test_command }} SHARD_INDEX=${i} TOTAL_SHARDS=${TOTAL_SHARDS}" \ - --service-account="projects/${{ inputs.project }}/serviceAccounts/${{ inputs.service_account }}" \ - --project="${{ inputs.project }}" \ + --service-account="projects/${PROJECT}/serviceAccounts/${{ inputs.service_account }}" \ + --project="${PROJECT}" \ --machine-type="${{ inputs.machine_type }}" \ --timeout="${{ inputs.timeout }}" \ --async \ @@ -73,7 +75,7 @@ runs: BUILD_IDS="${SHARD_BUILD_ID}" BUILD_LABELS="shard_${i}" fi - echo " Shard ${i} build ID: ${SHARD_BUILD_ID}" + echo " Shard ${i} build: ${CLOUD_BUILD_URL_BASE}/${SHARD_BUILD_ID}?project=${PROJECT}" done echo "build_ids=${BUILD_IDS}" >> $GITHUB_OUTPUT @@ -121,15 +123,18 @@ runs: # Check results FAILED=0 + PROJECT="${{ inputs.project }}" + CLOUD_BUILD_URL_BASE="https://console.cloud.google.com/cloud-build/builds" for idx in "${!BUILD_LABELS[@]}"; do LABEL="${BUILD_LABELS[$idx]}" BUILD_ID="${BUILD_IDS[$idx]}" + BUILD_URL="${CLOUD_BUILD_URL_BASE}/${BUILD_ID}?project=${PROJECT}" RESULT=$(cat "${TEMP_DIR}/${LABEL}" 2>/dev/null || echo "UNKNOWN") if [[ "${RESULT}" != "SUCCESS" ]]; then - echo "FAILED: ${LABEL} (${BUILD_ID}) - status: ${RESULT}" + echo "FAILED: ${LABEL} - ${BUILD_URL}" FAILED=1 else - echo "PASSED: ${LABEL} (${BUILD_ID})" + echo "PASSED: ${LABEL} - ${BUILD_URL}" fi done From c9a90518c3058bdb0c34dcc93d8b0953c1f28f3b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 18:06:52 +0000 Subject: [PATCH 11/22] cleanup --- gigl/common/utils/test_utils.py | 57 ++++- tests/integration/main.py | 18 ++ tests/unit/common/utils/test_sharding_test.py | 226 +++++++++++++++++- tests/unit/main.py | 15 ++ 4 files changed, 305 insertions(+), 11 deletions(-) diff --git a/gigl/common/utils/test_utils.py b/gigl/common/utils/test_utils.py index 4314d864b..ee2791cb0 100644 --- a/gigl/common/utils/test_utils.py +++ b/gigl/common/utils/test_utils.py @@ -4,7 +4,7 @@ import unittest from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass -from typing import Iterator, Tuple +from typing import Iterator from unittest import TestCase from gigl.common import LocalUri @@ -70,20 +70,52 @@ def parse_args() -> TestArgs: return test_args +def _get_shard_for_module( + module_name: str, + total_shards: int, + pinned_modules: tuple[str, ...], +) -> int: + """Returns the shard index a module should be assigned to. + + Pinned modules use their position in ``pinned_modules`` to determine the + shard (``index % total_shards``). All other modules fall back to + SHA-256 hashing. + + Args: + module_name: Fully-qualified module name. + total_shards: Total number of shards (must be >= 2). + pinned_modules: Ordered tuple of module names with deterministic + position-based shard assignment. + + Returns: + Zero-based shard index for the module. + """ + if module_name in pinned_modules: + return pinned_modules.index(module_name) % total_shards + hash_value = int(hashlib.sha256(module_name.encode()).hexdigest(), 16) + return hash_value % total_shards + + def _filter_tests_by_shard( - suite: unittest.TestSuite, shard_index: int, total_shards: int + suite: unittest.TestSuite, + shard_index: int, + total_shards: int, + pinned_modules: tuple[str, ...] = (), ) -> unittest.TestSuite: """Filters a test suite to only include tests belonging to the given shard. Sharding is done at the file (module) level so that setUpClass/tearDownClass - are not split across shards. Each top-level test group's module name is - SHA-256 hashed and assigned to a shard via ``hash % total_shards``. + are not split across shards. Pinned modules are assigned by their position + in ``pinned_modules`` (``index % total_shards``); all other modules use + SHA-256 hashing. Args: suite: The full test suite discovered by unittest. shard_index: Zero-based index of the current shard. total_shards: Total number of shards. If <= 1, the suite is returned unchanged. + pinned_modules: Ordered tuple of module names with deterministic + position-based shard assignment. Returns: A new TestSuite containing only the tests assigned to this shard. @@ -94,8 +126,10 @@ def _filter_tests_by_shard( filtered = unittest.TestSuite() for test_group in suite: module_name = _get_test_group_module_name(test_group) - hash_value = int(hashlib.sha256(module_name.encode()).hexdigest(), 16) - if hash_value % total_shards == shard_index: + if ( + _get_shard_for_module(module_name, total_shards, pinned_modules) + == shard_index + ): filtered.addTest(test_group) return filtered @@ -117,7 +151,7 @@ def _get_test_group_module_name(test_group: unittest.TestSuite | TestCase) -> st return type(test_group).__module__ -def _run_individual_test(test: TestCase) -> Tuple[bool, int]: +def _run_individual_test(test: TestCase) -> tuple[bool, int]: # If we don't have any test cases, we skip running the test. # This reduces some noise in the logs. if test.countTestCases() == 0: @@ -137,6 +171,7 @@ def run_tests( use_sequential_execution: bool = False, shard_index: int = 0, total_shards: int = 0, + pinned_modules: tuple[str, ...] = (), ) -> bool: """Discovers and runs tests, optionally filtering by shard. @@ -146,6 +181,8 @@ def run_tests( use_sequential_execution: Whether sequential execution should be used. shard_index: Zero-based index of the current shard. total_shards: Total number of shards. 0 or 1 means no sharding. + pinned_modules: Ordered tuple of module names with deterministic + position-based shard assignment (``index % total_shards``). Returns: Whether all tests passed successfully. @@ -160,7 +197,7 @@ def run_tests( ) total_discovered: int = suite.countTestCases() - suite = _filter_tests_by_shard(suite, shard_index, total_shards) + suite = _filter_tests_by_shard(suite, shard_index, total_shards, pinned_modules) if total_shards > 1: logger.info( @@ -176,7 +213,7 @@ def run_tests( total_num_test_cases = suite.countTestCases() else: with ProcessPoolExecutor() as executor: - was_successful_iter: Iterator[Tuple[bool, int]] = executor.map( + was_successful_iter: Iterator[tuple[bool, int]] = executor.map( _run_individual_test, suite._tests ) was_successful = True @@ -186,5 +223,5 @@ def run_tests( logger.info(f"Ran {total_num_test_cases}/{suite.countTestCases()} test cases") finish = time.perf_counter() - logger.info(f"It took {finish-start: .2f} second(s) to run tests") + logger.info(f"It took {finish - start: .2f} second(s) to run tests") return was_successful diff --git a/tests/integration/main.py b/tests/integration/main.py index ebadd5cab..00a6638b3 100644 --- a/tests/integration/main.py +++ b/tests/integration/main.py @@ -1,4 +1,5 @@ import sys +from typing import Final import gigl.src.common.constants.local_fs as local_fs_constants from gigl.common import LocalUri @@ -6,6 +7,22 @@ from gigl.src.common.utils.metrics_service_provider import initialize_metrics from tests.test_assets.uri_constants import DEFAULT_NABLP_TASK_CONFIG_URI +# Slow test modules that must be spread across shards. Position in the tuple +# determines the shard: ``index % total_shards``. **Append-only** — never +# reorder existing entries, or every module's shard assignment will shift. +# +# Durations measured 2026-02-27 (unsharded CI run, 77.5 min total): +INTEGRATION_TEST_SHARD_PINNED_MODULES: Final[tuple[str, ...]] = ( + "tests.integration.distributed.distributed_dataset_test", # 14.5 min (18.7%) + "tests.integration.distributed.utils.networking_test", # 13.3 min (17.2%) + "tests.integration.distributed.graph_store.graph_store_integration_test", # 13.0 min (16.8%) + "tests.integration.pipeline.data_preprocessor.data_preprocessor_pipeline_test", # 11.7 min (15.1%) + "tests.integration.pipeline.subgraph_sampler.subgraph_sampler_test", # 8.8 min (11.4%) + "tests.integration.common.services.vertex_ai_test", # 6.5 min (8.4%) + "tests.integration.pipeline.split_generator.split_generator_pipeline_test", # 3.8 min (5.0%) + "tests.integration.pipeline.inferencer.inferencer_test", # 2.1 min (2.8%) +) + def run( pattern: str = "*_test.py", @@ -23,6 +40,7 @@ def run( use_sequential_execution=True, shard_index=shard_index, total_shards=total_shards, + pinned_modules=INTEGRATION_TEST_SHARD_PINNED_MODULES, ) diff --git a/tests/unit/common/utils/test_sharding_test.py b/tests/unit/common/utils/test_sharding_test.py index 8a07ed75e..4ea87c42a 100644 --- a/tests/unit/common/utils/test_sharding_test.py +++ b/tests/unit/common/utils/test_sharding_test.py @@ -1,7 +1,11 @@ import unittest -from gigl.common.utils.test_utils import _filter_tests_by_shard +import gigl.src.common.constants.local_fs as local_fs_constants +from gigl.common import LocalUri +from gigl.common.utils.test_utils import _filter_tests_by_shard, _get_shard_for_module +from tests.integration.main import INTEGRATION_TEST_SHARD_PINNED_MODULES from tests.test_assets.test_case import TestCase +from tests.unit.main import UNIT_TEST_SHARD_PINNED_MODULES def _make_test_suite_with_modules(module_names: list[str]) -> unittest.TestSuite: @@ -124,3 +128,223 @@ def test_each_shard_gets_at_least_one_test_when_enough_modules(self) -> None: 0, f"Shard {shard_index} got no tests", ) + + +class ShardPinningTest(TestCase): + """Tests for manual shard pinning via pinned_modules.""" + + PINNED: tuple[str, ...] = ( + "tests.unit.distributed.dist_ablp_neighborloader_test", + "tests.unit.distributed.distributed_dataset_test", + "tests.unit.distributed.distributed_neighborloader_test", + "tests.unit.distributed.distributed_partitioner_test", + "tests.unit.distributed.utils.networking_test", + ) + + UNPINNED: list[str] = [ + "tests.unit.module_a_test", + "tests.unit.module_b_test", + "tests.unit.module_c_test", + "tests.unit.module_d_test", + "tests.unit.module_e_test", + ] + + def test_pinned_modules_assigned_by_position(self) -> None: + """Pinned module at index i is assigned to shard i % total_shards.""" + total_shards = 4 + for index, module_name in enumerate(self.PINNED): + expected_shard = index % total_shards + actual_shard = _get_shard_for_module(module_name, total_shards, self.PINNED) + self.assertEqual( + actual_shard, + expected_shard, + f"Pinned module {module_name} (index {index}) expected shard " + f"{expected_shard}, got {actual_shard}", + ) + + def test_pinned_modules_use_all_shards_with_four_shards(self) -> None: + """With 5 pinned modules and 4 shards, every shard gets at least one.""" + total_shards = 4 + assigned_shards = { + _get_shard_for_module(m, total_shards, self.PINNED) for m in self.PINNED + } + self.assertEqual( + len(assigned_shards), + min(len(self.PINNED), total_shards), + f"Expected {min(len(self.PINNED), total_shards)} distinct shards, " + f"got {assigned_shards}", + ) + + def test_full_coverage_no_overlap_with_pinned_and_unpinned(self) -> None: + """All modules appear exactly once across all shards.""" + total_shards = 4 + all_modules = list(self.PINNED) + self.UNPINNED + + seen_modules: list[str] = [] + for shard_index in range(total_shards): + fresh_suite = _make_test_suite_with_modules(all_modules) + result = _filter_tests_by_shard( + fresh_suite, shard_index, total_shards, pinned_modules=self.PINNED + ) + for test_group in result: + assert isinstance(test_group, unittest.TestSuite) + for test_case in test_group: + module = type(test_case).__module__ + self.assertNotIn( + module, + seen_modules, + f"Module {module} appeared in multiple shards", + ) + seen_modules.append(module) + + self.assertEqual( + sorted(seen_modules), + sorted(all_modules), + "Not all modules were covered across shards", + ) + + def test_pinning_across_various_total_shards(self) -> None: + """Pinned modules land on expected shards for several shard counts.""" + for total_shards in (2, 3, 4, 5, 8): + for index, module_name in enumerate(self.PINNED): + expected_shard = index % total_shards + actual_shard = _get_shard_for_module( + module_name, total_shards, self.PINNED + ) + self.assertEqual( + actual_shard, + expected_shard, + f"total_shards={total_shards}: pinned module {module_name} " + f"(index {index}) expected shard {expected_shard}, got {actual_shard}", + ) + + def test_unpinned_modules_use_hash(self) -> None: + """Unpinned modules still use SHA-256 hashing, unaffected by pinned list.""" + import hashlib + + total_shards = 4 + for module_name in self.UNPINNED: + expected = ( + int(hashlib.sha256(module_name.encode()).hexdigest(), 16) % total_shards + ) + actual = _get_shard_for_module(module_name, total_shards, self.PINNED) + self.assertEqual( + actual, + expected, + f"Unpinned module {module_name} should use hash-based assignment", + ) + + def test_real_unit_test_pinned_modules_cover_all_shards(self) -> None: + """The actual UNIT_TEST_SHARD_PINNED_MODULES cover every shard with 4 shards.""" + total_shards = 4 + assigned_shards = { + _get_shard_for_module(m, total_shards, UNIT_TEST_SHARD_PINNED_MODULES) + for m in UNIT_TEST_SHARD_PINNED_MODULES + } + self.assertEqual( + assigned_shards, + set(range(total_shards)), + f"Expected all shards 0..{total_shards - 1} covered, got {assigned_shards}", + ) + + def test_real_integration_test_pinned_modules_cover_all_shards(self) -> None: + """The actual INTEGRATION_TEST_SHARD_PINNED_MODULES cover every shard with 4 shards.""" + total_shards = 4 + assigned_shards = { + _get_shard_for_module( + m, total_shards, INTEGRATION_TEST_SHARD_PINNED_MODULES + ) + for m in INTEGRATION_TEST_SHARD_PINNED_MODULES + } + self.assertEqual( + assigned_shards, + set(range(total_shards)), + f"Expected all shards 0..{total_shards - 1} covered, got {assigned_shards}", + ) + + +def _collect_test_ids(suite: unittest.TestSuite) -> set[str]: + """Recursively collects all individual test case IDs from a suite. + + Args: + suite: A (possibly nested) test suite. + + Returns: + Set of fully-qualified test IDs (e.g. ``module.Class.test_method``). + """ + ids: set[str] = set() + for item in suite: + if isinstance(item, unittest.TestSuite): + ids.update(_collect_test_ids(item)) + else: + ids.add(item.id()) + return ids + + +class RealDiscoveryShardingTest(TestCase): + """Discovers real unit tests and verifies sharding preserves them all.""" + + TOTAL_SHARDS: int = 4 + + @classmethod + def setUpClass(cls) -> None: + start_dir = LocalUri.join( + local_fs_constants.get_project_root_directory(), "tests", "unit" + ) + cls.start_dir = start_dir + full_suite = unittest.TestLoader().discover( + start_dir=start_dir.uri, pattern="*_test.py" + ) + cls.unsharded_test_ids = _collect_test_ids(full_suite) + + def test_sharded_test_count_equals_unsharded(self) -> None: + """Sum of test cases across all shards equals the unsharded total.""" + sharded_total = 0 + for shard_index in range(self.TOTAL_SHARDS): + suite = unittest.TestLoader().discover( + start_dir=self.start_dir.uri, pattern="*_test.py" + ) + filtered = _filter_tests_by_shard( + suite, + shard_index, + self.TOTAL_SHARDS, + UNIT_TEST_SHARD_PINNED_MODULES, + ) + sharded_total += filtered.countTestCases() + + self.assertEqual( + sharded_total, + len(self.unsharded_test_ids), + f"Sharded total ({sharded_total}) != unsharded total " + f"({len(self.unsharded_test_ids)})", + ) + + def test_sharded_tests_equal_unsharded(self) -> None: + """Union of test IDs across all shards equals the full unsharded set.""" + sharded_test_ids: set[str] = set() + for shard_index in range(self.TOTAL_SHARDS): + suite = unittest.TestLoader().discover( + start_dir=self.start_dir.uri, pattern="*_test.py" + ) + filtered = _filter_tests_by_shard( + suite, + shard_index, + self.TOTAL_SHARDS, + UNIT_TEST_SHARD_PINNED_MODULES, + ) + shard_ids = _collect_test_ids(filtered) + overlap = sharded_test_ids & shard_ids + self.assertEqual( + overlap, + set(), + f"Shard {shard_index} overlaps with previous shards: {overlap}", + ) + sharded_test_ids.update(shard_ids) + + self.assertEqual( + sharded_test_ids, + self.unsharded_test_ids, + f"Test ID mismatch.\n" + f" Only in sharded: {sharded_test_ids - self.unsharded_test_ids}\n" + f" Only in unsharded: {self.unsharded_test_ids - sharded_test_ids}", + ) diff --git a/tests/unit/main.py b/tests/unit/main.py index b22a8a430..28809710c 100644 --- a/tests/unit/main.py +++ b/tests/unit/main.py @@ -1,9 +1,23 @@ import sys +from typing import Final import gigl.src.common.constants.local_fs as local_fs_constants from gigl.common import LocalUri from gigl.common.utils.test_utils import parse_args, run_tests +# Slow test modules that must be spread across shards. Position in the tuple +# determines the shard: ``index % total_shards``. **Append-only** — never +# reorder existing entries, or every module's shard assignment will shift. +# +# Durations measured 2026-02-27 (unsharded CI run, 61.7 min total): +UNIT_TEST_SHARD_PINNED_MODULES: Final[tuple[str, ...]] = ( + "tests.unit.distributed.dist_ablp_neighborloader_test", # 24.7 min (40.0%) + "tests.unit.distributed.distributed_dataset_test", # 10.7 min (17.4%) + "tests.unit.distributed.distributed_neighborloader_test", # 9.6 min (15.5%) + "tests.unit.distributed.distributed_partitioner_test", # 6.5 min (10.5%) + "tests.unit.distributed.utils.networking_test", # 2.7 min (4.4%) +) + def run( pattern: str = "*_test.py", @@ -18,6 +32,7 @@ def run( use_sequential_execution=True, shard_index=shard_index, total_shards=total_shards, + pinned_modules=UNIT_TEST_SHARD_PINNED_MODULES, ) From 97feb7d02d5595b2945fcf0a8d92cf10a023e2e6 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 19:00:12 +0000 Subject: [PATCH 12/22] maybe ready --- .github/workflows/on-pr-comment.yml | 52 ++++++++++++----------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index b45df2092..d6b8d2040 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -3,7 +3,6 @@ name: On Demand Pr Comment Workflows on: issue_comment: types: [created] - push: # TODO(revert) permissions: # Needed for gcloud auth: https://github.com/google-github-actions/auth @@ -41,29 +40,25 @@ jobs: message: ${{ steps.parse_commands.outputs.help_message }} unit-test-python: - #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} + if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - ref: ${{ github.sha }} - - name: Setup development environment - uses: ./.github/actions/setup-python-tools + - name: Run Python Unit Tests (sharded) + uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: + github-token: ${{ secrets.GITHUB_TOKEN }} + pr_number: ${{ github.event.issue.number }} + should_leave_progress_comments: "true" + descriptive_workflow_name: "Python Unit Test" setup_gcloud: "true" + use_sharded_cloud_build: "true" + sharded_test_command: "unit_test_py_shard" + total_shards: "4" + run_type_check: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - - name: Run Python Unit Tests (sharded) - uses: ./.github/actions/run-sharded-cloud-build - with: - test_command: "unit_test_py_shard" - total_shards: "4" - run_type_check: "true" - service_account: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - project: ${{ vars.GCP_PROJECT_ID }} unit-test-scala: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_scala') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} @@ -86,29 +81,24 @@ jobs: make unit_test_scala integration-test: - #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} + if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - ref: ${{ github.sha }} - - name: Setup development environment - uses: ./.github/actions/setup-python-tools + - name: Run Integration Tests (sharded) + uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: + github-token: ${{ secrets.GITHUB_TOKEN }} + pr_number: ${{ github.event.issue.number }} + should_leave_progress_comments: "true" + descriptive_workflow_name: "Integration Test" setup_gcloud: "true" + use_sharded_cloud_build: "true" + sharded_test_command: "integration_test_shard" + total_shards: "4" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - - name: Run Integration Tests (sharded) - uses: ./.github/actions/run-sharded-cloud-build - with: - test_command: "integration_test_shard" - total_shards: "4" - run_type_check: "false" - service_account: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} - project: ${{ vars.GCP_PROJECT_ID }} integration-e2e-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/e2e_test') || contains(github.event.comment.body, '/all_test')) }} From 3eb6ea753dd34892b5cd530e8d99fcbd3d02be35 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 27 Feb 2026 23:35:15 +0000 Subject: [PATCH 13/22] update --- .github/workflows/on-pr-comment.yml | 8 +++ tests/unit/common/utils/test_sharding_test.py | 72 ++++++++----------- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index d6b8d2040..f1ba61a41 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -52,6 +52,10 @@ jobs: should_leave_progress_comments: "true" descriptive_workflow_name: "Python Unit Test" setup_gcloud: "true" + # We use cloud run here instead of using github hosted runners because of limitation of tests + # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand + # how to leverage Workload Identity Federation to read assets from GCS, et al. See: + # https://github.com/tensorflow/tensorflow/issues/57104 use_sharded_cloud_build: "true" sharded_test_command: "unit_test_py_shard" total_shards: "4" @@ -93,6 +97,10 @@ jobs: should_leave_progress_comments: "true" descriptive_workflow_name: "Integration Test" setup_gcloud: "true" + # We use cloud run here instead of using github hosted runners because of limitation of tests + # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand + # how to leverage Workload Identity Federation to read assets from GCS, et al. See: + # https://github.com/tensorflow/tensorflow/issues/57104 use_sharded_cloud_build: "true" sharded_test_command: "integration_test_shard" total_shards: "4" diff --git a/tests/unit/common/utils/test_sharding_test.py b/tests/unit/common/utils/test_sharding_test.py index 4ea87c42a..5b741e4eb 100644 --- a/tests/unit/common/utils/test_sharding_test.py +++ b/tests/unit/common/utils/test_sharding_test.py @@ -1,3 +1,4 @@ +import hashlib import unittest import gigl.src.common.constants.local_fs as local_fs_constants @@ -8,6 +9,27 @@ from tests.unit.main import UNIT_TEST_SHARD_PINNED_MODULES +def _extract_module_names(suite: unittest.TestSuite) -> list[str]: + """Extracts module names from a filtered test suite, preserving order. + + Assumes the suite has the two-level nesting produced by + ``_make_test_suite_with_modules``: outer suite → inner TestSuite per + module → individual TestCase(s). + + Args: + suite: A filtered test suite. + + Returns: + Ordered list of module name strings found in the suite. + """ + return [ + type(test_case).__module__ + for test_group in suite + if isinstance(test_group, unittest.TestSuite) + for test_case in test_group + ] + + def _make_test_suite_with_modules(module_names: list[str]) -> unittest.TestSuite: """Creates a test suite where each top-level group simulates a different module. @@ -78,7 +100,7 @@ def test_all_tests_covered_across_shards(self) -> None: def test_no_overlap_between_shards(self) -> None: """Each module must appear in exactly one shard.""" total_shards = 4 - seen_modules: list[str] = [] + seen_modules: set[str] = set() for shard_index in range(total_shards): suite = _make_test_suite_with_modules(self.MODULES) result = _filter_tests_by_shard(suite, shard_index, total_shards) @@ -91,7 +113,7 @@ def test_no_overlap_between_shards(self) -> None: seen_modules, f"Module {module} appeared in multiple shards", ) - seen_modules.append(module) + seen_modules.add(module) def test_deterministic_assignment(self) -> None: """Running the same shard twice must produce identical results.""" @@ -99,21 +121,11 @@ def test_deterministic_assignment(self) -> None: shard_index = 1 suite1 = _make_test_suite_with_modules(self.MODULES) result1 = _filter_tests_by_shard(suite1, shard_index, total_shards) - modules1 = [ - type(tc).__module__ - for tg in result1 - if isinstance(tg, unittest.TestSuite) - for tc in tg - ] + modules1 = _extract_module_names(result1) suite2 = _make_test_suite_with_modules(self.MODULES) result2 = _filter_tests_by_shard(suite2, shard_index, total_shards) - modules2 = [ - type(tc).__module__ - for tg in result2 - if isinstance(tg, unittest.TestSuite) - for tc in tg - ] + modules2 = _extract_module_names(result2) self.assertEqual(modules1, modules2) @@ -180,7 +192,7 @@ def test_full_coverage_no_overlap_with_pinned_and_unpinned(self) -> None: total_shards = 4 all_modules = list(self.PINNED) + self.UNPINNED - seen_modules: list[str] = [] + seen_modules: set[str] = set() for shard_index in range(total_shards): fresh_suite = _make_test_suite_with_modules(all_modules) result = _filter_tests_by_shard( @@ -195,11 +207,11 @@ def test_full_coverage_no_overlap_with_pinned_and_unpinned(self) -> None: seen_modules, f"Module {module} appeared in multiple shards", ) - seen_modules.append(module) + seen_modules.add(module) self.assertEqual( - sorted(seen_modules), - sorted(all_modules), + seen_modules, + set(all_modules), "Not all modules were covered across shards", ) @@ -220,8 +232,6 @@ def test_pinning_across_various_total_shards(self) -> None: def test_unpinned_modules_use_hash(self) -> None: """Unpinned modules still use SHA-256 hashing, unaffected by pinned list.""" - import hashlib - total_shards = 4 for module_name in self.UNPINNED: expected = ( @@ -297,28 +307,6 @@ def setUpClass(cls) -> None: ) cls.unsharded_test_ids = _collect_test_ids(full_suite) - def test_sharded_test_count_equals_unsharded(self) -> None: - """Sum of test cases across all shards equals the unsharded total.""" - sharded_total = 0 - for shard_index in range(self.TOTAL_SHARDS): - suite = unittest.TestLoader().discover( - start_dir=self.start_dir.uri, pattern="*_test.py" - ) - filtered = _filter_tests_by_shard( - suite, - shard_index, - self.TOTAL_SHARDS, - UNIT_TEST_SHARD_PINNED_MODULES, - ) - sharded_total += filtered.countTestCases() - - self.assertEqual( - sharded_total, - len(self.unsharded_test_ids), - f"Sharded total ({sharded_total}) != unsharded total " - f"({len(self.unsharded_test_ids)})", - ) - def test_sharded_tests_equal_unsharded(self) -> None: """Union of test IDs across all shards equals the full unsharded set.""" sharded_test_ids: set[str] = set() From bd57293ada674f789b96442323c7303c449f97c4 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 18:42:10 +0000 Subject: [PATCH 14/22] update to matrix --- .github/workflows/on-pr-comment.yml | 50 +++++++++++++++++++++-------- Makefile | 3 -- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index f1ba61a41..1380054a2 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -3,6 +3,7 @@ name: On Demand Pr Comment Workflows on: issue_comment: types: [created] + push: permissions: # Needed for gcloud auth: https://github.com/google-github-actions/auth @@ -43,26 +44,38 @@ jobs: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + include: + - name: "Type Check" + command: "make type_check" + - name: "Shard 0" + command: "make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make unit_test_py_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make unit_test_py_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make unit_test_py_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - - name: Run Python Unit Tests (sharded) - uses: snapchat/gigl/.github/actions/run-command-on-pr@main + - name: Run Python Unit Tests (${{ matrix.name }}) + uses: ./.github/actions/run-command-on-pr with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} should_leave_progress_comments: "true" - descriptive_workflow_name: "Python Unit Test" + descriptive_workflow_name: "Python Unit Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand # how to leverage Workload Identity Federation to read assets from GCS, et al. See: # https://github.com/tensorflow/tensorflow/issues/57104 - use_sharded_cloud_build: "true" - sharded_test_command: "unit_test_py_shard" - total_shards: "4" - run_type_check: "true" + use_cloud_run: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} + command: ${{ matrix.command }} unit-test-scala: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_scala') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} @@ -88,25 +101,36 @@ jobs: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + include: + - name: "Shard 0" + command: "make integration_test_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make integration_test_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make integration_test_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make integration_test_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - - name: Run Integration Tests (sharded) - uses: snapchat/gigl/.github/actions/run-command-on-pr@main + - name: Run Integration Tests (${{ matrix.name }}) + uses: ./.github/actions/run-command-on-pr with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} should_leave_progress_comments: "true" - descriptive_workflow_name: "Integration Test" + descriptive_workflow_name: "Integration Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand # how to leverage Workload Identity Federation to read assets from GCS, et al. See: # https://github.com/tensorflow/tensorflow/issues/57104 - use_sharded_cloud_build: "true" - sharded_test_command: "integration_test_shard" - total_shards: "4" + use_cloud_run: "true" gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.WORKLOAD_IDENTITY_PROVIDER }} gcp_service_account_email: ${{ secrets.GCP_SERVICE_ACCOUNT_EMAIL }} + command: ${{ matrix.command }} integration-e2e-test: if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/e2e_test') || contains(github.event.comment.body, '/all_test')) }} diff --git a/Makefile b/Makefile index 3eb13e25b..37c6b9706 100644 --- a/Makefile +++ b/Makefile @@ -83,9 +83,6 @@ unit_test_py: clean_build_files_py type_check --resource_config_uri=${GIGL_TEST_DEFAULT_RESOURCE_CONFIG} \ --test_file_pattern=$(PY_TEST_FILES) \ -# Runs only the type checker without tests. Used as a standalone Cloud Build shard job. -type_check_only: clean_build_files_py type_check - # Runs a single shard of the Python unit tests (no type checking). # Usage: make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4 unit_test_py_shard: clean_build_files_py From 8b4b29e8d154d77e68781577586d616d694ced02 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 18:44:30 +0000 Subject: [PATCH 15/22] bump --- .github/workflows/on-pr-comment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 1380054a2..8f8c8a99f 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -3,7 +3,7 @@ name: On Demand Pr Comment Workflows on: issue_comment: types: [created] - push: + push: # TODO: revert permissions: # Needed for gcloud auth: https://github.com/google-github-actions/auth From 28e83a614576b5570eb56bfc686f10772a16f8c8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 18:45:00 +0000 Subject: [PATCH 16/22] enabel --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 8f8c8a99f..d060e3d3b 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -41,7 +41,7 @@ jobs: message: ${{ steps.parse_commands.outputs.help_message }} unit-test-python: - if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} + #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 strategy: @@ -98,7 +98,7 @@ jobs: make unit_test_scala integration-test: - if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} + #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 strategy: From 9fc20962dfc983ce781309b06dad5a4b432e5046 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 18:55:35 +0000 Subject: [PATCH 17/22] fix --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index d060e3d3b..14dcb8e99 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -60,7 +60,7 @@ jobs: command: "make unit_test_py_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - name: Run Python Unit Tests (${{ matrix.name }}) - uses: ./.github/actions/run-command-on-pr + uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} @@ -115,7 +115,7 @@ jobs: command: "make integration_test_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - name: Run Integration Tests (${{ matrix.name }}) - uses: ./.github/actions/run-command-on-pr + uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} From 96dd6eccaa592ca412d0ff566c589c29c7ac9f18 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 18:57:52 +0000 Subject: [PATCH 18/22] update --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 14dcb8e99..ad5737ea8 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -64,7 +64,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} - should_leave_progress_comments: "true" + #should_leave_progress_comments: "true" descriptive_workflow_name: "Python Unit Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests @@ -119,7 +119,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} pr_number: ${{ github.event.issue.number }} - should_leave_progress_comments: "true" + # should_leave_progress_comments: "true" descriptive_workflow_name: "Integration Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests From d8b2f0f03bcfdecdcce3f668ad0b0c7a57707e2a Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 19:10:48 +0000 Subject: [PATCH 19/22] update --- .github/workflows/on-pr-comment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index ad5737ea8..1e56746b4 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -63,7 +63,7 @@ jobs: uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} - pr_number: ${{ github.event.issue.number }} + pr_number: 520 #should_leave_progress_comments: "true" descriptive_workflow_name: "Python Unit Test (${{ matrix.name }})" setup_gcloud: "true" @@ -118,7 +118,7 @@ jobs: uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} - pr_number: ${{ github.event.issue.number }} + pr_number: 520 # should_leave_progress_comments: "true" descriptive_workflow_name: "Integration Test (${{ matrix.name }})" setup_gcloud: "true" From 99b18eef65b51583c9c5a5b7bb6c80f78b44266c Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 21:48:30 +0000 Subject: [PATCH 20/22] update --- .github/workflows/on-pr-comment.yml | 13 ++++---- .github/workflows/on-pr-merge.yml | 46 ++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/.github/workflows/on-pr-comment.yml b/.github/workflows/on-pr-comment.yml index 1e56746b4..3b166d57d 100644 --- a/.github/workflows/on-pr-comment.yml +++ b/.github/workflows/on-pr-comment.yml @@ -3,7 +3,6 @@ name: On Demand Pr Comment Workflows on: issue_comment: types: [created] - push: # TODO: revert permissions: # Needed for gcloud auth: https://github.com/google-github-actions/auth @@ -41,7 +40,7 @@ jobs: message: ${{ steps.parse_commands.outputs.help_message }} unit-test-python: - #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} + if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/unit_test_py') || endsWith(github.event.comment.body, '/unit_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 strategy: @@ -63,8 +62,8 @@ jobs: uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} - pr_number: 520 - #should_leave_progress_comments: "true" + pr_number: ${{ github.event.issue.number }} + should_leave_progress_comments: "true" descriptive_workflow_name: "Python Unit Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests @@ -98,7 +97,7 @@ jobs: make unit_test_scala integration-test: - #if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} + if: ${{ github.event.issue.pull_request && (contains(github.event.comment.body, '/integration_test') || contains(github.event.comment.body, '/all_test')) }} runs-on: ubuntu-latest timeout-minutes: 60 strategy: @@ -118,8 +117,8 @@ jobs: uses: snapchat/gigl/.github/actions/run-command-on-pr@main with: github-token: ${{ secrets.GITHUB_TOKEN }} - pr_number: 520 - # should_leave_progress_comments: "true" + pr_number: ${{ github.event.issue.number }} + should_leave_progress_comments: "true" descriptive_workflow_name: "Integration Test (${{ matrix.name }})" setup_gcloud: "true" # We use cloud run here instead of using github hosted runners because of limitation of tests diff --git a/.github/workflows/on-pr-merge.yml b/.github/workflows/on-pr-merge.yml index 02e7e75db..46037b4af 100644 --- a/.github/workflows/on-pr-merge.yml +++ b/.github/workflows/on-pr-merge.yml @@ -23,6 +23,20 @@ jobs: # Our tests take a long time to run, so this is not ideal. if: github.event_name == 'merge_group' runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - name: "Type Check" + command: "make type_check" + - name: "Shard 0" + command: "make unit_test_py_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make unit_test_py_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make unit_test_py_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make unit_test_py_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - uses: actions/checkout@v4 - name: Setup development environment @@ -32,16 +46,14 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.workload_identity_provider }} gcp_service_account_email: ${{ secrets.gcp_service_account_email }} - - name: Run Python Unit Tests (sharded) + - name: Run Python Unit Tests (${{ matrix.name }}) # We use Cloud Build instead of GitHub hosted runners because of limitation of tests # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand # how to leverage Workload Identity Federation to read assets from GCS, et al. See: # https://github.com/tensorflow/tensorflow/issues/57104 - uses: ./.github/actions/run-sharded-cloud-build + uses: ./.github/actions/run-cloud-run-command-on-active-checkout with: - test_command: "unit_test_py_shard" - total_shards: "4" - run_type_check: "true" + cmd: ${{ matrix.command }} service_account: ${{ secrets.gcp_service_account_email }} project: ${{ vars.GCP_PROJECT_ID }} @@ -75,6 +87,18 @@ jobs: ci-integration-test: if: github.event_name == 'merge_group' runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - name: "Shard 0" + command: "make integration_test_shard SHARD_INDEX=0 TOTAL_SHARDS=4" + - name: "Shard 1" + command: "make integration_test_shard SHARD_INDEX=1 TOTAL_SHARDS=4" + - name: "Shard 2" + command: "make integration_test_shard SHARD_INDEX=2 TOTAL_SHARDS=4" + - name: "Shard 3" + command: "make integration_test_shard SHARD_INDEX=3 TOTAL_SHARDS=4" steps: - uses: actions/checkout@v4 - name: Setup development environment @@ -84,12 +108,14 @@ jobs: gcp_project_id: ${{ vars.GCP_PROJECT_ID }} workload_identity_provider: ${{ secrets.workload_identity_provider }} gcp_service_account_email: ${{ secrets.gcp_service_account_email }} - - name: Run Integration Tests (sharded) - uses: ./.github/actions/run-sharded-cloud-build + - name: Run Integration Tests (${{ matrix.name }}) + # We use Cloud Build instead of GitHub hosted runners because of limitation of tests + # using GFile library (a.k.a anything that does IO w/ Tensorflow). GFile does not understand + # how to leverage Workload Identity Federation to read assets from GCS, et al. See: + # https://github.com/tensorflow/tensorflow/issues/57104 + uses: ./.github/actions/run-cloud-run-command-on-active-checkout with: - test_command: "integration_test_shard" - total_shards: "4" - run_type_check: "false" + cmd: ${{ matrix.command }} service_account: ${{ secrets.gcp_service_account_email }} project: ${{ vars.GCP_PROJECT_ID }} From 304b8e402d8fa9f9eab8c46ac2cc493b3afd17eb Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 5 Mar 2026 21:50:17 +0000 Subject: [PATCH 21/22] revert --- .github/actions/run-command-on-pr/action.yml | 30 +--- .../run-sharded-cloud-build/action.yml | 148 ------------------ 2 files changed, 1 insertion(+), 177 deletions(-) delete mode 100644 .github/actions/run-sharded-cloud-build/action.yml diff --git a/.github/actions/run-command-on-pr/action.yml b/.github/actions/run-command-on-pr/action.yml index 5f8781bb0..c20bdb4d7 100644 --- a/.github/actions/run-command-on-pr/action.yml +++ b/.github/actions/run-command-on-pr/action.yml @@ -48,24 +48,6 @@ inputs: required: false default: "" - # Parameters for sharded Cloud Build execution - use_sharded_cloud_build: - description: "Whether to use sharded Cloud Build instead of a single Cloud Build job" - required: false - default: "false" - total_shards: - description: "Number of parallel test shards (only used when use_sharded_cloud_build is true)" - required: false - default: "4" - sharded_test_command: - description: "The Make target to run for each shard (e.g. 'unit_test_py_shard')" - required: false - default: "" - run_type_check: - description: "Whether to launch a separate type-check Cloud Build job (only used when use_sharded_cloud_build is true)" - required: false - default: "false" - runs: using: "composite" steps: @@ -110,23 +92,13 @@ runs: run: ${{ inputs.command }} - name: Run specified command on the PR branch using Cloud Run - if: ${{ inputs.use_cloud_run == 'true' && inputs.use_sharded_cloud_build != 'true' }} + if: ${{ inputs.use_cloud_run == 'true' }} uses: snapchat/gigl/.github/actions/run-cloud-run-command-on-active-checkout@main with: cmd: ${{ inputs.command }} service_account: ${{ inputs.gcp_service_account_email }} project: ${{ inputs.gcp_project_id }} - - name: Run sharded Cloud Build tests - if: ${{ inputs.use_sharded_cloud_build == 'true' }} - uses: ./.github/actions/run-sharded-cloud-build - with: - test_command: ${{ inputs.sharded_test_command }} - total_shards: ${{ inputs.total_shards }} - run_type_check: ${{ inputs.run_type_check }} - service_account: ${{ inputs.gcp_service_account_email }} - project: ${{ inputs.gcp_project_id }} - - name: Commment workflow succeeded if: ${{ inputs.should_leave_progress_comments == 'true' }} uses: snapchat/gigl/.github/actions/comment-on-pr@main diff --git a/.github/actions/run-sharded-cloud-build/action.yml b/.github/actions/run-sharded-cloud-build/action.yml deleted file mode 100644 index c8a70422a..000000000 --- a/.github/actions/run-sharded-cloud-build/action.yml +++ /dev/null @@ -1,148 +0,0 @@ -name: "Run Sharded Cloud Build" -description: "Launches N Cloud Build shard jobs (and optionally a type-check job) in parallel, then waits for all to complete." -inputs: - test_command: - description: "The Make target to run for each shard (e.g. 'unit_test_py_shard')" - required: true - total_shards: - description: "Number of parallel test shards" - required: true - run_type_check: - description: "Whether to launch a separate type-check Cloud Build job" - required: false - default: "false" - service_account: - description: "Service account email for Cloud Build" - required: true - project: - description: "Google Cloud Project ID" - required: true - machine_type: - description: "Machine type for Cloud Build jobs" - default: "e2-highcpu-32" - timeout: - description: "Timeout for each Cloud Build job (duration format, e.g. '3h')" - default: "3h" - -runs: - using: "composite" - steps: - - name: Launch sharded Cloud Build jobs - id: launch_builds - shell: bash - run: | - set -euo pipefail - - BUILD_IDS="" - BUILD_LABELS="" - TOTAL_SHARDS=${{ inputs.total_shards }} - PROJECT="${{ inputs.project }}" - CLOUD_BUILD_URL_BASE="https://console.cloud.google.com/cloud-build/builds" - - # Launch type-check job if enabled - if [[ "${{ inputs.run_type_check }}" == "true" ]]; then - echo "Launching type-check Cloud Build job..." - TYPE_CHECK_BUILD_ID=$(gcloud builds submit . \ - --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ - --substitutions=_CMD="make type_check_only" \ - --service-account="projects/${PROJECT}/serviceAccounts/${{ inputs.service_account }}" \ - --project="${PROJECT}" \ - --machine-type="${{ inputs.machine_type }}" \ - --timeout="${{ inputs.timeout }}" \ - --async \ - --format='value(id)' 2>&1 | tail -1) - BUILD_IDS="${TYPE_CHECK_BUILD_ID}" - BUILD_LABELS="type_check" - echo " Type-check build: ${CLOUD_BUILD_URL_BASE}/${TYPE_CHECK_BUILD_ID}?project=${PROJECT}" - fi - - # Launch shard jobs - for i in $(seq 0 $(( TOTAL_SHARDS - 1 ))); do - echo "Launching shard ${i}/${TOTAL_SHARDS} Cloud Build job..." - SHARD_BUILD_ID=$(gcloud builds submit . \ - --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ - --substitutions=_CMD="make ${{ inputs.test_command }} SHARD_INDEX=${i} TOTAL_SHARDS=${TOTAL_SHARDS}" \ - --service-account="projects/${PROJECT}/serviceAccounts/${{ inputs.service_account }}" \ - --project="${PROJECT}" \ - --machine-type="${{ inputs.machine_type }}" \ - --timeout="${{ inputs.timeout }}" \ - --async \ - --format='value(id)' 2>&1 | tail -1) - if [[ -n "${BUILD_IDS}" ]]; then - BUILD_IDS="${BUILD_IDS},${SHARD_BUILD_ID}" - BUILD_LABELS="${BUILD_LABELS},shard_${i}" - else - BUILD_IDS="${SHARD_BUILD_ID}" - BUILD_LABELS="shard_${i}" - fi - echo " Shard ${i} build: ${CLOUD_BUILD_URL_BASE}/${SHARD_BUILD_ID}?project=${PROJECT}" - done - - echo "build_ids=${BUILD_IDS}" >> $GITHUB_OUTPUT - echo "build_labels=${BUILD_LABELS}" >> $GITHUB_OUTPUT - echo "All Cloud Build jobs launched." - - - name: Wait for all Cloud Build jobs to complete - shell: bash - run: | - set -euo pipefail - - IFS=',' read -ra BUILD_IDS <<< "${{ steps.launch_builds.outputs.build_ids }}" - IFS=',' read -ra BUILD_LABELS <<< "${{ steps.launch_builds.outputs.build_labels }}" - - PIDS=() - TEMP_DIR=$(mktemp -d) - - # Stream logs for each build in background and record exit status - for idx in "${!BUILD_IDS[@]}"; do - BUILD_ID="${BUILD_IDS[$idx]}" - LABEL="${BUILD_LABELS[$idx]}" - ( - echo "=== Streaming logs for ${LABEL} (${BUILD_ID}) ===" - gcloud builds log --stream "${BUILD_ID}" --project="${{ inputs.project }}" 2>&1 | \ - sed "s/^/[${LABEL}] /" || true - - # Check final build status - STATUS=$(gcloud builds describe "${BUILD_ID}" \ - --project="${{ inputs.project }}" \ - --format='value(status)') - echo "=== ${LABEL} (${BUILD_ID}) finished with status: ${STATUS} ===" - if [[ "${STATUS}" != "SUCCESS" ]]; then - echo "FAILED" > "${TEMP_DIR}/${LABEL}" - else - echo "SUCCESS" > "${TEMP_DIR}/${LABEL}" - fi - ) & - PIDS+=($!) - done - - # Wait for all background processes - for PID in "${PIDS[@]}"; do - wait "${PID}" || true - done - - # Check results - FAILED=0 - PROJECT="${{ inputs.project }}" - CLOUD_BUILD_URL_BASE="https://console.cloud.google.com/cloud-build/builds" - for idx in "${!BUILD_LABELS[@]}"; do - LABEL="${BUILD_LABELS[$idx]}" - BUILD_ID="${BUILD_IDS[$idx]}" - BUILD_URL="${CLOUD_BUILD_URL_BASE}/${BUILD_ID}?project=${PROJECT}" - RESULT=$(cat "${TEMP_DIR}/${LABEL}" 2>/dev/null || echo "UNKNOWN") - if [[ "${RESULT}" != "SUCCESS" ]]; then - echo "FAILED: ${LABEL} - ${BUILD_URL}" - FAILED=1 - else - echo "PASSED: ${LABEL} - ${BUILD_URL}" - fi - done - - rm -rf "${TEMP_DIR}" - - if [[ "${FAILED}" -eq 1 ]]; then - echo "One or more Cloud Build jobs failed." - exit 1 - fi - - echo "All Cloud Build jobs completed successfully." From 31fc53c30d541f4ba4541f3b265d0724c693fa6a Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 6 Mar 2026 05:51:59 +0000 Subject: [PATCH 22/22] updatte --- .github/actions/run-command-on-pr/action.yml | 30 +--- .../run-sharded-cloud-build/action.yml | 148 ------------------ 2 files changed, 1 insertion(+), 177 deletions(-) delete mode 100644 .github/actions/run-sharded-cloud-build/action.yml diff --git a/.github/actions/run-command-on-pr/action.yml b/.github/actions/run-command-on-pr/action.yml index 5f8781bb0..c20bdb4d7 100644 --- a/.github/actions/run-command-on-pr/action.yml +++ b/.github/actions/run-command-on-pr/action.yml @@ -48,24 +48,6 @@ inputs: required: false default: "" - # Parameters for sharded Cloud Build execution - use_sharded_cloud_build: - description: "Whether to use sharded Cloud Build instead of a single Cloud Build job" - required: false - default: "false" - total_shards: - description: "Number of parallel test shards (only used when use_sharded_cloud_build is true)" - required: false - default: "4" - sharded_test_command: - description: "The Make target to run for each shard (e.g. 'unit_test_py_shard')" - required: false - default: "" - run_type_check: - description: "Whether to launch a separate type-check Cloud Build job (only used when use_sharded_cloud_build is true)" - required: false - default: "false" - runs: using: "composite" steps: @@ -110,23 +92,13 @@ runs: run: ${{ inputs.command }} - name: Run specified command on the PR branch using Cloud Run - if: ${{ inputs.use_cloud_run == 'true' && inputs.use_sharded_cloud_build != 'true' }} + if: ${{ inputs.use_cloud_run == 'true' }} uses: snapchat/gigl/.github/actions/run-cloud-run-command-on-active-checkout@main with: cmd: ${{ inputs.command }} service_account: ${{ inputs.gcp_service_account_email }} project: ${{ inputs.gcp_project_id }} - - name: Run sharded Cloud Build tests - if: ${{ inputs.use_sharded_cloud_build == 'true' }} - uses: ./.github/actions/run-sharded-cloud-build - with: - test_command: ${{ inputs.sharded_test_command }} - total_shards: ${{ inputs.total_shards }} - run_type_check: ${{ inputs.run_type_check }} - service_account: ${{ inputs.gcp_service_account_email }} - project: ${{ inputs.gcp_project_id }} - - name: Commment workflow succeeded if: ${{ inputs.should_leave_progress_comments == 'true' }} uses: snapchat/gigl/.github/actions/comment-on-pr@main diff --git a/.github/actions/run-sharded-cloud-build/action.yml b/.github/actions/run-sharded-cloud-build/action.yml deleted file mode 100644 index c8a70422a..000000000 --- a/.github/actions/run-sharded-cloud-build/action.yml +++ /dev/null @@ -1,148 +0,0 @@ -name: "Run Sharded Cloud Build" -description: "Launches N Cloud Build shard jobs (and optionally a type-check job) in parallel, then waits for all to complete." -inputs: - test_command: - description: "The Make target to run for each shard (e.g. 'unit_test_py_shard')" - required: true - total_shards: - description: "Number of parallel test shards" - required: true - run_type_check: - description: "Whether to launch a separate type-check Cloud Build job" - required: false - default: "false" - service_account: - description: "Service account email for Cloud Build" - required: true - project: - description: "Google Cloud Project ID" - required: true - machine_type: - description: "Machine type for Cloud Build jobs" - default: "e2-highcpu-32" - timeout: - description: "Timeout for each Cloud Build job (duration format, e.g. '3h')" - default: "3h" - -runs: - using: "composite" - steps: - - name: Launch sharded Cloud Build jobs - id: launch_builds - shell: bash - run: | - set -euo pipefail - - BUILD_IDS="" - BUILD_LABELS="" - TOTAL_SHARDS=${{ inputs.total_shards }} - PROJECT="${{ inputs.project }}" - CLOUD_BUILD_URL_BASE="https://console.cloud.google.com/cloud-build/builds" - - # Launch type-check job if enabled - if [[ "${{ inputs.run_type_check }}" == "true" ]]; then - echo "Launching type-check Cloud Build job..." - TYPE_CHECK_BUILD_ID=$(gcloud builds submit . \ - --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ - --substitutions=_CMD="make type_check_only" \ - --service-account="projects/${PROJECT}/serviceAccounts/${{ inputs.service_account }}" \ - --project="${PROJECT}" \ - --machine-type="${{ inputs.machine_type }}" \ - --timeout="${{ inputs.timeout }}" \ - --async \ - --format='value(id)' 2>&1 | tail -1) - BUILD_IDS="${TYPE_CHECK_BUILD_ID}" - BUILD_LABELS="type_check" - echo " Type-check build: ${CLOUD_BUILD_URL_BASE}/${TYPE_CHECK_BUILD_ID}?project=${PROJECT}" - fi - - # Launch shard jobs - for i in $(seq 0 $(( TOTAL_SHARDS - 1 ))); do - echo "Launching shard ${i}/${TOTAL_SHARDS} Cloud Build job..." - SHARD_BUILD_ID=$(gcloud builds submit . \ - --config=.github/cloud_builder/run_command_on_active_checkout.yaml \ - --substitutions=_CMD="make ${{ inputs.test_command }} SHARD_INDEX=${i} TOTAL_SHARDS=${TOTAL_SHARDS}" \ - --service-account="projects/${PROJECT}/serviceAccounts/${{ inputs.service_account }}" \ - --project="${PROJECT}" \ - --machine-type="${{ inputs.machine_type }}" \ - --timeout="${{ inputs.timeout }}" \ - --async \ - --format='value(id)' 2>&1 | tail -1) - if [[ -n "${BUILD_IDS}" ]]; then - BUILD_IDS="${BUILD_IDS},${SHARD_BUILD_ID}" - BUILD_LABELS="${BUILD_LABELS},shard_${i}" - else - BUILD_IDS="${SHARD_BUILD_ID}" - BUILD_LABELS="shard_${i}" - fi - echo " Shard ${i} build: ${CLOUD_BUILD_URL_BASE}/${SHARD_BUILD_ID}?project=${PROJECT}" - done - - echo "build_ids=${BUILD_IDS}" >> $GITHUB_OUTPUT - echo "build_labels=${BUILD_LABELS}" >> $GITHUB_OUTPUT - echo "All Cloud Build jobs launched." - - - name: Wait for all Cloud Build jobs to complete - shell: bash - run: | - set -euo pipefail - - IFS=',' read -ra BUILD_IDS <<< "${{ steps.launch_builds.outputs.build_ids }}" - IFS=',' read -ra BUILD_LABELS <<< "${{ steps.launch_builds.outputs.build_labels }}" - - PIDS=() - TEMP_DIR=$(mktemp -d) - - # Stream logs for each build in background and record exit status - for idx in "${!BUILD_IDS[@]}"; do - BUILD_ID="${BUILD_IDS[$idx]}" - LABEL="${BUILD_LABELS[$idx]}" - ( - echo "=== Streaming logs for ${LABEL} (${BUILD_ID}) ===" - gcloud builds log --stream "${BUILD_ID}" --project="${{ inputs.project }}" 2>&1 | \ - sed "s/^/[${LABEL}] /" || true - - # Check final build status - STATUS=$(gcloud builds describe "${BUILD_ID}" \ - --project="${{ inputs.project }}" \ - --format='value(status)') - echo "=== ${LABEL} (${BUILD_ID}) finished with status: ${STATUS} ===" - if [[ "${STATUS}" != "SUCCESS" ]]; then - echo "FAILED" > "${TEMP_DIR}/${LABEL}" - else - echo "SUCCESS" > "${TEMP_DIR}/${LABEL}" - fi - ) & - PIDS+=($!) - done - - # Wait for all background processes - for PID in "${PIDS[@]}"; do - wait "${PID}" || true - done - - # Check results - FAILED=0 - PROJECT="${{ inputs.project }}" - CLOUD_BUILD_URL_BASE="https://console.cloud.google.com/cloud-build/builds" - for idx in "${!BUILD_LABELS[@]}"; do - LABEL="${BUILD_LABELS[$idx]}" - BUILD_ID="${BUILD_IDS[$idx]}" - BUILD_URL="${CLOUD_BUILD_URL_BASE}/${BUILD_ID}?project=${PROJECT}" - RESULT=$(cat "${TEMP_DIR}/${LABEL}" 2>/dev/null || echo "UNKNOWN") - if [[ "${RESULT}" != "SUCCESS" ]]; then - echo "FAILED: ${LABEL} - ${BUILD_URL}" - FAILED=1 - else - echo "PASSED: ${LABEL} - ${BUILD_URL}" - fi - done - - rm -rf "${TEMP_DIR}" - - if [[ "${FAILED}" -eq 1 ]]; then - echo "One or more Cloud Build jobs failed." - exit 1 - fi - - echo "All Cloud Build jobs completed successfully."