Add support for On-The-Fly Dynamic SafeTensors loading.#4218
Add support for On-The-Fly Dynamic SafeTensors loading.#4218copybara-service[bot] wants to merge 1 commit into
Conversation
7df20d2 to
e59c434
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
8eb1dd8 to
36bddba
Compare
36bddba to
6288ef3
Compare
PiperOrigin-RevId: 935065289
6288ef3 to
953cfea
Compare
|
🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces excellent support for dynamic, on-the-fly SafeTensors loading during training and evaluation, bypassing traditional ahead-of-time conversion bottlenecks. The high-level architecture is solid, well-integrated, and provides an elegant pathway to direct-load Hugging Face weights onto TPU meshes.
🔍 General Feedback
- Great Architecture Design: The integration with JAX's maximal sharding API and Orbax checkpointing is robust and correctly handles distributed downloading across TPU VMs.
- Code Duplication: There is substantial duplication of functions like
_build_multi_axis_stacked_tensor,_build_single_axis_stacked_tensor, and_get_hf_loading_functionacross multiple files (to_maxtext.py,utils.py, andtensor_handling.py). Consolidating these helpers intotensor_handling.pyand importing them elsewhere would significantly improve DRY compliance and long-term maintainability. - Robustness: Addressing the loop performance traps and tied weight
KeyErrorwill ensure the feature is ready for high-scale production workloads on massive model architectures.
| tensor_getter_fn: Callable[[str], np.ndarray], | ||
| hook_fns: Any, | ||
| target_leaf: Any, | ||
| config, | ||
| ) -> np.ndarray: | ||
| """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device.""" | ||
| if hasattr(target_leaf, "sharding"): | ||
| target_shape = target_leaf.shape | ||
| target_sharding = target_leaf.sharding | ||
| target_dtype = target_leaf.dtype | ||
| else: | ||
| target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf | ||
| target_sharding = None | ||
| target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 | ||
|
|
||
| mt_slice_shape = target_shape[2:] | ||
|
|
||
| if target_sharding is not None: | ||
| stacked_array = jax.jit( | ||
| lambda: np.zeros(target_shape, dtype=target_dtype), | ||
| out_shardings=target_sharding, | ||
| )() | ||
| else: | ||
| stacked_array = onp.zeros(target_shape, dtype=target_dtype) | ||
|
|
||
| # Outer loop iterates through experts | ||
| for exp_idx, layer_keys_for_expert in enumerate(hf_source_keys): | ||
| # Inner loop iterates through layers for the current expert | ||
| for lyr_idx, hf_key_single in enumerate(layer_keys_for_expert): | ||
| hf_tensor_numpy = tensor_getter_fn(hf_key_single) | ||
| processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) | ||
|
|
||
| if target_sharding is not None: | ||
| exp_idx_device = jax.device_put(exp_idx) | ||
| lyr_idx_device = jax.device_put(lyr_idx) | ||
| if hasattr(target_sharding, "spec"): | ||
| spec_list = list(target_sharding.spec)[2:] | ||
| slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) | ||
| else: | ||
| slice_sharding = target_sharding | ||
| processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) | ||
| stacked_array = stacked_array.at[exp_idx_device, lyr_idx_device].set(processed_hf_tensor) | ||
| else: | ||
| stacked_array[exp_idx, lyr_idx] = processed_hf_tensor | ||
|
|
||
| return stacked_array |
There was a problem hiding this comment.
Because JAX arrays are immutable, every .at[...].set(...) operation within this nested Python loop allocates a brand new full-sized target JAX array and copies all data into it. For models with large numbers of layers/experts (e.g.,
Instead, accumulate the individual slices/expert tensors into a standard Python list, perform a single unified np.stack across the desired axes (which JAX/XLA compiles efficiently as a single fused allocation/concat operation), and then place the final stacked array onto the target sharding with a single jax.device_put call.
| tensor_getter_fn: Callable[[str], np.ndarray], | |
| hook_fns: Any, | |
| target_leaf: Any, | |
| config, | |
| ) -> np.ndarray: | |
| """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device.""" | |
| if hasattr(target_leaf, "sharding"): | |
| target_shape = target_leaf.shape | |
| target_sharding = target_leaf.sharding | |
| target_dtype = target_leaf.dtype | |
| else: | |
| target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf | |
| target_sharding = None | |
| target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 | |
| mt_slice_shape = target_shape[2:] | |
| if target_sharding is not None: | |
| stacked_array = jax.jit( | |
| lambda: np.zeros(target_shape, dtype=target_dtype), | |
| out_shardings=target_sharding, | |
| )() | |
| else: | |
| stacked_array = onp.zeros(target_shape, dtype=target_dtype) | |
| # Outer loop iterates through experts | |
| for exp_idx, layer_keys_for_expert in enumerate(hf_source_keys): | |
| # Inner loop iterates through layers for the current expert | |
| for lyr_idx, hf_key_single in enumerate(layer_keys_for_expert): | |
| hf_tensor_numpy = tensor_getter_fn(hf_key_single) | |
| processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) | |
| if target_sharding is not None: | |
| exp_idx_device = jax.device_put(exp_idx) | |
| lyr_idx_device = jax.device_put(lyr_idx) | |
| if hasattr(target_sharding, "spec"): | |
| spec_list = list(target_sharding.spec)[2:] | |
| slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) | |
| else: | |
| slice_sharding = target_sharding | |
| processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) | |
| stacked_array = stacked_array.at[exp_idx_device, lyr_idx_device].set(processed_hf_tensor) | |
| else: | |
| stacked_array[exp_idx, lyr_idx] = processed_hf_tensor | |
| return stacked_array | |
| def _build_multi_axis_stacked_tensor( | |
| hf_source_keys: List[List[str]], | |
| tensor_getter_fn: Callable[[str], np.ndarray], | |
| hook_fns: Any, | |
| target_leaf: Any, | |
| config, | |
| ) -> np.ndarray: | |
| """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device.""" | |
| if hasattr(target_leaf, "sharding"): | |
| target_shape = target_leaf.shape | |
| target_sharding = target_leaf.sharding | |
| target_dtype = target_leaf.dtype | |
| else: | |
| target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf | |
| target_sharding = None | |
| target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 | |
| mt_slice_shape = target_shape[2:] | |
| all_expert_tensors = [] | |
| # Outer loop iterates through experts | |
| for layer_keys_for_expert in hf_source_keys: | |
| layer_tensors_for_expert = [] | |
| # Inner loop iterates through layers for the current expert | |
| for hf_key_single in layer_keys_for_expert: | |
| hf_tensor_numpy = tensor_getter_fn(hf_key_single) | |
| processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) | |
| layer_tensors_for_expert.append(processed_hf_tensor) | |
| all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) | |
| stacked_array = np.stack(all_expert_tensors, axis=0).astype(target_dtype) | |
| if target_sharding is not None: | |
| stacked_array = jax.device_put(stacked_array, target_sharding) | |
| return stacked_array |
| def _build_single_axis_stacked_tensor( | ||
| hf_source_keys: List[str], | ||
| tensor_getter_fn: Callable[[str], np.ndarray], | ||
| hook_fns: Any, | ||
| target_leaf: Any, | ||
| config, | ||
| ) -> np.ndarray: | ||
| """Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device.""" | ||
| if hasattr(target_leaf, "sharding"): | ||
| target_shape = target_leaf.shape | ||
| target_sharding = target_leaf.sharding | ||
| target_dtype = target_leaf.dtype | ||
| else: | ||
| target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf | ||
| target_sharding = None | ||
| target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 | ||
|
|
||
| if config.scan_layers: | ||
| # If it's a standard scanned layer, we use the configured param_scan_axis. | ||
| axis_to_stack = config.param_scan_axis | ||
| else: | ||
| # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). | ||
| axis_to_stack = 0 | ||
|
|
||
| # The hook function needs the shape of an individual slice, not the full stacked tensor. | ||
| # We calculate it by removing the stacking dimension from the final target shape. | ||
| mt_slice_shape_list = list(target_shape) | ||
| del mt_slice_shape_list[axis_to_stack] | ||
| mt_slice_shape = tuple(mt_slice_shape_list) | ||
|
|
||
| if target_sharding is not None: | ||
| stacked_array = jax.jit( | ||
| lambda: np.zeros(target_shape, dtype=target_dtype), | ||
| out_shardings=target_sharding, | ||
| )() | ||
| else: | ||
| stacked_array = onp.zeros(target_shape, dtype=target_dtype) | ||
|
|
||
| for i, hf_key_single in enumerate(hf_source_keys): | ||
| hf_tensor_numpy = tensor_getter_fn(hf_key_single) | ||
| processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) | ||
|
|
||
| # Construct indexing tuple dynamically along axis_to_stack | ||
| indexer = [slice(None)] * len(target_shape) | ||
|
|
||
| if target_sharding is not None: | ||
| idx = jax.device_put(i) | ||
| if hasattr(target_sharding, "spec"): | ||
| spec_list = list(target_sharding.spec) | ||
| del spec_list[axis_to_stack] | ||
| slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) | ||
| else: | ||
| slice_sharding = target_sharding | ||
| processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) | ||
| indexer[axis_to_stack] = idx | ||
| stacked_array = stacked_array.at[tuple(indexer)].set(processed_hf_tensor) | ||
| else: | ||
| indexer[axis_to_stack] = i | ||
| stacked_array[tuple(indexer)] = processed_hf_tensor | ||
|
|
||
| return stacked_array | ||
|
|
||
|
|
There was a problem hiding this comment.
Just like the multi-axis stack function, updating the JAX array inside a Python loop with stacked_array = stacked_array.at[tuple(indexer)].set(...) creates copies of the entire array at each step, causing massive memory churn and TPU Host/Device OOMs.
By accumulating processed slices in a list and performing a single np.stack(..., axis=axis_to_stack), we optimize memory footprint and compile the stack as a single fused XLA allocation, followed by placing the result on the target sharding.
| def _build_single_axis_stacked_tensor( | |
| hf_source_keys: List[str], | |
| tensor_getter_fn: Callable[[str], np.ndarray], | |
| hook_fns: Any, | |
| target_leaf: Any, | |
| config, | |
| ) -> np.ndarray: | |
| """Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device.""" | |
| if hasattr(target_leaf, "sharding"): | |
| target_shape = target_leaf.shape | |
| target_sharding = target_leaf.sharding | |
| target_dtype = target_leaf.dtype | |
| else: | |
| target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf | |
| target_sharding = None | |
| target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 | |
| if config.scan_layers: | |
| # If it's a standard scanned layer, we use the configured param_scan_axis. | |
| axis_to_stack = config.param_scan_axis | |
| else: | |
| # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). | |
| axis_to_stack = 0 | |
| # The hook function needs the shape of an individual slice, not the full stacked tensor. | |
| # We calculate it by removing the stacking dimension from the final target shape. | |
| mt_slice_shape_list = list(target_shape) | |
| del mt_slice_shape_list[axis_to_stack] | |
| mt_slice_shape = tuple(mt_slice_shape_list) | |
| if target_sharding is not None: | |
| stacked_array = jax.jit( | |
| lambda: np.zeros(target_shape, dtype=target_dtype), | |
| out_shardings=target_sharding, | |
| )() | |
| else: | |
| stacked_array = onp.zeros(target_shape, dtype=target_dtype) | |
| for i, hf_key_single in enumerate(hf_source_keys): | |
| hf_tensor_numpy = tensor_getter_fn(hf_key_single) | |
| processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) | |
| # Construct indexing tuple dynamically along axis_to_stack | |
| indexer = [slice(None)] * len(target_shape) | |
| if target_sharding is not None: | |
| idx = jax.device_put(i) | |
| if hasattr(target_sharding, "spec"): | |
| spec_list = list(target_sharding.spec) | |
| del spec_list[axis_to_stack] | |
| slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) | |
| else: | |
| slice_sharding = target_sharding | |
| processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) | |
| indexer[axis_to_stack] = idx | |
| stacked_array = stacked_array.at[tuple(indexer)].set(processed_hf_tensor) | |
| else: | |
| indexer[axis_to_stack] = i | |
| stacked_array[tuple(indexer)] = processed_hf_tensor | |
| return stacked_array | |
| def _build_single_axis_stacked_tensor( | |
| hf_source_keys: List[str], | |
| tensor_getter_fn: Callable[[str], np.ndarray], | |
| hook_fns: Any, | |
| target_leaf: Any, | |
| config, | |
| ) -> np.ndarray: | |
| """Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device.""" | |
| if hasattr(target_leaf, "sharding"): | |
| target_shape = target_leaf.shape | |
| target_sharding = target_leaf.sharding | |
| target_dtype = target_leaf.dtype | |
| else: | |
| target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf | |
| target_sharding = None | |
| target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 | |
| if config.scan_layers: | |
| # If it's a standard scanned layer, we use the configured param_scan_axis. | |
| axis_to_stack = config.param_scan_axis | |
| else: | |
| # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). | |
| axis_to_stack = 0 | |
| # The hook function needs the shape of an individual slice, not the full stacked tensor. | |
| # We calculate it by removing the stacking dimension from the final target shape. | |
| mt_slice_shape_list = list(target_shape) | |
| del mt_slice_shape_list[axis_to_stack] | |
| mt_slice_shape = tuple(mt_slice_shape_list) | |
| tensors_to_stack = [] | |
| for hf_key_single in hf_source_keys: | |
| hf_tensor_numpy = tensor_getter_fn(hf_key_single) | |
| processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) | |
| tensors_to_stack.append(processed_hf_tensor) | |
| stacked_array = np.stack(tensors_to_stack, axis=axis_to_stack).astype(target_dtype) | |
| if target_sharding is not None: | |
| stacked_array = jax.device_put(stacked_array, target_sharding) | |
| return stacked_array |
| else: | ||
| raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.") | ||
|
|
There was a problem hiding this comment.
Strictly requiring base_output_directory to start with gs:// prevents single-host experiments, GPU-based workloads, and local unit/integration tests from functioning. In single-host environments (where jax.process_count() == 1), we can cleanly fallback to standard Hugging Face snapshot downloading.
| else: | |
| raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.") | |
| elif jax.process_count() == 1: | |
| # Single-host environment: download directly from HF Hub using snapshot_download | |
| max_logging.log("Single-host environment detected: downloading directly from HF Hub.") | |
| from huggingface_hub import snapshot_download | |
| path = snapshot_download( | |
| repo_id=repo_id, | |
| token=maxtext_config.hf_access_token, | |
| allow_patterns=["*.safetensors", "*.json"], | |
| ) | |
| else: | |
| raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.") |
| t0 = time.time() | ||
|
|
||
| def tensor_getter(key): |
There was a problem hiding this comment.
Popping keys from hf_state will cause a KeyError when attempting to load architectures with shared or tied weights (such as tie_word_embeddings = True), where the same Hugging Face key (e.g. model.embed_tokens.weight) is mapped to multiple MaxText parameters. Accessing the dictionary using standard bracket lookup or .get() avoids this crash and is completely safe since hf_state will be garbage collected after the function returns.
| t0 = time.time() | |
| def tensor_getter(key): | |
| def tensor_getter(key): | |
| return hf_state[key] |
Add support for On-The-Fly Dynamic SafeTensors loading.