Skip to content

Conversation

@praveenhosdrug123
Copy link
Contributor

@praveenhosdrug123 praveenhosdrug123 commented Oct 17, 2025

Summary

Applies distributed initialization fix to model backbone to resolve OOM errors during initialization of 7B+ parameter models on 8GB TPU devices. This PR adds a helper function to distribute the initializers at time of instantiation.

Issue

Token embedding initialization creates large arrays at time of creation, placing all weights on a single device.
Combined with forward passes during backbone initialization, this causes a 2X to 3X memory spike and triggers OOM on TPUs with limited HBM.

Solution

Implements _distribute_initializer helper that wraps embedding initializers with explicit TensorLayout, properly sharding weights across TPU devices during instantiation. Validated on 8-device TPU: models that previously OOM'd during backbone initialization now load successfully.

Reference

For memory profiling analysis, cache locality theory, validation logs and alternative solutions considered, refer to: Doc

Related PR: keras-team/keras-hub#2441

Issue: #21634

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @praveenhosdrug123, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a crucial fix to prevent Out-Of-Memory errors during the initialization of very large Keras models, particularly on resource-constrained TPU devices. By introducing a mechanism to distribute the initialization of token embeddings, it ensures that model weights are sharded across available devices from the outset, significantly reducing memory consumption and enabling the successful loading of models that previously failed.

Highlights

  • OOM Resolution for Large Models: Addresses Out-Of-Memory (OOM) errors encountered when initializing 7B+ parameter models on 8GB TPU devices by distributing token embedding initialization.
  • New Helper Function _distribute_initializer: Introduces _distribute_initializer in keras/src/backend/jax/distribution_lib.py to enable distribution-aware token embedding initialization for the JAX backend.
  • Distributed Initialization Logic: The new helper function wraps JAX random initializers with explicit TensorLayout to properly shard weights across TPU devices during instantiation, preventing memory spikes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new function, _distribute_initializer, to handle the distribution of token embedding initializers in the JAX backend. This function aims to resolve OOM errors encountered during the initialization of large models on TPUs with limited HBM by sharding weights across TPU devices during instantiation. The code includes argument validation, sharding logic based on tensor layout, and application of mean/stddev for relevant distributions. The review focuses on error handling, code clarity, and adherence to the Keras API design guidelines.

@codecov-commenter
Copy link

codecov-commenter commented Oct 17, 2025

Codecov Report

❌ Patch coverage is 88.00000% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.30%. Comparing base (7631c1a) to head (115b566).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/core.py 86.36% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21755      +/-   ##
==========================================
+ Coverage   76.27%   76.30%   +0.02%     
==========================================
  Files         579      579              
  Lines       59917    59971      +54     
  Branches     9403     9411       +8     
==========================================
+ Hits        45703    45762      +59     
+ Misses      11748    11741       -7     
- Partials     2466     2468       +2     
Flag Coverage Δ
keras 76.17% <88.00%> (+0.02%) ⬆️
keras-jax 62.16% <76.00%> (-0.54%) ⬇️
keras-numpy 57.33% <28.00%> (-0.04%) ⬇️
keras-openvino 34.27% <28.00%> (-0.02%) ⬇️
keras-torch 63.25% <28.00%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@hertschuh
Copy link
Collaborator

@praveenhosdrug123

Thank you for the investigation. This is indeed an issue.

Somebody on the team is working on a fix that's generally applicable to all variables so that you don't have to explicitly use the fix that you provided here.

@praveenhosdrug123
Copy link
Contributor Author

@hertschuh - Thanks for the feedback and for taking the time to review the document.

I want to clarify the technical issue:
The OOM problem is about large contiguous memory allocation, not total parameter count. Token embeddings are the largest single array and exceed device memory during initialization, even when the full model would fit after sharding.

Thank you for the context on the general solution. A few follow-up questions to help me understand the timeline:

  1. What's the expected completion date for the general fix?
  2. Will it handle the edge cases mentioned in the document (interleaving, quantization, LoRA)?
  3. Will it detect which variables actually need distribution?

The reason I ask: users are blocked on this today for 7B+ models on 8GB TPU devices.
If the general fix is months out, would it make sense to:

  • Merge this targeted fix as a stopgap
  • Mark it deprecated once the general solution ships
  • Remove it in a future release

Let me know if that's feasible.

@hertschuh
Copy link
Collaborator

@praveenhosdrug123

Merge this targeted fix as a stopgap

I'm confused about how this particular fix addresses the issue though. The _distribute_initializer function is not used anywhere.

@praveenhosdrug123
Copy link
Contributor Author

@hertschuh - Thank you :)
The _distribute_initializer function is used in keras_hub/src/utils/dist_initializer.py as part of the following pull request keras-team/keras-hub#2441 .

@praveenhosdrug123
Copy link
Contributor Author

praveenhosdrug123 commented Nov 18, 2025

@hertschuh and @amitsrivastava78 --
The PR is ready for review. Since the implementation revises the head and tail of the flow, there are some test failures in the distribution library and random library tests in CI. These failures occur because the implementation changes the expected flow for these tests. The same tests pass locally with the updated implementation.

The other backends are unaffected by the change.

I would thoroughly appreciate your review and guidance on the PR now.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@praveenhosdrug123

Thank you for all the work on this!

I was thinking that as a first step, we would do a much narrower fix, just for JAX.

In jax/core.py, in JaxVariable, we would add this.

    def _initialize_with_initializer(self, initializer):
        from keras.src.distribution import TensorLayout

        distribution = global_state.get_global_attribute("distribution")
        layout = None
        if self._layout is not None:
            layout = self._layout
        elif distribution is not None:
            layout = distribution.get_variable_layout(self)

         if isinstance(layout, TensorLayout):
             layout = tensor_layout.backend_layout

         if layout is not None:
             initializer = jax.jit(initializer, out_shardings=layout)

         value = self._convert_to_tensor(
             initializer(self._shape, dtype=self._dtype)
         )
         self._direct_assign(value)

Disclaimer: this is completely untested code that I wrote here, not in an IDE. This might need to be changed to add the correct RNG support.

Hopefully this would get most of the job done.

Now, I totally agree with your idea that:

  • random ops should have a layout argument
  • initializers should have a layout argument

I was thinking along those lines too.

But I'd like to do that as a separate step and PR. In fact the code above would still be used for backwards compatibility for people who wrote their custom Initializer subclass that don't take the layout argument.

I also want to think carefully about the RNGs and how they will be passed. I think there might be a simpler approach using a StatelessScope.

Thanks!

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the update!

Were you able to test this on a large model? Also, did you notice a slowdown in the initialization?

Comment on lines +417 to +423
if backend.backend() == "jax":
raise NotImplementedError
else:
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
self._initialize(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please undo this file. It's generally understood that methods can have overrides, and that if some overrides are missing bad things happens. But we don't throw when a known override is missing.

Moreover, we want to support JaxVariable._initialize_with_initializer calling super()._initialize_with_initializer.

def _initialize(self, value):
# Note that variable.shape is needed by distribution_lib
self._shape = self._validate_shape(value.shape)
def set_tensor_layout(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename set_tensor_layout to _initialize_layout.

  • it's private
  • it's not really a setter, there is no value argument

out_shardings=layout,
static_argnames=["shape", "dtype"],
)
value = jax.device_put(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the device_put? Doesn't out_shardings put it on the device automatically?

Comment on lines +136 to +140
def _initialize_with_initializer(self, initializer):
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
self._initialize(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the issue with NNX? Why can't we do the sharding?

Comment on lines +318 to +325
def check_layout_spec_nnx(init_layout):
nnx_enabled = config.is_nnx_enabled()
is_named_sharding = isinstance(init_layout, jax.sharding.NamedSharding)
# Check if PartitionSpec has any non-None values
spec = getattr(init_layout, "spec", None)
partition_spec = spec if spec is not None else ()
is_partitioned = any(dim is not None for dim in partition_spec)
return is_partitioned and is_named_sharding and not nnx_enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this? What is the issue with NNX?

Also, check_layout_spec_nnx is only called in JaxVariable._initialize_with_initializer. However, NnxVariable overrides _initialize_with_initializer and doesn't call super.

So my reading of this is that it's not actually used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants