diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 2c588b43..f3d579de 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -22,47 +22,38 @@ on: push: branches: [ "main" ] workflow_dispatch: - schedule: - # Run the job every 12 hours - - cron: '0 */12 * * *' jobs: - build: - strategy: - fail-fast: false - matrix: - tpu-type: ["v5p-8"] - name: "TPU test (${{ matrix.tpu-type }})" - runs-on: ["self-hosted","${{ matrix.tpu-type }}"] + maxtext_workload: + name: "Run MaxText Workload" + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) + runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] + container: + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/maxtext-gpu-custom:latest steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.12 - uses: actions/setup-python@v5 - with: - python-version: '3.12' - - name: Install dependencies - run: | - pip install -e . - pip uninstall jax jaxlib libtpu-nightly libtpu -y - bash setup.sh MODE=stable - export PATH=$PATH:$HOME/.local/bin - pip install ruff - pip install isort - pip install pytest - - name: Analysing the code with ruff - run: | - ruff check . - - name: version check - run: | - python --version - pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets - - name: PyTest - run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x -# add_pull_ready: -# if: github.ref != 'refs/heads/main' -# permissions: -# checks: read -# pull-requests: write -# needs: build -# uses: ./.github/workflows/AddLabel.yml + - name: Checkout MaxText Repo + uses: actions/checkout@v4 + with: + repository: AI-Hypercomputer/maxtext + path: maxtext + + - name: Print dependencies + run: | + pip freeze + + - name: Run MaxText Training + run: | + # This command is adapted from your DAG for a single-slice configuration. + cd maxtext && \ + + export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 + export TF_FORCE_GPU_ALLOW_GROWTH=true + export NVTE_FUSED_ATTN=1 + + python3 -m MaxText.train MaxText/configs/base.yml \ + steps=5 \ + enable_checkpointing=false \ + attention=cudnn_flash_te \ + dataset_type=synthetic \ + run_name=rbierneni-test-maxtext-gpu \ + base_output_directory=gs://rbierneni-multipod-dev/maxtext/${{ github.run_id }} diff --git a/maxdiffusion_jax_ai_image_tpu.Dockerfile b/maxdiffusion_jax_ai_image_tpu.Dockerfile index cab50fee..301f9b88 100644 --- a/maxdiffusion_jax_ai_image_tpu.Dockerfile +++ b/maxdiffusion_jax_ai_image_tpu.Dockerfile @@ -19,4 +19,4 @@ COPY . . RUN pip install -r /deps/requirements_with_jax_ai_image.txt # Run the script available in JAX-AI-Image base image to generate the manifest file -RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file +RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/verify_conflict.sh b/verify_conflict.sh new file mode 100644 index 00000000..54454f7c --- /dev/null +++ b/verify_conflict.sh @@ -0,0 +1,22 @@ +print("--- PyTorch vs. JAX Conflict Test ---") + +print("\nStep 1: Attempting to import torch...") +try: + import torch + print(f"Successfully imported torch version: {torch.__version__}") + # This check will confirm you have the CPU-only version + print(f"Is PyTorch using CUDA? -> {torch.cuda.is_available()}") +except Exception as e: + print(f"Failed to import torch: {e}") + + +print("\nStep 2: Now, attempting to initialize JAX...") +try: + import jax + devices = jax.devices() + print("\n--- RESULT: SUCCESS ---") + print(f"JAX initialized correctly and found devices: {devices}") +except Exception as e: + print("\n--- RESULT: FAILURE ---") + print("JAX failed to initialize after PyTorch was imported.") + print(f"JAX Error: {e}") \ No newline at end of file