diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 075db1394b..86b8a4acf4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -324,12 +324,7 @@ def _make_graphed_callables( if cache_quantized_params: # Initialize flag that controls FP8 weight updates - qstate = FP8GlobalStateManager.quantization_state - if qstate.skip_fp8_weight_update_tensor is None: - qstate.skip_fp8_weight_update_tensor = torch.empty( - 1, dtype=torch.float32, device="cuda" - ) - qstate.skip_fp8_weight_update_tensor.fill_(False) + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) # Check callables for c in callables: @@ -841,9 +836,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: - FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_( - skip_fp8_weight_update - ) + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) ctx.cuda_graph_stream = cuda_graph_stream ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..41c72af661 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -409,6 +409,20 @@ class FP8GlobalStateManager: quantization_state = FP8GlobalState() + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """Set the skip fp8 weight update tensor""" + if cls.quantization_state.skip_fp8_weight_update_tensor is None: + cls.quantization_state.skip_fp8_weight_update_tensor = torch.empty( + 1, dtype=torch.float32, device="cuda" + ) + cls.quantization_state.skip_fp8_weight_update_tensor.fill_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> Optional[torch.Tensor]: + """Get the skip fp8 weight update tensor""" + return cls.quantization_state.skip_fp8_weight_update_tensor + @classmethod def reset(cls) -> None: """Reset the global state"""