build(jax): split pinned cpu and gpu groups#5436
build(jax): split pinned cpu and gpu groups#5436njzjz-bot wants to merge 1 commit intodeepmodeling:masterfrom
Conversation
📝 WalkthroughWalkthroughThe PR splits the unified ChangesJAX Dependency CPU/GPU Split
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5436 +/- ##
=======================================
Coverage 82.47% 82.47%
=======================================
Files 825 825
Lines 87721 87721
Branches 4206 4206
=======================================
+ Hits 72344 72345 +1
+ Misses 14094 14093 -1
Partials 1283 1283 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Keep the public jax extra unchanged and only split the pinned dependency groups used by CPU and CUDA CI. CPU uses plain jax, while GPU uses jax[cuda12]. Authored by OpenClaw (model: gpt-5.5)
6bbe50b to
2d024de
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@pyproject.toml`:
- Around line 178-183: Add explicit user-facing extras "jax-cpu" and "jax-gpu"
under [tool.deepmd_build_backend.optional-dependencies], mapping "jax-cpu" to
the pin_jax_cpu group and "jax-gpu" to the pin_jax_gpu group (so CI pinning and
user extras stay aligned); remove the duplicated entries from the existing "jax"
extra and make "jax" a single, deduplicated reference (or a thin alias) to the
CPU variant as appropriate to avoid the uv#8601 workaround duplication. Ensure
the unique identifiers pin_jax_cpu, pin_jax_gpu and the existing jax extra are
updated so that the new jax-cpu and jax-gpu extras appear in package metadata.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 358a2778-1d17-4e4e-ae3e-2d49ac2b28fa
📒 Files selected for processing (4)
.github/workflows/test_cc.yml.github/workflows/test_cuda.yml.github/workflows/test_python.ymlpyproject.toml
🚧 Files skipped from review as they are similar to previous changes (3)
- .github/workflows/test_python.yml
- .github/workflows/test_cuda.yml
- .github/workflows/test_cc.yml
There was a problem hiding this comment.
Pull request overview
Splits the CI JAX dependency pins into separate CPU and CUDA groups so CPU and CUDA workflows can each install the appropriate pinned JAX variant while keeping the public jax extra unchanged.
Changes:
- Replaced the single
pin_jaxdependency group withpin_jax_cpuandpin_jax_gpuinpyproject.toml. - Updated CPU workflows to install
--group pin_jax_cpu. - Updated the CUDA workflow to install
--group pin_jax_gpu(pinningjax[cuda12]).
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| pyproject.toml | Splits the pinned JAX dependency group into CPU/GPU variants. |
| .github/workflows/test_python.yml | Uses pin_jax_cpu for CPU Python test jobs. |
| .github/workflows/test_cuda.yml | Uses pin_jax_gpu for CUDA jobs (pins jax[cuda12]). |
| .github/workflows/test_cc.yml | Uses pin_jax_cpu for CPU C++/Python dependency install step. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Problem
jaxextra should stay unchanged.Change
jaxoptional dependency as plainjax.pin_jax_cpuandpin_jax_gpu.pin_jax_cpuin CPU jobs andpin_jax_gpuin the CUDA job.Validation
git diff --checkuv pip compile pyproject.toml --group pin_jax_cpu --python-version 3.10uv pip compile pyproject.toml --group pin_jax_gpu --python-version 3.10Authored by OpenClaw (model: gpt-5.5)
Summary by CodeRabbit