-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Extended fix OOM Issue #21634 on Keras side #21755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Extended fix OOM Issue #21634 on Keras side #21755
Conversation
Summary of ChangesHello @praveenhosdrug123, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements a crucial fix to prevent Out-Of-Memory errors during the initialization of very large Keras models, particularly on resource-constrained TPU devices. By introducing a mechanism to distribute the initialization of token embeddings, it ensures that model weights are sharded across available devices from the outset, significantly reducing memory consumption and enabling the successful loading of models that previously failed. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new function, _distribute_initializer, to handle the distribution of token embedding initializers in the JAX backend. This function aims to resolve OOM errors encountered during the initialization of large models on TPUs with limited HBM by sharding weights across TPU devices during instantiation. The code includes argument validation, sharding logic based on tensor layout, and application of mean/stddev for relevant distributions. The review focuses on error handling, code clarity, and adherence to the Keras API design guidelines.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21755 +/- ##
==========================================
+ Coverage 76.27% 76.30% +0.02%
==========================================
Files 579 579
Lines 59917 59971 +54
Branches 9403 9411 +8
==========================================
+ Hits 45703 45762 +59
+ Misses 11748 11741 -7
- Partials 2466 2468 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d33cb4e to
b36b051
Compare
|
Thank you for the investigation. This is indeed an issue. Somebody on the team is working on a fix that's generally applicable to all variables so that you don't have to explicitly use the fix that you provided here. |
|
@hertschuh - Thanks for the feedback and for taking the time to review the document. I want to clarify the technical issue: Thank you for the context on the general solution. A few follow-up questions to help me understand the timeline:
The reason I ask: users are blocked on this today for 7B+ models on 8GB TPU devices.
Let me know if that's feasible. |
I'm confused about how this particular fix addresses the issue though. The |
|
@hertschuh - Thank you :) |
b36b051 to
66a47a1
Compare
31b8a71 to
77d1b35
Compare
|
@hertschuh and @amitsrivastava78 -- The other backends are unaffected by the change. I would thoroughly appreciate your review and guidance on the PR now. |
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for all the work on this!
I was thinking that as a first step, we would do a much narrower fix, just for JAX.
In jax/core.py, in JaxVariable, we would add this.
def _initialize_with_initializer(self, initializer):
from keras.src.distribution import TensorLayout
distribution = global_state.get_global_attribute("distribution")
layout = None
if self._layout is not None:
layout = self._layout
elif distribution is not None:
layout = distribution.get_variable_layout(self)
if isinstance(layout, TensorLayout):
layout = tensor_layout.backend_layout
if layout is not None:
initializer = jax.jit(initializer, out_shardings=layout)
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
self._direct_assign(value)Disclaimer: this is completely untested code that I wrote here, not in an IDE. This might need to be changed to add the correct RNG support.
Hopefully this would get most of the job done.
Now, I totally agree with your idea that:
- random ops should have a
layoutargument - initializers should have a
layoutargument
I was thinking along those lines too.
But I'd like to do that as a separate step and PR. In fact the code above would still be used for backwards compatibility for people who wrote their custom Initializer subclass that don't take the layout argument.
I also want to think carefully about the RNGs and how they will be passed. I think there might be a simpler approach using a StatelessScope.
Thanks!
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the update!
Were you able to test this on a large model? Also, did you notice a slowdown in the initialization?
| if backend.backend() == "jax": | ||
| raise NotImplementedError | ||
| else: | ||
| value = self._convert_to_tensor( | ||
| initializer(self._shape, dtype=self._dtype) | ||
| ) | ||
| self._initialize(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please undo this file. It's generally understood that methods can have overrides, and that if some overrides are missing bad things happens. But we don't throw when a known override is missing.
Moreover, we want to support JaxVariable._initialize_with_initializer calling super()._initialize_with_initializer.
| def _initialize(self, value): | ||
| # Note that variable.shape is needed by distribution_lib | ||
| self._shape = self._validate_shape(value.shape) | ||
| def set_tensor_layout(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename set_tensor_layout to _initialize_layout.
- it's private
- it's not really a setter, there is no
valueargument
| out_shardings=layout, | ||
| static_argnames=["shape", "dtype"], | ||
| ) | ||
| value = jax.device_put( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need the device_put? Doesn't out_shardings put it on the device automatically?
| def _initialize_with_initializer(self, initializer): | ||
| value = self._convert_to_tensor( | ||
| initializer(self._shape, dtype=self._dtype) | ||
| ) | ||
| self._initialize(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the issue with NNX? Why can't we do the sharding?
| def check_layout_spec_nnx(init_layout): | ||
| nnx_enabled = config.is_nnx_enabled() | ||
| is_named_sharding = isinstance(init_layout, jax.sharding.NamedSharding) | ||
| # Check if PartitionSpec has any non-None values | ||
| spec = getattr(init_layout, "spec", None) | ||
| partition_spec = spec if spec is not None else () | ||
| is_partitioned = any(dim is not None for dim in partition_spec) | ||
| return is_partitioned and is_named_sharding and not nnx_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this? What is the issue with NNX?
Also, check_layout_spec_nnx is only called in JaxVariable._initialize_with_initializer. However, NnxVariable overrides _initialize_with_initializer and doesn't call super.
So my reading of this is that it's not actually used.
Summary
Applies distributed initialization fix to model backbone to resolve OOM errors during initialization of 7B+ parameter models on 8GB TPU devices. This PR adds a helper function to distribute the initializers at time of instantiation.
Issue
Token embedding initialization creates large arrays at time of creation, placing all weights on a single device.
Combined with forward passes during backbone initialization, this causes a 2X to 3X memory spike and triggers OOM on TPUs with limited HBM.
Solution
Implements _distribute_initializer helper that wraps embedding initializers with explicit TensorLayout, properly sharding weights across TPU devices during instantiation. Validated on 8-device TPU: models that previously OOM'd during backbone initialization now load successfully.
Reference
For memory profiling analysis, cache locality theory, validation logs and alternative solutions considered, refer to: Doc
Related PR: keras-team/keras-hub#2441
Issue: #21634