[bugfix] fix MiniCPMV4_6 text-only batch graph break in distributed training#9443
[bugfix] fix MiniCPMV4_6 text-only batch graph break in distributed training#9443randydl wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a _post_encode method in swift/template/templates/minicpm.py to handle text-only batches during training by passing a dummy tensor through the vision encoder, ensuring consistent parameter usage across ranks. The reviewer suggested several defensive programming improvements to prevent runtime errors, including unwrapping DDP/FSDP model wrappers, retrieving the model's data type from inputs_embeds instead of model.dtype, and safely handling pooler_output types before concatenation.
| def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||
| if not self.is_training: | ||
| return inputs | ||
|
|
||
| pixel_values = inputs.get('pixel_values') | ||
| pixel_values_videos = inputs.get('pixel_values_videos') | ||
|
|
||
| if pixel_values is None and pixel_values_videos is None: | ||
| # Text-only batch: run a minimal dummy through the vision encoder | ||
| # so every rank's computation graph includes the same parameters. | ||
| device = inputs['input_ids'].device | ||
| patch_size = model.config.vision_config.patch_size | ||
|
|
||
| dummy_pv = torch.zeros( | ||
| 1, 3, 4 * patch_size, 4 * patch_size, | ||
| device=device, dtype=model.dtype | ||
| ) | ||
| dummy_ts = torch.tensor( | ||
| [[4, 4]], device=device, dtype=torch.int32 | ||
| ) | ||
|
|
||
| inputs_embeds = model.get_input_embeddings()(inputs['input_ids']) | ||
| vision_output = model.get_image_features( | ||
| dummy_pv, dummy_ts, downsample_mode=self.downsample_mode | ||
| ) | ||
|
|
||
| dummy_feats = torch.cat(vision_output.pooler_output, dim=0).to( | ||
| device=inputs_embeds.device, dtype=inputs_embeds.dtype | ||
| ) | ||
| inputs_embeds = inputs_embeds + dummy_feats.mean() * 0.0 | ||
|
|
||
| return {'inputs_embeds': inputs_embeds} | ||
|
|
||
| return inputs |
There was a problem hiding this comment.
To ensure robustness and prevent potential runtime errors during training, we should address a few defensive programming concerns:
- DDP/FSDP Unwrapping: If the model is wrapped in
DistributedDataParallel(DDP) or FSDP, accessing attributes likeconfigor methods likeget_input_embeddingsdirectly onmodelwill raise anAttributeErrorbecause these wrappers do not delegate attribute access. Unwrapping the inner module viamodel.moduleif present is highly recommended. - Model Dtype:
model.dtypeis not a standard PyTorchnn.Moduleattribute and might not be available on all custom wrappers or PEFT models. Getting the dtype directly frominputs_embeds.dtypeis much safer. - Defensive handling of
pooler_output: Depending on the model configuration or future updates,vision_output.pooler_outputmight be a single tensor instead of a list of tensors. We should check its type before callingtorch.catto avoid potential type errors.
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
return inputs
pixel_values = inputs.get('pixel_values')
pixel_values_videos = inputs.get('pixel_values_videos')
if pixel_values is None and pixel_values_videos is None:
# Unwrap DDP/FSDP wrapper if present
if hasattr(model, 'module'):
model = model.module
# Text-only batch: run a minimal dummy through the vision encoder
# so every rank's computation graph includes the same parameters.
inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
device = inputs_embeds.device
dtype = inputs_embeds.dtype
patch_size = model.config.vision_config.patch_size
dummy_pv = torch.zeros(
1, 3, 4 * patch_size, 4 * patch_size,
device=device, dtype=dtype
)
dummy_ts = torch.tensor(
[[4, 4]], device=device, dtype=torch.int32
)
vision_output = model.get_image_features(
dummy_pv, dummy_ts, downsample_mode=self.downsample_mode
)
pooler_output = vision_output.pooler_output
if isinstance(pooler_output, (list, tuple)):
dummy_feats = torch.cat(pooler_output, dim=0)
else:
dummy_feats = pooler_output
dummy_feats = dummy_feats.to(device=device, dtype=dtype)
inputs_embeds = inputs_embeds + dummy_feats.mean() * 0.0
return {'inputs_embeds': inputs_embeds}
return inputsThere was a problem hiding this comment.
_post_encode 收到的 model 已经是 unwrap 后的,不需要再处理。hook 注册时经过 accelerator.unwrap_model(model) 拿到原始模型,然后 register_forward_pre_hook 注册在这个对象上。训练时 wrapper 调用 self.module() 触发的是原始模型上的 hook,所以回调中的 model 就是 MiniCPMV4_6ForConditionalGeneration 本身,直接访问 .model.language_model.embed_tokens 没问题。
Before Fix:After Fix: |
Added a _post_encode override in MiniCPMV4_6Template that intercepts text-only batches and injects a zero-contribution dummy forward pass through the vision encoder: