Skip to content

[bugfix] fix MiniCPMV4_6 text-only batch graph break in distributed training#9443

Open
randydl wants to merge 3 commits into
modelscope:mainfrom
randydl:dev
Open

[bugfix] fix MiniCPMV4_6 text-only batch graph break in distributed training#9443
randydl wants to merge 3 commits into
modelscope:mainfrom
randydl:dev

Conversation

@randydl
Copy link
Copy Markdown
Contributor

@randydl randydl commented May 28, 2026

Added a _post_encode override in MiniCPMV4_6Template that intercepts text-only batches and injects a zero-contribution dummy forward pass through the vision encoder:

  1. Detects text-only batches by checking that both pixel_values and pixel_values_videos are None.
  2. Constructs a minimal dummy image (1 tile of 4×4 patches, the smallest valid tile size) on the correct device/dtype.
  3. Runs the vision encoder via model.get_image_features() with this dummy input, ensuring all vision encoder parameters are part of the computation graph on every rank.
  4. Adds the dummy features (multiplied by 0.0) to inputs_embeds — gradients flow through the vision encoder onto a tensor that feeds into the rest of the model, but the dummy contribution has zero effect on the actual output.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a _post_encode method in swift/template/templates/minicpm.py to handle text-only batches during training by passing a dummy tensor through the vision encoder, ensuring consistent parameter usage across ranks. The reviewer suggested several defensive programming improvements to prevent runtime errors, including unwrapping DDP/FSDP model wrappers, retrieving the model's data type from inputs_embeds instead of model.dtype, and safely handling pooler_output types before concatenation.

Comment on lines +690 to +723
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
return inputs

pixel_values = inputs.get('pixel_values')
pixel_values_videos = inputs.get('pixel_values_videos')

if pixel_values is None and pixel_values_videos is None:
# Text-only batch: run a minimal dummy through the vision encoder
# so every rank's computation graph includes the same parameters.
device = inputs['input_ids'].device
patch_size = model.config.vision_config.patch_size

dummy_pv = torch.zeros(
1, 3, 4 * patch_size, 4 * patch_size,
device=device, dtype=model.dtype
)
dummy_ts = torch.tensor(
[[4, 4]], device=device, dtype=torch.int32
)

inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
vision_output = model.get_image_features(
dummy_pv, dummy_ts, downsample_mode=self.downsample_mode
)

dummy_feats = torch.cat(vision_output.pooler_output, dim=0).to(
device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
inputs_embeds = inputs_embeds + dummy_feats.mean() * 0.0

return {'inputs_embeds': inputs_embeds}

return inputs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To ensure robustness and prevent potential runtime errors during training, we should address a few defensive programming concerns:

  1. DDP/FSDP Unwrapping: If the model is wrapped in DistributedDataParallel (DDP) or FSDP, accessing attributes like config or methods like get_input_embeddings directly on model will raise an AttributeError because these wrappers do not delegate attribute access. Unwrapping the inner module via model.module if present is highly recommended.
  2. Model Dtype: model.dtype is not a standard PyTorch nn.Module attribute and might not be available on all custom wrappers or PEFT models. Getting the dtype directly from inputs_embeds.dtype is much safer.
  3. Defensive handling of pooler_output: Depending on the model configuration or future updates, vision_output.pooler_output might be a single tensor instead of a list of tensors. We should check its type before calling torch.cat to avoid potential type errors.
    def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
        if not self.is_training:
            return inputs

        pixel_values = inputs.get('pixel_values')
        pixel_values_videos = inputs.get('pixel_values_videos')

        if pixel_values is None and pixel_values_videos is None:
            # Unwrap DDP/FSDP wrapper if present
            if hasattr(model, 'module'):
                model = model.module

            # Text-only batch: run a minimal dummy through the vision encoder
            # so every rank's computation graph includes the same parameters.
            inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
            device = inputs_embeds.device
            dtype = inputs_embeds.dtype
            patch_size = model.config.vision_config.patch_size

            dummy_pv = torch.zeros(
                1, 3, 4 * patch_size, 4 * patch_size,
                device=device, dtype=dtype
            )
            dummy_ts = torch.tensor(
                [[4, 4]], device=device, dtype=torch.int32
            )

            vision_output = model.get_image_features(
                dummy_pv, dummy_ts, downsample_mode=self.downsample_mode
            )

            pooler_output = vision_output.pooler_output
            if isinstance(pooler_output, (list, tuple)):
                dummy_feats = torch.cat(pooler_output, dim=0)
            else:
                dummy_feats = pooler_output

            dummy_feats = dummy_feats.to(device=device, dtype=dtype)
            inputs_embeds = inputs_embeds + dummy_feats.mean() * 0.0

            return {'inputs_embeds': inputs_embeds}

        return inputs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_post_encode 收到的 model 已经是 unwrap 后的,不需要再处理。hook 注册时经过 accelerator.unwrap_model(model) 拿到原始模型,然后 register_forward_pre_hook 注册在这个对象上。训练时 wrapper 调用 self.module() 触发的是原始模型上的 hook,所以回调中的 model 就是 MiniCPMV4_6ForConditionalGeneration 本身,直接访问 .model.language_model.embed_tokens 没问题。

@randydl randydl changed the title [bugfix] fix MiniCPMV4_6 text-only batch DDP vision encoder graph break [bugfix] fix MiniCPMV4_6 text-only batch graph break in distributed training May 28, 2026
@randydl
Copy link
Copy Markdown
Contributor Author

randydl commented May 29, 2026

Before Fix:

[Gloo] Rank 1[Gloo] Rank  is connected to 1 peer ranks. 0Expected number of connected peer ranks is :  is connected to 11
 peer ranks. Expected number of connected peer ranks is : 1
Train:   0%|                                                                                                                                                                           | 0/313 [00:00<?, ?it/s][INFO:swift] use_logits_to_keep: False
Train:   0%|▌                                                                                                                                                                | 1/313 [00:14<1:14:42, 14.37s/it][rank1]:W0529 12:11:02.706000 717238 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank0]:W0529 12:11:02.706000 717235 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:11:02.734000 717238 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:11:02.734000 717235 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:11:02.737000 717238 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:11:02.737000 717235 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
{'loss': '1.534', 'grad_norm': '16.03', 'learning_rate': '1e-06', 'token_acc': '0.6462', 'epoch': '0.0032', 'global_step/max_steps': '1/313', 'elapsed_time': '14s', 'remaining_time': '1h 14m 58s', 'memory(GiB)': '60.16', 'train_speed(s/it)': '14.42'}
Train:   3%|█████▏                                                                                                                                                            | 10/313 [01:51<53:23, 10.57s/it][rank0]:W0529 12:12:39.911000 717235 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:12:39.911000 717238 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank0]:W0529 12:12:39.916000 717235 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:12:39.918000 717235 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:12:39.917000 717238 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:12:39.919000 717238 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
{'loss': '1.295', 'grad_norm': '5.649', 'learning_rate': '1e-05', 'token_acc': '0.6781', 'epoch': '0.032', 'global_step/max_steps': '10/313', 'elapsed_time': '1m 52s', 'remaining_time': '56m 21s', 'memory(GiB)': '60.17', 'train_speed(s/it)': '11.16'}
Train:   4%|█████▋                                                                                                                                                            | 11/313 [02:03<54:43, 10.87s/it][INFO:swift] last_model_checkpoint: None
[INFO:swift] best_model_checkpoint: None
[INFO:swift] images_dir: /nas_train/app.e0016372/train/sft/full/temp/v0-20260529-121018/images
[rank1]: Traceback (most recent call last):
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/cli/sft.py", line 20, in <module>
[rank1]:     sft_main()
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/pipelines/train/sft.py", line 353, in sft_main
[rank1]:     return SwiftSft(args).main()
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/pipelines/base.py", line 52, in main
[rank1]:     result = self.run()
[rank1]:              ^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/ray_utils/base.py", line 168, in wrapper
[rank1]:     return func(self, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/pipelines/train/sft.py", line 197, in run
[rank1]:     return self.train(trainer)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/pipelines/train/sft.py", line 271, in train
[rank1]:     trainer.train(resume_checkpoint)
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/trainers/mixin.py", line 903, in train
[rank1]:     res = super().train(*args, **kwargs)
[rank1]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/transformers/trainer.py", line 1427, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/transformers/trainer.py", line 1509, in _inner_training_loop
[rank1]:     self._run_epoch(
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/transformers/trainer.py", line 1737, in _run_epoch
[rank1]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/projects/ms-swift/swift/trainers/seq2seq_trainer.py", line 236, in training_step
[rank1]:     return super().training_step(model, inputs, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/transformers/trainer.py", line 1937, in training_step
[rank1]:     self.accelerator.backward(loss, **kwargs)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/accelerate/accelerator.py", line 2830, in backward
[rank1]:     self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 270, in backward
[rank1]:     self.engine.backward(loss, **kwargs)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2625, in backward
[rank1]:     loss.backward(**backward_kwargs)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/torch/_tensor.py", line 625, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/torch/autograd/__init__.py", line 354, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1063, in grad_handling_hook
[rank1]:     self.process_gradients(param, i)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1612, in process_gradients
[rank1]:     self.reduce_ready_partitions_and_remove_grads(param, i)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1616, in reduce_ready_partitions_and_remove_grads
[rank1]:     self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1093, in reduce_independent_p_g_buckets_and_remove_grads
[rank1]:     self.reduce_ipg_grads(comm_dtype=comm_dtype)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1564, in reduce_ipg_grads
[rank1]:     self.average_tensor(bucket.buffer[bucket.index].narrow(0, 0, bucket.elements), comm_dtype)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1309, in average_tensor
[rank1]:     self.allreduce_and_scatter(buckets[bucket_key],
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1214, in allreduce_and_scatter
[rank1]:     self.allreduce_and_copy_with_multiple_ranks(small_bucket,
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1176, in allreduce_and_copy_with_multiple_ranks
[rank1]:     allreduced = self.allreduce_bucket(small_bucket,
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1705, in allreduce_bucket
[rank1]:     dist.all_reduce(tensor_to_allreduce, group=process_group)
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/comm/comm.py", line 118, in log_wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/comm/comm.py", line 654, in all_reduce
[rank1]:     return cdb.all_reduce(tensor, op, group, async_op)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/deepspeed/comm/torch.py", line 167, in all_reduce
[rank1]:     return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nas_user/app.e0016372/miniforge3/envs/swift/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2935, in all_reduce
[rank1]:     work = group.allreduce([tensor], opts)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(SequenceNumber=3104, OpType=ALLREDUCE, TensorShape=[195812480], TensorDtypes=BFloat16, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=3104, OpType=ALLREDUCE, TensorShape=[51123456], TensorDtypes=BFloat16, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Tensor Tensor shapes: 195812480vs 51123456

After Fix:

[RANK 0] Gradient accumulation steps mismatch: GradientAccumulationPlugin has 1, DeepSpeed config has 8. Using DeepSpeed's value.
[transformers] The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
Train:   0%|                                                                                                                                                                           | 0/313 [00:00<?, ?it/s][INFO:swift] use_logits_to_keep: False
Train:   0%|▌                                                                                                                                                                | 1/313 [00:15<1:19:06, 15.21s/it][rank1]:W0529 12:13:04.593000 867520 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank0]:W0529 12:13:04.593000 867519 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:13:04.615000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:13:04.615000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:13:04.618000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:13:04.618000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
{'loss': '1.534', 'grad_norm': '16.03', 'learning_rate': '1e-06', 'token_acc': '0.6462', 'epoch': '0.0032', 'global_step/max_steps': '1/313', 'elapsed_time': '15s', 'remaining_time': '1h 19m 19s', 'memory(GiB)': '60.16', 'train_speed(s/it)': '15.25'}
Train:   3%|█████▏                                                                                                                                                            | 10/313 [01:56<55:50, 11.06s/it][rank0]:W0529 12:14:45.600000 867519 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:14:45.601000 867520 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank0]:W0529 12:14:45.608000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:14:45.608000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:14:45.610000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:14:45.611000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
{'loss': '1.297', 'grad_norm': '4.986', 'learning_rate': '1e-05', 'token_acc': '0.6778', 'epoch': '0.032', 'global_step/max_steps': '10/313', 'elapsed_time': '1m 56s', 'remaining_time': '58m 42s', 'memory(GiB)': '60.17', 'train_speed(s/it)': '11.62'}
Train:   6%|██████████▎                                                                                                                                                       | 20/313 [03:28<44:11,  9.05s/it][rank0]:W0529 12:16:17.865000 867519 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:16:17.866000 867520 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:16:17.873000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:16:17.873000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:16:17.876000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:16:17.876000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
{'loss': '1.088', 'grad_norm': '4.925', 'learning_rate': '9.98e-06', 'token_acc': '0.7156', 'epoch': '0.064', 'global_step/max_steps': '20/313', 'elapsed_time': '3m 29s', 'remaining_time': '50m 55s', 'memory(GiB)': '60.17', 'train_speed(s/it)': '10.43'}
Train:  10%|███████████████▌                                                                                                                                                  | 30/313 [04:56<39:58,  8.48s/it][rank1]:W0529 12:17:46.141000 867520 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank0]:W0529 12:17:46.141000 867519 site-packages/torch/distributed/distributed_c10d.py:3081] _object_to_tensor size: 27 hash value: 3048292174749185915
[rank1]:W0529 12:17:46.151000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:17:46.151000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank1]:W0529 12:17:46.153000 867520 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
[rank0]:W0529 12:17:46.153000 867519 site-packages/torch/distributed/distributed_c10d.py:3096] _tensor_to_object size: 27 hash value: 9135131298210845903
{'loss': '0.9697', 'grad_norm': '14.87', 'learning_rate': '9.9e-06', 'token_acc': '0.7338', 'epoch': '0.096', 'global_step/max_steps': '30/313', 'elapsed_time': '4m 57s', 'remaining_time': '46m 40s', 'memory(GiB)': '60.18', 'train_speed(s/it)': '9.893'}
Train:  11%|█████████████████                                                                                                                                                 | 33/313 [05:27<44:50,  9.61s/it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant