Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

# Force the transpose cache to be kept whenever the recipe is MXFP8 / MXFP4,
# regardless of whether we are currently inside an fp8_autocast region or not.
# reset_parameters() would disable columnwise_usage for params constructed inside
# `fp8_model_init` / `quantized_model_init`, leaving `_columnwise_data=None`).
if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4():
self.keep_fp8_weight_transpose_cache = True

if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
Expand All @@ -1092,8 +1099,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_initialized = True

self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4():
self.keep_fp8_weight_transpose_cache = True

_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv]
split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples = [128, 4]
# NOTE: ROCm/HIP backend uses an unpadded scale-inv layout (see `MXFP8Quantizer.make_empty`),
# so applying the padding here would produce a per-shard scale-inv whose dim-0
# does not match the destination scale-inv allocated for the FSDP2 local shard.
padding_multiples = [128, 4] if not IS_HIP_EXTENSION else [1, 1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for gfx1250 we have some other padding requirements, this should be unified with #568

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. These changes should also be present in that PR accordingly. But I think for now, let's fix the issue on existing archs and make the appropriate changes along with the #568 PR.

Copy link
Copy Markdown
Contributor

@alextmagro alextmagro May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, @matthiasdiener can you work with Sudharshan make sure this is in your PR one way or another?

for scale_inv, scale_split_size, pad_multiple in zip(
scale_invs, split_sizes_for_scale, padding_multiples
):
Expand Down
Loading