Skip to content

Commit 177a75f

Browse files
Merge pull request #2510 from bzantium:feature/#2509
PiperOrigin-RevId: 838961966
2 parents 8c80266 + 496a7b2 commit 177a75f

File tree

4 files changed

+74
-11
lines changed

4 files changed

+74
-11
lines changed

src/MaxText/configs/base.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward
256256
# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay).
257257

258258
model_fsdp_ag_once: False # This controls whether the Zero-1 optimization is active.
259-
# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step.
259+
# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step.
260260
# False: This is Zero-3 Sharing. Use the standard Transformer, which gathers for each microbatch's fwd/bwd pass.
261261
pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration.
262262
# This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed
@@ -306,7 +306,7 @@ param_scan_axis: 1
306306
# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
307307
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
308308
attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla
309-
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
309+
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
310310
attention_sink: False
311311
sliding_window_size: 0
312312
chunk_attn_window_size: 0
@@ -424,7 +424,7 @@ logical_axis_rules: [
424424
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
425425
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
426426
['embed_no_exp', ['fsdp', 'sequence', 'context']],
427-
['embed_tensor_transpose', ['tensor_transpose']],
427+
['embed_tensor_transpose', ['tensor_transpose']],
428428
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
429429
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
430430
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
@@ -530,7 +530,7 @@ per_device_batch_size: 12.0
530530
# Each data-loading host will load per_device_batch_size * expansion_factor_real_data.
531531
# When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job.
532532
# Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md#using-grain
533-
expansion_factor_real_data: -1.0
533+
expansion_factor_real_data: -1.0
534534
eval_per_device_batch_size: 0.0
535535
max_corpus_chars: 10_000_000
536536
train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
@@ -595,14 +595,15 @@ grain_train_files: ''
595595
grain_eval_files: ''
596596
grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data.
597597
grain_file_type: 'arrayrecord' # arrayrecord or parquet
598-
grain_worker_count: 1
598+
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
599599
grain_per_worker_buffer_size: 1
600600
# num_threads and prefetch_buffer_size are per-worker per-dataset. Used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
601601
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
602-
grain_num_threads: 16
602+
grain_num_threads: 16
603603
grain_prefetch_buffer_size: 500
604604
grain_worker_count_eval: 1
605605
grain_per_worker_buffer_size_eval: 1
606+
grain_ram_budget_mb: 1024 # RAM budget (MB) for auto-tuning worker count. Only used when grain_worker_count is -1.
606607
grain_num_threads_eval: 16
607608
grain_prefetch_buffer_size_eval: 500
608609
grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.
@@ -930,7 +931,7 @@ temporal_patch_size_for_vit: 2
930931
num_position_embeddings_for_vit: 1024
931932
deepstack_visual_indexes_for_vit: []
932933

933-
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
934+
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
934935
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
935936
subslice_shape: ""
936937

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,7 @@ class GrainDataset(BaseModel):
860860
grain_per_worker_buffer_size_eval: int = Field(
861861
1, description="Buffer size for each worker for Grain data loading during evaluation."
862862
)
863+
grain_ram_budget_mb: int = Field(1024, description="RAM budget (MB) for auto-tuning worker count.")
863864
grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.")
864865
grain_prefetch_buffer_size: int = Field(500, description="Prefetch buffer size for Grain ReadOptions during training.")
865866
grain_num_threads_eval: int = Field(16, description="Number of threads for Grain ReadOptions during evaluation.")

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import jax
2525

26+
from grain.experimental import pick_performance_config
2627
import grain.python as grain
2728

2829
from MaxText.utils import gcs_utils
@@ -230,12 +231,20 @@ def pretrain_preprocessing_pipeline(
230231
axis=1,
231232
)
232233
)
233-
dataset = dataset.mp_prefetch(
234-
grain.MultiprocessingOptions(
234+
multiprocessing_options = (
235+
pick_performance_config(
236+
ds=dataset,
237+
ram_budget_mb=config.grain_ram_budget_mb,
238+
max_workers=None,
239+
max_buffer_size=None,
240+
).multiprocessing_options
241+
if grain_worker_count == -1
242+
else grain.MultiprocessingOptions(
235243
num_workers=grain_worker_count,
236244
per_worker_buffer_size=grain_per_worker_buffer_size,
237245
)
238246
)
247+
dataset = dataset.mp_prefetch(multiprocessing_options)
239248
return dataset
240249

241250

@@ -273,12 +282,20 @@ def dpo_preprocessing_pipeline(
273282
batch_size = config.global_batch_size_to_load // jax.process_count()
274283
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
275284
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
276-
dataset = dataset.mp_prefetch(
277-
grain.MultiprocessingOptions(
285+
multiprocessing_options = (
286+
pick_performance_config(
287+
ds=dataset,
288+
ram_budget_mb=config.grain_ram_budget_mb,
289+
max_workers=None,
290+
max_buffer_size=None,
291+
).multiprocessing_options
292+
if grain_worker_count == -1
293+
else grain.MultiprocessingOptions(
278294
num_workers=grain_worker_count,
279295
per_worker_buffer_size=grain_per_worker_buffer_size,
280296
)
281297
)
298+
dataset = dataset.mp_prefetch(multiprocessing_options)
282299
return dataset
283300

284301

tests/grain_data_processing_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import json
2323

2424
import jax
25+
import pytest
2526
from jax.sharding import Mesh
2627
from jax.experimental import mesh_utils
2728

@@ -182,6 +183,49 @@ def setUp(self):
182183
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
183184

184185

186+
class GrainArrayRecordAutoTuneTest(GrainArrayRecordProcessingTest):
187+
"""Test grain data processing with auto-tuning enabled (grain_worker_count=-1)."""
188+
189+
def setUp(self):
190+
super().setUp()
191+
temp_dir = tempfile.gettempdir()
192+
self.config = pyconfig.initialize(
193+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
194+
per_device_batch_size=1,
195+
run_name="test",
196+
mesh_axes=["data"],
197+
logical_axis_rules=[["batch", "data"]],
198+
data_sharding=["data"],
199+
base_output_directory="gs://max-experiments/",
200+
dataset_type="grain",
201+
grain_train_files=os.path.join(
202+
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
203+
),
204+
grain_worker_count=-1, # Enable auto-tuning
205+
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
206+
enable_checkpointing=False,
207+
)
208+
self.mesh_shape_1d = (len(jax.devices()),)
209+
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
210+
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
211+
self.config.data_sharding,
212+
self.config.global_batch_size_to_load,
213+
self.config.global_batch_size_to_train_on,
214+
self.config.max_target_length,
215+
self.mesh,
216+
)
217+
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
218+
219+
@pytest.mark.skip(
220+
reason=(
221+
"Auto-tuning tries multiple numbers of workers during the first few batches "
222+
"and it affects batch determinism at first."
223+
)
224+
)
225+
def test_batch_determinism(self):
226+
super().test_batch_determinism()
227+
228+
185229
class GrainParquetProcessingTest(unittest.TestCase):
186230

187231
@classmethod

0 commit comments

Comments
 (0)