Skip to content

MXFP8 training bug fixes for quantized_model_init and Torch FSDP fp8 all gather#587

Open
sudhu2k wants to merge 2 commits into
devfrom
sudhu/mxfp8_bug_fixes
Open

MXFP8 training bug fixes for quantized_model_init and Torch FSDP fp8 all gather#587
sudhu2k wants to merge 2 commits into
devfrom
sudhu/mxfp8_bug_fixes

Conversation

@sudhu2k
Copy link
Copy Markdown
Contributor

@sudhu2k sudhu2k commented May 15, 2026

Description

Ensure keep_fp8_weight_transpose_cache flag is set to True not only for autocast but also for quantized_model_init.
Fix padding during fp8 all-gather

Fixes: #15425
#15420

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…el_init case and not just autocast case.

Fix padding during fp8 all-gather
@sudhu2k sudhu2k self-assigned this May 15, 2026
@sudhu2k sudhu2k added the ci-level 3 CI test level 3 label May 15, 2026
# NOTE: ROCm/HIP backend uses an unpadded scale-inv layout (see `MXFP8Quantizer.make_empty`),
# so applying the padding here would produce a per-shard scale-inv whose dim-0
# does not match the destination scale-inv allocated for the FSDP2 local shard.
padding_multiples = [128, 4] if not IS_HIP_EXTENSION else [1, 1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for gfx1250 we have some other padding requirements, this should be unified with #568

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. These changes should also be present in that PR accordingly. But I think for now, let's fix the issue on existing archs and make the appropriate changes along with the #568 PR.

Copy link
Copy Markdown
Contributor

@alextmagro alextmagro May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, @matthiasdiener can you work with Sudharshan make sure this is in your PR one way or another?

Comment thread transformer_engine/pytorch/tensor/mxfp8_tensor.py Outdated
@sudhu2k sudhu2k requested a review from alextmagro May 19, 2026 20:10
@alextmagro
Copy link
Copy Markdown
Contributor

LGTM! Just sync with Matthias on that one padding thing please.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants