From 953cfeae03ef99a7fef5b6949472dca0c6be0cfd Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Fri, 19 Jun 2026 15:42:40 -0700 Subject: [PATCH] Add support for On-The-Fly Dynamic SafeTensors loading. PiperOrigin-RevId: 935065289 --- .../checkpoint_conversion/to_maxtext.py | 11 +- .../utils/load_dynamic.py | 304 ++++++++++++++++++ .../utils/tensor_handling.py | 190 +++++++++++ .../checkpoint_conversion/utils/utils.py | 181 +++++++++-- src/maxtext/common/checkpointing.py | 39 +-- src/maxtext/configs/types.py | 2 +- src/maxtext/layers/quantizations.py | 9 +- src/maxtext/utils/maxtext_utils.py | 1 + 8 files changed, 690 insertions(+), 47 deletions(-) create mode 100644 src/maxtext/checkpoint_conversion/utils/load_dynamic.py create mode 100644 src/maxtext/checkpoint_conversion/utils/tensor_handling.py diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 4245201b4e..ef7c0a4e82 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -58,21 +58,22 @@ import time from typing import Any, Callable, List, Sequence import absl -import ml_dtypes import flax.linen as nn from huggingface_hub import hf_hub_download, list_repo_files import jax -from maxtext.configs import pyconfig -from maxtext.configs.types import DType -from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys +from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns +from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_safetensors, load_hf_dict_from_transformers, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys +from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.configs import pyconfig +from maxtext.configs.types import DType from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging, max_utils, maxtext_utils from maxtext.utils.globals import HF_IDS +import ml_dtypes import numpy as np from orbax.checkpoint import type_handlers from safetensors import safe_open diff --git a/src/maxtext/checkpoint_conversion/utils/load_dynamic.py b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py new file mode 100644 index 0000000000..b28341912b --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py @@ -0,0 +1,304 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format.""" + +import concurrent.futures +import multiprocessing +import os +import random +import time + +from flax import nnx +import flax.traverse_util +from google.cloud import storage +from huggingface_hub import HfFileSystem +import jax +from maxtext.checkpoint_conversion.utils import hf_model_configs +from maxtext.checkpoint_conversion.utils import param_mapping +from maxtext.checkpoint_conversion.utils import tensor_handling +from maxtext.utils import globals as maxtext_globals +from maxtext.utils import max_logging +from orbax.checkpoint import v1 as ocp_v1 +from orbax.checkpoint._src.arrays import sharding as sharding_utils + +HF_MODEL_CONFIGS = hf_model_configs.HF_MODEL_CONFIGS +get_hf_loading_function = tensor_handling.get_hf_loading_function + + +def build_gcs_cache_worker(fpath, gcs_cache_dir, hf_access_token): + """Caches a file from Hugging Face to a GCS bucket cache directory. + + Args: + fpath: The path of the file on Hugging Face. + gcs_cache_dir: The destination directory in GCS. + hf_access_token: The access token for Hugging Face. + """ + fs = HfFileSystem(token=hf_access_token) + time.sleep(random.uniform(0.0, 5.0)) + + bucket_name = gcs_cache_dir.replace("gs://", "").split("/")[0] + blob_prefix = gcs_cache_dir.replace("gs://", "").split("/", 1)[1] if "/" in gcs_cache_dir.replace("gs://", "") else "" + blob_name = os.path.join(blob_prefix, os.path.basename(fpath)) + + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + if blob.exists(): + max_logging.log(f"[Worker] Cache hit for {os.path.basename(fpath)}.") + return + + t0 = time.time() + max_retries = 5 + for attempt in range(max_retries): + try: + with fs.open(fpath, "rb") as remote_f: + blob.chunk_size = 1024 * 1024 * 32 # 32MB chunks + blob.upload_from_file(remote_f, client=storage_client) + print( + f"[Worker] Cached {os.path.basename(fpath)} in" f" {time.time() - t0:.1f}s", + flush=True, + ) + break + except Exception as e: # pylint: disable=broad-exception-caught + if attempt < max_retries - 1: + max_logging.log( + f"Error fetching {fpath} to GCS: {e}. Retrying in 15 seconds..." f" (Attempt {attempt+1}/{max_retries})" + ) + time.sleep(15) + else: + max_logging.log(f"Failed to fetch {fpath} to GCS after {max_retries} attempts.") + raise + + +def get_hf_config_and_mappings(maxtext_config): + """Gets HF config and parameter mapping based on the MaxText config.""" + model_key = maxtext_config.model_name + if "-Instruct" in model_key: + model_key = model_key.replace("-Instruct", "") + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() + + param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key]( + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers + ) + hook_fn_map_mt = param_mapping.HOOK_FNS[model_key]( + hf_config_dict, + maxtext_config, + scan_layers=maxtext_config.scan_layers, + saving_to_hf=False, + ) + return param_map_mt_to_hf, hook_fn_map_mt + + +def load_sharded_hf_state(path): + """Loads HF state with maximal sharding across TPU mesh to avoid host OOM.""" + t0 = time.time() + context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) + with context: + metadata = ocp_v1.pytree_metadata(path) + simple_abstract_state = metadata.metadata + + # Distributed Sharded Download: Tell JAX to shard the HF Safetensors download + # across the entire TPU mesh to avoid Host OOM. + current_global_devices = jax.devices() + shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_global_devices) + + def combine_sharding(sds, single_sharding): + return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding) + + sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) + + max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS" " Download)...") + hf_state = ocp_v1.load_pytree(path, sharded_abstract_state) + max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s") + return hf_state + + +def transform_hf_state_to_mt_state(hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config): + """Transforms HF state into MaxText state by applying param mappings and mathematical hooks.""" + t0 = time.time() + + def tensor_getter(key): + return hf_state.pop(key) + + flat_target = flax.traverse_util.flatten_dict(target_tree, sep=".") + flat_restored = flat_target.copy() + + mapped_count = 0 + keys_missed = [] + max_logging.log("Starting fast in-memory Distributed Transformations...") + + for mt_key, hf_source in param_map_mt_to_hf.items(): + mt_name = mt_key.replace("params-", "").replace("-", ".") + + # Determine the correct key in flat_target + check_name = mt_name + if check_name not in flat_target: + if "params." + mt_name in flat_target: + check_name = "params." + mt_name + elif mt_key.replace("-", ".") in flat_target: + check_name = mt_key.replace("-", ".") + + if check_name not in flat_target: + keys_missed.append(mt_name) + continue + + target_leaf = flat_target[check_name] + hook_fn = hook_fn_map_mt.get(mt_key) + + load_fn = get_hf_loading_function( + hf_source, + tensor_getter, + hook_fn, + target_leaf, + maxtext_config, + ) + + # Execute transformation and assign to flat_restored + t_layer = time.time() + flat_restored[check_name] = load_fn() + + max_logging.log(f"Transformed {check_name} from {hf_source} in" f" {time.time() - t_layer:.4f}s") + mapped_count += 1 + + if mapped_count == 0: + max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}") + max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}") + + max_logging.log(f"Successfully mapped {mapped_count} parameters.") + restored_params = flax.traverse_util.unflatten_dict(flat_restored, sep=".") + + if "params" in restored_params: + restored_params = restored_params["params"] + + max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s") + + return {"params": restored_params} + + +def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config): + """Main entry point to dynamically build and load safetensors into MaxText format. + + Splits execution into: + 1. Deriving Mappings + 2. Loading Sharded arrays directly to TPUs + 3. Processing the transformations natively on TPUs + """ + if maxtext_config is None: + raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.") + + model_name = maxtext_config.model_name + if "-Instruct" in model_name: + model_name = model_name.replace("-Instruct", "") + + if not path: + if model_name not in maxtext_globals.HF_IDS: + raise ValueError("Unsupported model name for automatic HF repo resolution:" f" {model_name}.") + path = maxtext_globals.HF_IDS[model_name] + + if path.startswith("hf://"): + path = path[5:] + + if not path.startswith("gs://") and not os.path.isdir(path): + fs = HfFileSystem(token=maxtext_config.hf_access_token) + repo_id = path + + files = fs.glob(f"{repo_id}/*.safetensors") + + host_id = jax.process_index() + + if hasattr(maxtext_config, "base_output_directory") and maxtext_config.base_output_directory.startswith("gs://"): + gcs_cache_dir = f"{maxtext_config.base_output_directory}/hf_cache/{repo_id.replace('/', '_')}" + path = gcs_cache_dir + + # Only Host 0 downloads to the shared GCS cache + if host_id == 0: + max_logging.log("Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS" f" Cache: {gcs_cache_dir}") + t_gcs_start = time.time() + + # List existing blobs to avoid spawning processes for already cached + # files + storage_client = storage.Client() + gcs_cache_dir_no_gs = gcs_cache_dir.replace("gs://", "") + bucket_name = gcs_cache_dir_no_gs.split("/", maxsplit=1)[0] + blob_prefix = gcs_cache_dir_no_gs.split("/", maxsplit=1)[1] if "/" in gcs_cache_dir_no_gs else "" + + existing_blobs = {blob.name for blob in storage_client.list_blobs(bucket_name, prefix=blob_prefix)} + + files_to_download = [] + for fpath in files: + expected_blob_name = os.path.join(blob_prefix, os.path.basename(fpath)) + if expected_blob_name not in existing_blobs: + files_to_download.append(fpath) + + if files_to_download: + with concurrent.futures.ProcessPoolExecutor( + max_workers=32, mp_context=multiprocessing.get_context("spawn") + ) as executor: + futures = [ + executor.submit( + build_gcs_cache_worker, + fpath, + gcs_cache_dir, + maxtext_config.hf_access_token, + ) + for fpath in files_to_download + ] + + while futures: + done, futures = concurrent.futures.wait(futures, timeout=10) + + # Raise any exceptions if a worker failed + for f in done: + f.result() + + t_gcs_end = time.time() + max_logging.log( + f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s." + f" Downloaded {len(files_to_download)} missing files." + ) + + # Global barrier: all hosts wait for Host 0 to finish downloading to the + # shared GCS bucket + max_logging.log(f"Host {host_id} waiting for GCS cache at {gcs_cache_dir} to be" " populated by Host 0...") + jax.experimental.multihost_utils.sync_global_devices("dynamic_hf_download_complete") + max_logging.log(f"Host {host_id} detected GCS cache is ready!") + + else: + raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.") + + t_total = time.time() + param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config) + max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s") + + target_tree = ( + abstract_unboxed_pre_state.to_pure_dict() + if isinstance(abstract_unboxed_pre_state, nnx.State) + else abstract_unboxed_pre_state.params + ) + + t1 = time.time() + hf_state = load_sharded_hf_state(path) + max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s") + + t2 = time.time() + restored_params = transform_hf_state_to_mt_state( + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config + ) + max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s") + max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s") + + return None, restored_params diff --git a/src/maxtext/checkpoint_conversion/utils/tensor_handling.py b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py new file mode 100644 index 0000000000..06b23a3504 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py @@ -0,0 +1,190 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor handling utility functions for checkpoint conversion.""" + +import functools +from typing import Any, Callable, List +import jax +import jax.numpy as np +import numpy as onp + + +def apply_hook_fns(weight, target_shape, hook_fns): + """Apply hook functions, essential for to_maxtext and to_huggingface""" + # If hook is unsepecified, use identity + if hook_fns is None: + return weight + if not isinstance(hook_fns, list): + hook_fns = [hook_fns] + # Apply a list of hooks, be careful of order + for hook_fn in hook_fns: + weight = hook_fn(weight, target_shape) + return weight + + +def _build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_leaf: Any, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device.""" + if hasattr(target_leaf, "sharding"): + target_shape = target_leaf.shape + target_sharding = target_leaf.sharding + target_dtype = target_leaf.dtype + else: + target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf + target_sharding = None + target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 + + mt_slice_shape = target_shape[2:] + + if target_sharding is not None: + stacked_array = jax.jit( + lambda: np.zeros(target_shape, dtype=target_dtype), + out_shardings=target_sharding, + )() + else: + stacked_array = onp.zeros(target_shape, dtype=target_dtype) + + # Outer loop iterates through experts + for exp_idx, layer_keys_for_expert in enumerate(hf_source_keys): + # Inner loop iterates through layers for the current expert + for lyr_idx, hf_key_single in enumerate(layer_keys_for_expert): + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + + if target_sharding is not None: + exp_idx_device = jax.device_put(exp_idx) + lyr_idx_device = jax.device_put(lyr_idx) + if hasattr(target_sharding, "spec"): + spec_list = list(target_sharding.spec)[2:] + slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) + else: + slice_sharding = target_sharding + processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) + stacked_array = stacked_array.at[exp_idx_device, lyr_idx_device].set(processed_hf_tensor) + else: + stacked_array[exp_idx, lyr_idx] = processed_hf_tensor + + return stacked_array + + +def _build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_leaf: Any, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device.""" + if hasattr(target_leaf, "sharding"): + target_shape = target_leaf.shape + target_sharding = target_leaf.sharding + target_dtype = target_leaf.dtype + else: + target_shape = target_leaf.shape if hasattr(target_leaf, "shape") else target_leaf + target_sharding = None + target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + if target_sharding is not None: + stacked_array = jax.jit( + lambda: np.zeros(target_shape, dtype=target_dtype), + out_shardings=target_sharding, + )() + else: + stacked_array = onp.zeros(target_shape, dtype=target_dtype) + + for i, hf_key_single in enumerate(hf_source_keys): + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + + # Construct indexing tuple dynamically along axis_to_stack + indexer = [slice(None)] * len(target_shape) + + if target_sharding is not None: + idx = jax.device_put(i) + if hasattr(target_sharding, "spec"): + spec_list = list(target_sharding.spec) + del spec_list[axis_to_stack] + slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) + else: + slice_sharding = target_sharding + processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) + indexer[axis_to_stack] = idx + stacked_array = stacked_array.at[tuple(indexer)].set(processed_hf_tensor) + else: + indexer[axis_to_stack] = i + stacked_array[tuple(indexer)] = processed_hf_tensor + + return stacked_array + + +def get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_leaf, config): + """Determine the loading function for HF keys.""" + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, leaf, hook): + if hasattr(leaf, "sharding"): + array = apply_hook_fns(getter(key), leaf.shape, hook) + return jax.device_put(array, device=leaf.sharding) + else: + shape = leaf.shape if hasattr(leaf, "shape") else leaf + return apply_hook_fns(getter(key), shape, hook) + + return functools.partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_leaf, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + return functools.partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_leaf, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + return functools.partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_leaf, + config, + ) diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index cf43763f06..1ab2ce50df 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -14,43 +14,39 @@ """Checkpoint conversion utility functions.""" +from concurrent.futures import ThreadPoolExecutor import contextlib +from functools import partial import gc import io +import json import logging import os +import pathlib +import resource import tempfile import time -import json -from concurrent.futures import ThreadPoolExecutor -from typing import Any -from tqdm import tqdm -import resource -import numpy as np -import psutil -import pathlib +from typing import Any, Callable, List from etils import epath - +from flax.training import train_state +from huggingface_hub import HfApi, repo_exists, snapshot_download import jax from jax import tree from jax.experimental import multihost_utils from jaxtyping import Array - -from safetensors import safe_open -from safetensors.numpy import save_file as numpy_save_file -from safetensors.numpy import save as numpy_save -from safetensors.flax import save as save_flax_to_bytes - -from huggingface_hub import HfApi, repo_exists, snapshot_download - -from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from transformers import AutoModelForCausalLM - -from flax.training import train_state from maxtext.common import checkpointing from maxtext.common.gcloud_stub import gcs_storage from maxtext.utils import max_logging +import numpy as np import orbax.checkpoint as ocp +import psutil +from safetensors import safe_open +from safetensors.flax import save as save_flax_to_bytes +from safetensors.numpy import save as numpy_save +from safetensors.numpy import save_file as numpy_save_file +from tqdm import tqdm +from transformers import AutoModelForCausalLM +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES _storage = gcs_storage() Client = _storage.Client @@ -1229,3 +1225,146 @@ def save_weights_to_checkpoint( checkpoint_manager.wait_until_finished() max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min") + + +def _build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). + + This function handles the complex case for scanned MoE layers, producing a + tensor + with the shape (num_experts, num_layers, ...). + + Args: + hf_source_keys: A nested (2D) list of Hugging Face parameter names. Outer + list iterates experts, inner list iterates layers. + tensor_getter_fn: A callable that takes a HF key and returns the tensor + (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + all_expert_tensors = [] + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:] + mt_slice_shape = target_shape[2:] + + # Outer loop iterates through experts + for layer_keys_for_expert in hf_source_keys: + layer_tensors_for_expert = [] + # Inner loop iterates through layers for the current expert + for hf_key_single in layer_keys_for_expert: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + layer_tensors_for_expert.append(processed_hf_tensor) + all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) + return np.stack(all_expert_tensors, axis=0) + + +def _build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along a single axis. + + This function handles both standard scanned layers (e.g., attention) and + unscanned MoE layers (which are stacked along the expert axis). + + Args: + hf_source_keys: A 1D list of Hugging Face parameter names. + tensor_getter_fn: A callable that takes a HF key and returns the tensor + (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + tensors_to_stack = [] + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + for hf_key_single in hf_source_keys: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + tensors_to_stack.append(processed_hf_tensor) + + # Stack all processed tensors along the determined axis. + return np.stack(tensors_to_stack, axis=axis_to_stack) + + +def _get_hf_loading_function( + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, +): + """Determine the loading function for HF keys. + + HF keys can take four forms: + + Case 1: Unscanned (single string) + Case 2: Scanned (list of strings) + Case 3: Unscanned with expert stacking (list of strings) + Case 4: Scanned with expert stacking (nested list of strings) + """ + load_fn = None + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, shape, hook): + return apply_hook_fns(getter(key), shape, hook) + + load_fn = partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_shape_or_shapes, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + load_fn = partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + load_fn = partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + return load_fn diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 73f475bb39..8b596cdfed 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -34,6 +34,8 @@ from maxtext.utils import max_logging from maxtext.utils import gcs_utils from maxtext.utils import elastic_utils +from maxtext.checkpoint_conversion.utils.load_dynamic import load_safetensors_dynamic_state + import numpy as np import orbax.checkpoint as ocp from orbax.checkpoint import v1 as ocp_v1 @@ -773,6 +775,7 @@ def load_state_if_possible( checkpoint_conversion_fn=None, source_checkpoint_layout="orbax", expansion_factor_real_data: int = -1, + maxtext_config: Any | None = None, ): """Loads TrainState as possible from the inputs. @@ -838,9 +841,7 @@ def map_to_pspec(data): (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager), ): checkpoint_path = str(checkpoint_manager.directory / str(step) / "items") - with handle_checkpoint_mismatch( - "restore NNX checkpoint", checkpoint_path - ): + with handle_checkpoint_mismatch("restore NNX checkpoint", checkpoint_path): restored_nnx = _load_linen_checkpoint_into_nnx( checkpoint_path, abstract_unboxed_pre_state, @@ -876,9 +877,7 @@ def map_to_pspec(data): EmergencyReplicatorCheckpointManager, ), ): - restored = checkpoint_manager.restore( - step, args=Composite(state=checkpoint_args) - ).state + restored = checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state _assert_no_shaped_dtype_struct(restored) return ( restored, @@ -906,21 +905,22 @@ def map_to_pspec(data): # Case 3: Default/Fallback case. # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. case _: - restored = checkpoint_manager.restore( - step, args=Composite(items=checkpoint_args) - ) + restored = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)) _assert_no_shaped_dtype_struct(restored) return (restored, None) - if load_parameters_from_path != "": + if source_checkpoint_layout == "safetensors_dynamic": + path = load_parameters_from_path or load_full_state_from_path + max_logging.log(f"Dynamic On-the-Fly Formatting: Loading SafeTensors from {path}") + + return load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config) + elif load_parameters_from_path != "": if isinstance(abstract_unboxed_pre_state, nnx.State): _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) else: params = abstract_unboxed_pre_state.params - with handle_checkpoint_mismatch( - "load parameters", load_parameters_from_path - ): + with handle_checkpoint_mismatch("load parameters", load_parameters_from_path): restored_params = load_params_from_path( load_parameters_from_path, params, @@ -932,9 +932,7 @@ def map_to_pspec(data): return None, restored_params elif load_full_state_from_path != "": max_logging.log(f"Loading full state from path: {load_full_state_from_path}") - with handle_checkpoint_mismatch( - "load full state", load_full_state_from_path - ): + with handle_checkpoint_mismatch("load full state", load_full_state_from_path): restored_state = _load_full_state_from_path( path=load_full_state_from_path, abstract_unboxed_pre_state=abstract_unboxed_pre_state, @@ -972,13 +970,18 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute- def load_params_from_path( - load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True + load_parameters_from_path, + abstract_unboxed_params, + checkpoint_storage_concurrent_gb, + use_ocdbt=True, + use_zarr3=True, ): """Load decode params from checkpoint at specified path.""" assert load_parameters_from_path, "load_parameters_from_path is not defined." max_logging.log(f"restoring params from {load_parameters_from_path}") - # NNX target: the on-disk checkpoint is in Linen layout; reshape it into the NNX params state. + # NNX target: the on-disk checkpoint is in Linen layout; reshape it into the + # NNX params state. if isinstance(abstract_unboxed_params, nnx.State): return _load_linen_params_into_nnx( load_parameters_from_path, diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 15cdce6290..de62e56d92 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -345,7 +345,7 @@ class Checkpointing(BaseModel): save_quantized_params_path: PathStr = Field("", description="Path to save params quantized on the fly.") enable_orbax_v1: bool = Field(False, description="Bool flag for enabling Orbax v1.") checkpoint_conversion_fn: None | str = Field(None, description="Function for processing loaded checkpoint dict.") - source_checkpoint_layout: Literal["orbax", "safetensors"] = Field( + source_checkpoint_layout: Literal["orbax", "safetensors", "safetensors_dynamic"] = Field( "orbax", description="The layout of the source checkpoint to load." ) save_checkpoint_on_completion: bool = Field( diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 27d2fb1130..dee42e6500 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -1014,19 +1014,24 @@ def _wrap(self, f, name=None): import transformer_engine.jax # pylint: disable=import-outside-toplevel # pytype: disable=import-error from transformer_engine.common import recipe # pylint: disable=import-outside-toplevel # pytype: disable=import-error - fp8_recipe = self._recipe + default_fp8_recipe = self._recipe class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase): """Wrapper module for TransformerEngine quantization.""" - def generate_quantizer_set(self, postfix: str = "", + def generate_quantizer_set( + self, + postfix: str = "", variable_collection: str | None = None, quantization_checkpoint_name: str | None = None, fp8_recipe: recipe.Recipe | None = None, n_groups: int | None = None, ): + """Generates a quantizer set for the given recipe.""" OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" + if fp8_recipe is None: + fp8_recipe = default_fp8_recipe return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 238758da92..3ffb3010db 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1496,6 +1496,7 @@ def setup_initial_state( checkpoint_conversion_fn=config.checkpoint_conversion_fn, source_checkpoint_layout=config.source_checkpoint_layout, expansion_factor_real_data=config.expansion_factor_real_data, + maxtext_config=config, ) if restored: