Skip to content

Add support for On-The-Fly Dynamic SafeTensors loading.#4218

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_935065289
Open

Add support for On-The-Fly Dynamic SafeTensors loading.#4218
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_935065289

Conversation

@copybara-service

Copy link
Copy Markdown
Contributor

Add support for On-The-Fly Dynamic SafeTensors loading.

@codecov

codecov Bot commented Jun 22, 2026

Copy link
Copy Markdown

@copybara-service copybara-service Bot force-pushed the test_935065289 branch 3 times, most recently from 8eb1dd8 to 36bddba Compare June 23, 2026 21:39
@github-actions

Copy link
Copy Markdown

🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces excellent support for dynamic, on-the-fly SafeTensors loading during training and evaluation, bypassing traditional ahead-of-time conversion bottlenecks. The high-level architecture is solid, well-integrated, and provides an elegant pathway to direct-load Hugging Face weights onto TPU meshes.

🔍 General Feedback

  • Great Architecture Design: The integration with JAX's maximal sharding API and Orbax checkpointing is robust and correctly handles distributed downloading across TPU VMs.
  • Code Duplication: There is substantial duplication of functions like _build_multi_axis_stacked_tensor, _build_single_axis_stacked_tensor, and _get_hf_loading_function across multiple files (to_maxtext.py, utils.py, and tensor_handling.py). Consolidating these helpers into tensor_handling.py and importing them elsewhere would significantly improve DRY compliance and long-term maintainability.
  • Robustness: Addressing the loop performance traps and tied weight KeyError will ensure the feature is ready for high-scale production workloads on massive model architectures.

Comment on lines +39 to +84
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_leaf: Any,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device."""
if hasattr(target_leaf, "sharding"):
target_shape = target_leaf.shape
target_sharding = target_leaf.sharding
target_dtype = target_leaf.dtype
else:
target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf
target_sharding = None
target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32

mt_slice_shape = target_shape[2:]

if target_sharding is not None:
stacked_array = jax.jit(
lambda: np.zeros(target_shape, dtype=target_dtype),
out_shardings=target_sharding,
)()
else:
stacked_array = onp.zeros(target_shape, dtype=target_dtype)

# Outer loop iterates through experts
for exp_idx, layer_keys_for_expert in enumerate(hf_source_keys):
# Inner loop iterates through layers for the current expert
for lyr_idx, hf_key_single in enumerate(layer_keys_for_expert):
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)

if target_sharding is not None:
exp_idx_device = jax.device_put(exp_idx)
lyr_idx_device = jax.device_put(lyr_idx)
if hasattr(target_sharding, "spec"):
spec_list = list(target_sharding.spec)[2:]
slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list))
else:
slice_sharding = target_sharding
processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding)
stacked_array = stacked_array.at[exp_idx_device, lyr_idx_device].set(processed_hf_tensor)
else:
stacked_array[exp_idx, lyr_idx] = processed_hf_tensor

return stacked_array

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Critical Performance Trap: Avoid repeated JAX array mutations using `.at[...].set(...)` inside python loops to prevent TPU Host/Device OOM.

Because JAX arrays are immutable, every .at[...].set(...) operation within this nested Python loop allocates a brand new full-sized target JAX array and copies all data into it. For models with large numbers of layers/experts (e.g., $32 \times 32 = 1024$), this results in massive memory churn (allocating and discarding over 200GB+ of intermediate buffers), easily triggering a TPU Host/Device Out-Of-Memory (OOM) error.

Instead, accumulate the individual slices/expert tensors into a standard Python list, perform a single unified np.stack across the desired axes (which JAX/XLA compiles efficiently as a single fused allocation/concat operation), and then place the final stacked array onto the target sharding with a single jax.device_put call.

Suggested change
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_leaf: Any,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device."""
if hasattr(target_leaf, "sharding"):
target_shape = target_leaf.shape
target_sharding = target_leaf.sharding
target_dtype = target_leaf.dtype
else:
target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf
target_sharding = None
target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32
mt_slice_shape = target_shape[2:]
if target_sharding is not None:
stacked_array = jax.jit(
lambda: np.zeros(target_shape, dtype=target_dtype),
out_shardings=target_sharding,
)()
else:
stacked_array = onp.zeros(target_shape, dtype=target_dtype)
# Outer loop iterates through experts
for exp_idx, layer_keys_for_expert in enumerate(hf_source_keys):
# Inner loop iterates through layers for the current expert
for lyr_idx, hf_key_single in enumerate(layer_keys_for_expert):
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
if target_sharding is not None:
exp_idx_device = jax.device_put(exp_idx)
lyr_idx_device = jax.device_put(lyr_idx)
if hasattr(target_sharding, "spec"):
spec_list = list(target_sharding.spec)[2:]
slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list))
else:
slice_sharding = target_sharding
processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding)
stacked_array = stacked_array.at[exp_idx_device, lyr_idx_device].set(processed_hf_tensor)
else:
stacked_array[exp_idx, lyr_idx] = processed_hf_tensor
return stacked_array
def _build_multi_axis_stacked_tensor(
hf_source_keys: List[List[str]],
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_leaf: Any,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device."""
if hasattr(target_leaf, "sharding"):
target_shape = target_leaf.shape
target_sharding = target_leaf.sharding
target_dtype = target_leaf.dtype
else:
target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf
target_sharding = None
target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32
mt_slice_shape = target_shape[2:]
all_expert_tensors = []
# Outer loop iterates through experts
for layer_keys_for_expert in hf_source_keys:
layer_tensors_for_expert = []
# Inner loop iterates through layers for the current expert
for hf_key_single in layer_keys_for_expert:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
layer_tensors_for_expert.append(processed_hf_tensor)
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
stacked_array = np.stack(all_expert_tensors, axis=0).astype(target_dtype)
if target_sharding is not None:
stacked_array = jax.device_put(stacked_array, target_sharding)
return stacked_array

Comment on lines +87 to +149
def _build_single_axis_stacked_tensor(
hf_source_keys: List[str],
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_leaf: Any,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device."""
if hasattr(target_leaf, "sharding"):
target_shape = target_leaf.shape
target_sharding = target_leaf.sharding
target_dtype = target_leaf.dtype
else:
target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf
target_sharding = None
target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32

if config.scan_layers:
# If it's a standard scanned layer, we use the configured param_scan_axis.
axis_to_stack = config.param_scan_axis
else:
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
axis_to_stack = 0

# The hook function needs the shape of an individual slice, not the full stacked tensor.
# We calculate it by removing the stacking dimension from the final target shape.
mt_slice_shape_list = list(target_shape)
del mt_slice_shape_list[axis_to_stack]
mt_slice_shape = tuple(mt_slice_shape_list)

if target_sharding is not None:
stacked_array = jax.jit(
lambda: np.zeros(target_shape, dtype=target_dtype),
out_shardings=target_sharding,
)()
else:
stacked_array = onp.zeros(target_shape, dtype=target_dtype)

for i, hf_key_single in enumerate(hf_source_keys):
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)

# Construct indexing tuple dynamically along axis_to_stack
indexer = [slice(None)] * len(target_shape)

if target_sharding is not None:
idx = jax.device_put(i)
if hasattr(target_sharding, "spec"):
spec_list = list(target_sharding.spec)
del spec_list[axis_to_stack]
slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list))
else:
slice_sharding = target_sharding
processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding)
indexer[axis_to_stack] = idx
stacked_array = stacked_array.at[tuple(indexer)].set(processed_hf_tensor)
else:
indexer[axis_to_stack] = i
stacked_array[tuple(indexer)] = processed_hf_tensor

return stacked_array


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Critical Performance Trap: Avoid repeated JAX array mutations using `.at[...].set(...)` inside python loops to prevent TPU Host/Device OOM.

Just like the multi-axis stack function, updating the JAX array inside a Python loop with stacked_array = stacked_array.at[tuple(indexer)].set(...) creates copies of the entire array at each step, causing massive memory churn and TPU Host/Device OOMs.

By accumulating processed slices in a list and performing a single np.stack(..., axis=axis_to_stack), we optimize memory footprint and compile the stack as a single fused XLA allocation, followed by placing the result on the target sharding.

Suggested change
def _build_single_axis_stacked_tensor(
hf_source_keys: List[str],
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_leaf: Any,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device."""
if hasattr(target_leaf, "sharding"):
target_shape = target_leaf.shape
target_sharding = target_leaf.sharding
target_dtype = target_leaf.dtype
else:
target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf
target_sharding = None
target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32
if config.scan_layers:
# If it's a standard scanned layer, we use the configured param_scan_axis.
axis_to_stack = config.param_scan_axis
else:
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
axis_to_stack = 0
# The hook function needs the shape of an individual slice, not the full stacked tensor.
# We calculate it by removing the stacking dimension from the final target shape.
mt_slice_shape_list = list(target_shape)
del mt_slice_shape_list[axis_to_stack]
mt_slice_shape = tuple(mt_slice_shape_list)
if target_sharding is not None:
stacked_array = jax.jit(
lambda: np.zeros(target_shape, dtype=target_dtype),
out_shardings=target_sharding,
)()
else:
stacked_array = onp.zeros(target_shape, dtype=target_dtype)
for i, hf_key_single in enumerate(hf_source_keys):
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
# Construct indexing tuple dynamically along axis_to_stack
indexer = [slice(None)] * len(target_shape)
if target_sharding is not None:
idx = jax.device_put(i)
if hasattr(target_sharding, "spec"):
spec_list = list(target_sharding.spec)
del spec_list[axis_to_stack]
slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list))
else:
slice_sharding = target_sharding
processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding)
indexer[axis_to_stack] = idx
stacked_array = stacked_array.at[tuple(indexer)].set(processed_hf_tensor)
else:
indexer[axis_to_stack] = i
stacked_array[tuple(indexer)] = processed_hf_tensor
return stacked_array
def _build_single_axis_stacked_tensor(
hf_source_keys: List[str],
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_leaf: Any,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device."""
if hasattr(target_leaf, "sharding"):
target_shape = target_leaf.shape
target_sharding = target_leaf.sharding
target_dtype = target_leaf.dtype
else:
target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf
target_sharding = None
target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32
if config.scan_layers:
# If it's a standard scanned layer, we use the configured param_scan_axis.
axis_to_stack = config.param_scan_axis
else:
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
axis_to_stack = 0
# The hook function needs the shape of an individual slice, not the full stacked tensor.
# We calculate it by removing the stacking dimension from the final target shape.
mt_slice_shape_list = list(target_shape)
del mt_slice_shape_list[axis_to_stack]
mt_slice_shape = tuple(mt_slice_shape_list)
tensors_to_stack = []
for hf_key_single in hf_source_keys:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
tensors_to_stack.append(processed_hf_tensor)
stacked_array = np.stack(tensors_to_stack, axis=axis_to_stack).astype(target_dtype)
if target_sharding is not None:
stacked_array = jax.device_put(stacked_array, target_sharding)
return stacked_array

Comment on lines +280 to +282
else:
raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Support local and single-host development and testing by falling back to standard Hugging Face snapshot downloads when GCS is not configured.

Strictly requiring base_output_directory to start with gs:// prevents single-host experiments, GPU-based workloads, and local unit/integration tests from functioning. In single-host environments (where jax.process_count() == 1), we can cleanly fallback to standard Hugging Face snapshot downloading.

Suggested change
else:
raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.")
elif jax.process_count() == 1:
# Single-host environment: download directly from HF Hub using snapshot_download
max_logging.log("Single-host environment detected: downloading directly from HF Hub.")
from huggingface_hub import snapshot_download
path = snapshot_download(
repo_id=repo_id,
token=maxtext_config.hf_access_token,
allow_patterns=["*.safetensors", "*.json"],
)
else:
raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.")

Comment on lines +132 to +134
t0 = time.time()

def tensor_getter(key):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Avoid popping keys from the loaded state to support architectures with tied/shared weights.

Popping keys from hf_state will cause a KeyError when attempting to load architectures with shared or tied weights (such as tie_word_embeddings = True), where the same Hugging Face key (e.g. model.embed_tokens.weight) is mapped to multiple MaxText parameters. Accessing the dictionary using standard bracket lookup or .get() avoids this crash and is completely safe since hf_state will be garbage collected after the function returns.

Suggested change
t0 = time.time()
def tensor_getter(key):
def tensor_getter(key):
return hf_state[key]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants