Skip to content

Fix memory overheads with FP4 native weights#2834

Open
WanZzzzzz wants to merge 3 commits intoNVIDIA:mainfrom
WanZzzzzz:fix-fp4-mem
Open

Fix memory overheads with FP4 native weights#2834
WanZzzzzz wants to merge 3 commits intoNVIDIA:mainfrom
WanZzzzzz:fix-fp4-mem

Conversation

@WanZzzzzz
Copy link
Copy Markdown
Contributor

@WanZzzzzz WanZzzzzz commented Apr 3, 2026

Description

Previous implementation concated master weights into one tensor and did fp32->bf16 conversion once. However, this torch.cat creates a full FP32 copy of ALL master weights into one contiguous buffer, causing the increase of peak memory usage. This diminishes the memory savings of FP4 native weights. This PR reverts the change and sticks with per-parameter conversion.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

WanZzzzzz and others added 3 commits April 3, 2026 14:02
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR reverts a previous "optimization" in quantize_master_weights that concatenated all FP32 master weights into a single buffer before doing a dtype cast, which inadvertently created a full FP32 peak-memory copy of every master weight simultaneously. The fix restores per-parameter casting (master_weight.to(model_weight.dtype)) and consolidates this logic to apply to all quantizer types (NVFP4 and FP8 alike) rather than being duplicated inside an else branch.

Key changes:

  • Removes the torch.cat / torch.split batch-conversion block that was the source of the memory regression for FP4 native weights.
  • Moves master_weight.to(model_weight.dtype) to the top of the per-parameter loop so it is shared across all quantizer branches.
  • The elif chain for FP8 quantizer subtypes is promoted out of the else block, resulting in a flatter, cleaner dispatch structure.
  • Comment is updated from fp8_primary_weights to fp8/fp4_primary_weights to reflect the wider applicability of the cast logic.

Confidence Score: 4/5

Safe to merge; the change correctly reverts a memory-regressing optimization and the per-parameter cast logic is functionally equivalent to the prior per-type implementations.

The fix is small, well-scoped, and the logic is straightforward: the removed torch.cat block was creating unnecessary peak memory, and the replacement per-parameter .to() cast is exactly what FP8 paths were already doing. No test coverage is added (checklist item is unchecked), which is the only reason the score is not 5.

No files require special attention; the single changed file is clean.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/utils.py Removes batch torch.cat/torch.split conversion for NVFP4 weights; replaces with per-parameter dtype cast applied uniformly to all quantizer types. Logic is correct and memory regression is fixed; no new tests added.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[quantize_master_weights called] --> B[Loop over model_weights / master_weights]
    B --> C[clear_high_precision_init_val]
    C --> D{master_weight is not None?}
    D -- Yes --> E["master_weight = master_weight.to(model_weight.dtype)\n(per-parameter cast, no torch.cat)"]
    D -- No --> F[get quantizer]
    E --> F
    F --> G{quantizer type?}
    G -- NVFP4Quantizer --> H[nvfp4_params.append]
    G -- Float8Quantizer --> I[delayed_scaling_params.append]
    G -- Float8CurrentScalingQuantizer --> J[current_scaling_params.append]
    G -- Float8BlockQuantizer --> K[blockwise_scaling_params.append]
    G -- MXFP8Quantizer --> L[mxfp8_scaling_params.append]
    G -- other --> M[raise ValueError]
    H & I & J & K & L --> N[End of loop]
    N --> O[_cast_master_weights_to_fp8_delayed_scaling]
    N --> P[_cast_master_weights_to_fp8_current_scaling]
    N --> Q[_cast_master_weights_to_fp8_blockwise_scaling]
    N --> R[_cast_master_weights_to_fp8_mxfp8_scaling]
    N --> S[_cast_master_weights_to_nvfp4_2d]
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants