Fix memory overheads with FP4 native weights#2834
Fix memory overheads with FP4 native weights#2834WanZzzzzz wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR reverts a previous "optimization" in Key changes:
Confidence Score: 4/5Safe to merge; the change correctly reverts a memory-regressing optimization and the per-parameter cast logic is functionally equivalent to the prior per-type implementations. The fix is small, well-scoped, and the logic is straightforward: the removed torch.cat block was creating unnecessary peak memory, and the replacement per-parameter .to() cast is exactly what FP8 paths were already doing. No test coverage is added (checklist item is unchecked), which is the only reason the score is not 5. No files require special attention; the single changed file is clean. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[quantize_master_weights called] --> B[Loop over model_weights / master_weights]
B --> C[clear_high_precision_init_val]
C --> D{master_weight is not None?}
D -- Yes --> E["master_weight = master_weight.to(model_weight.dtype)\n(per-parameter cast, no torch.cat)"]
D -- No --> F[get quantizer]
E --> F
F --> G{quantizer type?}
G -- NVFP4Quantizer --> H[nvfp4_params.append]
G -- Float8Quantizer --> I[delayed_scaling_params.append]
G -- Float8CurrentScalingQuantizer --> J[current_scaling_params.append]
G -- Float8BlockQuantizer --> K[blockwise_scaling_params.append]
G -- MXFP8Quantizer --> L[mxfp8_scaling_params.append]
G -- other --> M[raise ValueError]
H & I & J & K & L --> N[End of loop]
N --> O[_cast_master_weights_to_fp8_delayed_scaling]
N --> P[_cast_master_weights_to_fp8_current_scaling]
N --> Q[_cast_master_weights_to_fp8_blockwise_scaling]
N --> R[_cast_master_weights_to_fp8_mxfp8_scaling]
N --> S[_cast_master_weights_to_nvfp4_2d]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
|
/te-ci pytorch L1 |
Description
Previous implementation concated master weights into one tensor and did fp32->bf16 conversion once. However, this torch.cat creates a full FP32 copy of ALL master weights into one contiguous buffer, causing the increase of peak memory usage. This diminishes the memory savings of FP4 native weights. This PR reverts the change and sticks with per-parameter conversion.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: