From 5415cafd004fa40510f21775ca2cd08a6955bdbc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 31 Mar 2026 16:23:05 +0200 Subject: [PATCH 1/2] drop Signed-off-by: Pawel Gadzinski --- .../dot_product_attention/backends.py | 4 +- .../attention/dot_product_attention/utils.py | 5 +- transformer_engine/pytorch/csrc/common.h | 39 ++++-- transformer_engine/pytorch/csrc/extensions.h | 76 ++++++----- .../pytorch/csrc/extensions/activation.cpp | 127 ++++++++++-------- .../pytorch/csrc/extensions/attention.cpp | 37 +++-- .../pytorch/csrc/extensions/bias.cpp | 31 +++-- .../pytorch/csrc/extensions/cast.cpp | 43 +++--- .../pytorch/csrc/extensions/normalization.cpp | 18 ++- .../pytorch/csrc/extensions/pybind.cpp | 78 ++++++----- transformer_engine/pytorch/csrc/quantizer.cpp | 39 ++++-- transformer_engine/pytorch/module/_common.py | 2 + transformer_engine/pytorch/module/base.py | 37 ++++- .../pytorch/module/grouped_linear.py | 29 +++- .../pytorch/module/layernorm_linear.py | 18 ++- .../pytorch/module/layernorm_mlp.py | 53 +++++--- transformer_engine/pytorch/module/linear.py | 13 +- .../pytorch/ops/basic/activation.py | 2 +- .../pytorch/ops/basic/swiglu.py | 3 +- transformer_engine/pytorch/quantization.py | 10 +- .../pytorch/quantized_tensor.py | 39 ++++-- .../pytorch/tensor/_quantization_helpers.py | 5 +- .../pytorch/tensor/float8_tensor.py | 72 ++++++---- transformer_engine/pytorch/tensor/utils.py | 7 +- 24 files changed, 496 insertions(+), 291 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 442366035a..5ea7917a5a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -456,10 +456,10 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] if fp8_recipe.float8_current_scaling(): S_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=S_quantizer.dtype, device="cuda" + fp8_dtype=S_quantizer.dtype, ) dP_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=dP_quantizer.dtype, device="cuda" + fp8_dtype=dP_quantizer.dtype, ) if "2" in qkv_layout or "3" in qkv_layout: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 170cb2cd34..bd87082f56 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2183,10 +2183,7 @@ def print_quantizers( type_str = "DS" elif isinstance(q, Float8CurrentScalingQuantizer): type_str = "CS" - print( - f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" - f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" - ) + print(f"{label} >> {names[i]:14s}: {type_str}") def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 63a2e86e67..9afa4bc799 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -194,9 +194,7 @@ class Float8Quantizer : public Quantizer { class Float8CurrentScalingQuantizer : public Quantizer { public: - at::Tensor scale; at::Tensor scale_inv; - at::Tensor amax; DType dtype; bool with_amax_reduction; c10::intrusive_ptr amax_reduction_group; @@ -217,33 +215,52 @@ class Float8CurrentScalingQuantizer : public Quantizer { py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, size_t logical_last_dim) const override; - /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + /*! @brief Construct an unquantized tensor with an amax buffer. * - * The amax is zeroed out. Most TE kernels that output amax expect - * amax to be initialized to zero. + * The provided amax tensor is zeroed out and set on the output tensor. + * Most TE kernels that output amax expect amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype, std::optional data = std::nullopt); + const std::vector& shape, DType dtype, at::Tensor amax, + std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; + /*! @brief Quantize to FP8 (virtual fallback, allocates local amax/scale) */ void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - /*! @brief Quantize to FP8, skipping local amax computation + /*! @brief Quantize to FP8 using provided amax/scale workspace buffers */ + void quantize(const TensorWrapper& input, TensorWrapper& out, + at::Tensor amax, at::Tensor scale, + const std::optional& noop_flag = std::nullopt); + + /*! @brief Quantize to FP8, skipping local amax computation. * - * The quantizer's amax pointer is assumed to already hold the local - * amax. The amax may still be reduced across the amax reduction - * group. + * The provided amax tensor is assumed to already hold the local + * amax (e.g. computed by a fused LN kernel). The amax may still + * be reduced across the amax reduction group. */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, + at::Tensor amax, at::Tensor scale, const std::optional& noop_flag = std::nullopt); private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, bool compute_amax); + const std::optional& noop_flag, bool compute_amax, + at::Tensor amax, at::Tensor scale); }; +/*! @brief Extract amax and scale from a quantizer workspace tensor. + * + * Workspace layout: [amax, scale] (2 float32). + */ +inline std::pair split_quantizer_workspace(const at::Tensor& workspace) { + NVTE_CHECK(workspace.numel() >= 2, "Quantizer workspace must have at least 2 float32 elements"); + return {workspace.slice(0, 0, 1).contiguous(), + workspace.slice(0, 1, 2).contiguous()}; +} + class Float8BlockQuantizer : public Quantizer { public: // Which float8 type is used for q data. diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1c5116a8da..188c27c6ff 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -92,7 +92,8 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, + py::handle s_quantizer, py::handle o_quantizer, + const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); @@ -215,57 +216,57 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = st **************************************************************************************************/ /* GLU (sigmoid gate) */ -py::object glu(const at::Tensor &input, py::handle quantizer); +py::object glu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); /* GELU and variants*/ -py::object gelu(const at::Tensor &input, py::handle quantizer); +py::object gelu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object geglu(const at::Tensor &input, py::handle quantizer); +py::object geglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object qgelu(const at::Tensor &input, py::handle quantizer); +py::object qgelu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object qgeglu(const at::Tensor &input, py::handle quantizer); +py::object qgeglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); /* ReLU and variants*/ -py::object relu(const at::Tensor &input, py::handle quantizer); +py::object relu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object reglu(const at::Tensor &input, py::handle quantizer); +py::object reglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object srelu(const at::Tensor &input, py::handle quantizer); +py::object srelu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object sreglu(const at::Tensor &input, py::handle quantizer); +py::object sreglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); /* Silu and variants*/ -py::object silu(const at::Tensor &input, py::handle quantizer); +py::object silu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object swiglu(const at::Tensor &input, py::handle quantizer); +py::object swiglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); -py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, std::optional qw, float limit, float alpha); py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, - float limit, float alpha); + std::optional qw, float limit, float alpha); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ @@ -278,7 +279,8 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, float eps, py::object ln_out, py::handle quantizer, DType out_dtype, const int sm_margin, - const bool zero_centered_gamma); + const bool zero_centered_gamma, + std::optional quantizer_workspace = std::nullopt); /*************************************************************************************************** * RMSNorm @@ -295,14 +297,16 @@ std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor & std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object ln_out, py::handle quantizer, DType otype, - const int sm_margin, const bool zero_centered_gamma); + const int sm_margin, const bool zero_centered_gamma, + std::optional quantizer_workspace = std::nullopt); /*************************************************************************************************** * Cast **************************************************************************************************/ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop_flag); + std::optional noop_flag, + std::optional workspace = std::nullopt); py::object dequantize(const py::handle &input, DType otype); @@ -315,28 +319,30 @@ std::vector multi_tensor_quantize(const std::vector &ten std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, std::vector quantizer_list, - bool disable_bulk_allocation = false); + bool disable_bulk_allocation = false, + std::optional> quantizer_workspaces = std::nullopt); /*************************************************************************************************** * Bias gradient fusions **************************************************************************************************/ -std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer); +std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer, + std::optional quantizer_workspace = std::nullopt); std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); + py::handle quantizer, std::optional qw = std::nullopt); std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); + py::handle quantizer, std::optional qw = std::nullopt); std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); + py::handle quantizer, std::optional qw = std::nullopt); std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); + py::handle quantizer, std::optional qw = std::nullopt); std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer); + py::handle quantizer, std::optional qw = std::nullopt); /*************************************************************************************************** * Dropout diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 99b9c1fefa..95edf77f4c 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -15,8 +15,9 @@ using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t); using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, - Args&&... args) { +py::object activation_helper(const at::Tensor& input, py::handle quantizer, + std::optional quantizer_workspace, + int shape_divisor = 1, Args&&... args) { init_extension(); // Input tensor @@ -86,8 +87,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int { auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); auto [temp_nvte, _] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype, cs_amax); NVTE_SCOPED_GIL_RELEASE({ if constexpr (act_func == nullptr) { act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., @@ -96,7 +98,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int act_func(input_nvte.data(), temp_nvte.data(), stream); } }); - fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte, cs_amax, cs_scale); } break; case Impl::FUSED_ACTIVATION_AMAX_NVFP4: @@ -126,7 +128,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int template py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer, Args&&... args) { + py::handle quantizer, + std::optional quantizer_workspace, + Args&&... args) { init_extension(); // Grad output and input tensors @@ -198,8 +202,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i { auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); auto [temp_nvte, _] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype, cs_amax); NVTE_SCOPED_GIL_RELEASE({ if constexpr (dact_func == nullptr) { dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), @@ -208,7 +213,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); } }); - fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte, cs_amax, cs_scale); } break; case Impl::FUSED_ACTIVATION_AMAX_NVFP4: @@ -238,103 +243,115 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } // namespace /* GELU and variants */ -py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object gelu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw); } -py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object glu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object glu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw, 2); } -py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object geglu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw, 2); } -py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object qgelu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw); } -py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object qgeglu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw, 2); } -py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } /* ReLU and variants */ -py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object relu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw); } -py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object reglu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw, 2); } -py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object srelu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw); } -py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object sreglu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw, 2); } -py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } /* Silu and variants */ -py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object silu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw); } -py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } -py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object swiglu(const at::Tensor& input, py::handle quantizer, std::optional qw) { + return activation_helper(input, quantizer, qw, 2); } -py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + std::optional qw) { + return dactivation_helper(grad, input, quantizer, qw); } /* clamped functions */ -py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { - return activation_helper(input, quantizer, 2, limit, alpha); +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, + std::optional qw, float limit, float alpha) { + return activation_helper(input, quantizer, qw, 2, limit, alpha); } py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, - float limit, float alpha) { - return dactivation_helper(grad, input, quantizer, limit, alpha); + std::optional qw, float limit, float alpha) { + return dactivation_helper(grad, input, quantizer, qw, limit, alpha); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index ff60bb87bb..4b0999b149 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -58,7 +58,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, bool create_hp_tensor_for_cs, - std::optional data) { + std::optional data, + std::optional quantizer_workspace) { std::unique_ptr T_quantizer = convert_quantizer(quantizer); TensorWrapper te_T; py::object py_T; @@ -79,11 +80,16 @@ std::pair quantizer_helper(py::handle quantizer, // current scaling auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { + at::Tensor ws = quantizer_workspace.has_value() + ? *quantizer_workspace + : at::empty({2}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + auto [cs_amax, cs_scale] = split_quantizer_workspace(ws); if (data.has_value()) { std::tie(te_T, py_T) = - T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, cs_amax, data.value()); } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, cs_amax); } } else { std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); @@ -106,7 +112,8 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, + py::handle s_quantizer, py::handle o_quantizer, + const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { // Ensure that cuDNN handle is created on the correct device, @@ -126,7 +133,8 @@ std::vector fused_attn_fwd( // create S tensor TensorWrapper te_S; py::object py_S; - std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + std::tie(te_S, py_S) = + quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt, std::nullopt); // create O tensor TensorWrapper te_O; @@ -137,7 +145,8 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); + std::tie(te_O, py_O) = + quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt, std::nullopt); // construct NVTE tensors TensorWrapper te_Bias; @@ -317,7 +326,7 @@ std::vector fused_attn_bwd( const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, - const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, + const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { auto none = py::none(); @@ -332,9 +341,10 @@ std::vector fused_attn_bwd( // create S and dP tensors TensorWrapper te_S, te_dP; py::object py_S, py_dP; - std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + std::tie(te_S, py_S) = + quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt, std::nullopt); std::tie(te_dP, py_dP) = - quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt); + quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt, std::nullopt); // create dQ, dK, dV tensors TensorWrapper te_dQ, te_dK, te_dV; @@ -431,9 +441,12 @@ std::vector fused_attn_bwd( NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); - std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); - std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); + std::tie(te_dQ, py_dQ) = + quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ, std::nullopt); + std::tie(te_dK, py_dK) = + quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK, std::nullopt); + std::tie(te_dV, py_dV) = + quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV, std::nullopt); // construct NVTE tensors if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index c59e3c4f64..6a3a3f5876 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -19,7 +19,8 @@ namespace transformer_engine { namespace pytorch { -std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer) { +std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer, + std::optional quantizer_workspace) { using namespace transformer_engine::pytorch::detail; init_extension(); @@ -109,7 +110,8 @@ std::vector dact_dbias( void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, cudaStream_t), void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t), - at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py) { + at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py, + std::optional quantizer_workspace) { using namespace transformer_engine::pytorch::detail; init_extension(); @@ -208,14 +210,15 @@ std::vector dact_dbias( dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Invalid quantizer for fused dact-amax kernel impl"); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); auto [temp_nvte, temp_py] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype, cs_amax); NVTE_SCOPED_GIL_RELEASE({ dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); }); const auto temp_torch = temp_py.cast(); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte, cs_amax, cs_scale); break; } case Impl::FUSED_DACT_AMAX_NVFP4: @@ -245,28 +248,28 @@ std::vector dact_dbias( } // namespace std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer); + py::handle quantizer, std::optional qw) { + return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer, qw); } std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer); + py::handle quantizer, std::optional qw) { + return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer, qw); } std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer); + py::handle quantizer, std::optional qw) { + return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer, qw); } std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer); + py::handle quantizer, std::optional qw) { + return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer, qw); } std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer); + py::handle quantizer, std::optional qw) { + return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer, qw); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index cb3434ec52..de36b99d54 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -34,25 +34,22 @@ std::vector get_tensor_shape(const TensorWrapper &tensor) { } // namespace py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, - std::optional noop_flag) { + std::optional noop_flag, std::optional workspace) { // Convert quantizer to C++ object auto quantizer_cpp = convert_quantizer(quantizer); + // Extract amax/scale from workspace for Float8CurrentScaling + const bool is_cs = detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()); + at::Tensor amax, scale; + if (is_cs) { + NVTE_CHECK(workspace.has_value(), "Float8CurrentScalingQuantizer requires workspace"); + std::tie(amax, scale) = split_quantizer_workspace(*workspace); + } + // Convert input tensor to C++ object auto input_contiguous = tensor.contiguous(); auto input_cpp = makeTransformerEngineTensor(input_contiguous); - // Set amax if use_existing_amax = true (only valid for CS) - bool use_existing_amax = false; - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - use_existing_amax = quantizer.attr("use_existing_amax").cast(); - if (use_existing_amax) { - const at::Tensor &amax = quantizer.attr("amax").cast(); - input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); - } - } - // Initialize output tensor TensorWrapper output_cpp; py::object output_py; @@ -71,9 +68,9 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob } // Perform quantization - if (use_existing_amax) { + if (is_cs) { auto *quantizer_cs = dynamic_cast(quantizer_cpp.get()); - quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp); + quantizer_cs->quantize(input_cpp, output_cpp, amax, scale, noop_flag_cpp); } else { quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); } @@ -257,7 +254,8 @@ namespace { void multi_tensor_quantize_impl(const std::vector &input_list, std::vector &quantizer_py_list, std::vector> &quantizer_cpp_list, - std::vector &output_list) { + std::vector &output_list, + const std::optional> &workspaces = std::nullopt) { // Check number of tensors const size_t num_tensors = input_list.size(); NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, @@ -298,7 +296,14 @@ void multi_tensor_quantize_impl(const std::vector &input_list, } else { // Quantize kernels individually for (size_t i = 0; i < num_tensors; ++i) { - quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); + if (workspaces.has_value() && + detail::IsFloat8CurrentScalingQuantizers(quantizer_py_list[i].ptr())) { + auto *quantizer_cs = dynamic_cast(quantizer_cpp_list[i].get()); + auto [amax, scale] = split_quantizer_workspace((*workspaces)[i]); + quantizer_cs->quantize(input_list[i], output_list[i], amax, scale); + } else { + quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); + } } } } @@ -1258,7 +1263,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, std::vector quantizer_list, - bool disable_bulk_allocation) { + bool disable_bulk_allocation, + std::optional> quantizer_workspaces) { init_extension(); // Check number of tensors @@ -1405,7 +1411,8 @@ std::vector split_quantize(const at::Tensor &tensor, } default: // General multi-tensor quantization - multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); + multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list, + quantizer_workspaces); } return output_py_list; diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 3214c3a9db..4519c6d27c 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -61,7 +61,8 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, - const bool zero_centered_gamma) { + const bool zero_centered_gamma, + std::optional quantizer_workspace) { using namespace transformer_engine::pytorch::detail; // Ensure that cuDNN handle is created on the correct device, @@ -154,8 +155,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } break; case Impl::FUSED_NORM_AMAX_FP8: { auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); std::tie(unquantized_out_nvte, unquantized_out) = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype, cs_amax); kernel_out_nvte = &unquantized_out_nvte; } break; case Impl::FUSED_NORM_AMAX_NVFP4: { @@ -199,7 +201,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } break; case Impl::FUSED_NORM_AMAX_FP8: { auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, cs_amax, cs_scale); } break; case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); @@ -303,7 +306,8 @@ std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor & std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object out, py::handle quantizer, DType out_dtype, - const int sm_margin, const bool zero_centered_gamma) { + const int sm_margin, const bool zero_centered_gamma, + std::optional quantizer_workspace) { using namespace transformer_engine::pytorch::detail; // Ensure that cuDNN handle is created on the correct device, @@ -390,8 +394,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } break; case Impl::FUSED_NORM_AMAX_FP8: { auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); std::tie(unquantized_out_nvte, unquantized_out) = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype, cs_amax); kernel_out_nvte = &unquantized_out_nvte; } break; case Impl::FUSED_NORM_AMAX_NVFP4: { @@ -433,7 +438,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } break; case Impl::FUSED_NORM_AMAX_FP8: { auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); - fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte, cs_amax, cs_scale); } break; case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c590a3c9e2..d348f4bdd1 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -136,13 +136,15 @@ void init_extension() { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), - py::arg("output") = py::none(), py::arg("noop") = py::none()); + py::arg("output") = py::none(), py::arg("noop") = py::none(), + py::arg("workspace") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, - "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); + "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none()); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), @@ -153,74 +155,81 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); /* GLU (sigmoid gate) */ m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); /* GELU and variants*/ m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); /* ReLU and variants */ m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); /* SwiGLU and variants */ m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none(), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* Backward of GLU */ m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); /* Backward of ReLU and variants */ m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none()); /* Backward of SiLU and variants */ m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); + py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none(), + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none()); m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none()); m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none()); m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize", - py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), + py::arg("quantizer_workspace") = py::none()); m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu, "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer")); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); // Permutation functions m.def("moe_permute_fwd", transformer_engine::pytorch::moe_permute_fwd, "MOE permute FWD", @@ -261,11 +270,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Other granular functions m.def("layernorm_fwd", &transformer_engine::pytorch::layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); + py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), + py::arg("quantizer_workspace") = py::none()); m.def("layernorm_bwd", &transformer_engine::pytorch::layernorm_bwd, "Backward of LayerNorm"); m.def("rmsnorm_fwd", &transformer_engine::pytorch::rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); + py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), + py::arg("quantizer_workspace") = py::none()); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, "Fused backward of RMSNorm + add"); @@ -273,7 +284,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); + py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false, + py::arg("quantizer_workspaces") = py::none()); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm_for_grouped_tensor", diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..cb9c9bb99e 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -532,11 +532,7 @@ void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { - const at::Tensor& scale = quantizer.attr("scale").cast(); - const at::Tensor& amax = quantizer.attr("amax").cast(); const DType type = quantizer.attr("dtype").cast(); - this->amax = amax; - this->scale = scale; this->dtype = type; // Get amax reduction group if needed @@ -557,12 +553,8 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q } void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them) - tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), - getTensorShape(scale)); - at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); - tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); + // No-op: amax and scale buffers are set directly by quantize/quantize_with_amax + // from externally provided workspace tensors. } std::pair Float8CurrentScalingQuantizer::create_tensor( @@ -751,6 +743,7 @@ std::pair Float8CurrentScalingQuantizer::creat std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, DType dtype, + at::Tensor amax, std::optional data) { amax.zero_(); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) @@ -856,7 +849,8 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, - bool compute_amax) { + bool compute_amax, + at::Tensor amax, at::Tensor scale) { auto stream = at::cuda::getCurrentCUDAStream(); // Nothing to be done if input is empty @@ -864,6 +858,12 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te return; } + // Set amax and scale buffers on the output tensor + out.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + out.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + // Quantization configs QuantizationConfigWrapper quant_config; if (noop_flag) { @@ -897,15 +897,26 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { - this->quantize_impl(input, out, noop_flag, true); + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto amax = at::empty({1}, opts); + auto scale = at::empty({1}, opts); + this->quantize_impl(input, out, noop_flag, true, amax, scale); +} + +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + at::Tensor amax, at::Tensor scale, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true, amax, scale); } void Float8CurrentScalingQuantizer::quantize_with_amax( - TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { + TensorWrapper& input, TensorWrapper& out, + at::Tensor amax, at::Tensor scale, + const std::optional& noop_flag) { NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), "Input does not use the appropriate amax tensor"); input.set_amax(nullptr, DType::kFloat32, input.defaultShape); - this->quantize_impl(input, out, noop_flag, false); + this->quantize_impl(input, out, noop_flag, false, amax, scale); } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index bf5a230e84..443fd8f46b 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -42,6 +42,7 @@ def apply_normalization( normalization: str, fwd_ln_sm_margin: int, zero_centered_gamma: bool, + quantizer_workspace=None, ): """Apply normalization to input.""" normalization_func = _get_normalization_func(normalization, True) @@ -56,6 +57,7 @@ def apply_normalization( TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, fwd_ln_sm_margin, zero_centered_gamma, + quantizer_workspace, ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..cd7500733e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -661,6 +661,7 @@ def __init__(self, name: Optional[str] = None) -> None: self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, QuantizedTensor] = {} + self._quantizer_workspace: Optional[torch.Tensor] = None self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_store = None @@ -807,8 +808,25 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) + self._init_quantizer_workspace(recipe) + self.fast_setattr("fp8_meta_tensors_initialized", True) + def _init_quantizer_workspace(self, recipe: Recipe) -> None: + """Allocate shared workspace buffer for stateless quantizers. + + A single 2-float32 buffer ``[amax, scale]`` is reused by all + Float8CurrentScaling quantizers because quantization ops run + sequentially on the same CUDA stream. + """ + if not recipe.float8_current_scaling(): + return + + if self._quantizer_workspace is None or self._quantizer_workspace.numel() < 2: + self._quantizer_workspace = torch.zeros( + 2, dtype=torch.float32, device="cuda", + ) + def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" @@ -1068,6 +1086,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() + self._quantizer_workspace = None def prepare_forward( self, @@ -1174,6 +1193,7 @@ def grad_output_preprocess( grad_output: torch.Tensor, row_parallel_mode: bool, quantizer: Optional[Quantizer], + quantizer_workspace: Optional[torch.Tensor] = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Utility function for backward. Returns tuple in order (all optional/None based on training precion/recipe): @@ -1216,7 +1236,7 @@ def grad_output_preprocess( Float8BlockwiseQTensorStorage, ), ): - grad_output = quantizer(grad_output) + grad_output = quantizer(grad_output, workspace=quantizer_workspace) # Copy into communication buffer, and replace original gradient with it grad_output, _ = fill_userbuffers_buffer_for_all_gather( @@ -1236,7 +1256,7 @@ def grad_output_preprocess( # Debug without all-gather: unfused cast and bgrad # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None if ctx.debug: - grad_output_ = quantizer(grad_output) + grad_output_ = quantizer(grad_output, workspace=quantizer_workspace) if ctx.use_bias: grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: @@ -1259,12 +1279,14 @@ def grad_output_preprocess( grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: if isinstance(quantizer, Float8BlockQuantizer): - # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + grad_bias, grad_output = tex.bgrad_quantize( + grad_output, quantizer, + quantizer_workspace=quantizer_workspace, + ) if not isinstance(grad_output, QuantizedTensorStorage): - grad_output = quantizer(grad_output) + grad_output = quantizer(grad_output, workspace=quantizer_workspace) return grad_output, grad_bias def register_parameter(self, name, param, **kwargs): @@ -1402,6 +1424,7 @@ def get_weight_workspace( skip_update_flag: Optional[torch.Tensor] = None, fsdp_group: Optional[dist_group_type] = None, workspace_dtype: Optional[torch.dtype] = None, + quantizer_workspace: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Get workspace buffer for weights and maybe update its values @@ -1500,7 +1523,7 @@ def get_weight_workspace( # Setting internal=True would cause the data to be removed in prepare_for_saving(...). quantizer_internal = quantizer.internal quantizer.internal = False - out = quantizer.quantize(tensor, dtype=workspace_dtype) + out = quantizer.quantize(tensor, dtype=workspace_dtype, workspace=quantizer_workspace) if cache_name is not None: quantizer.internal = quantizer_internal @@ -1518,7 +1541,7 @@ def get_weight_workspace( if hasattr(out, "quantize_"): out.quantize_(tensor, noop_flag=skip_update_flag) else: - tex.quantize(tensor, quantizer, out, skip_update_flag) + tex.quantize(tensor, quantizer, out, skip_update_flag, quantizer_workspace) return out def _load_from_state_dict( diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0adda48e36..cf4267b22c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,6 +96,7 @@ def forward( skip_fp8_weight_update, save_original_input, debug, + quantizer_workspaces, ) = non_tensor_args num_gemms = len(m_splits) @@ -153,6 +154,7 @@ def forward( m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading, + quantizer_workspaces=quantizer_workspaces, ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -178,6 +180,7 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, workspace_dtype=activation_dtype, + quantizer_workspace=quantizer_workspaces[i] if quantizer_workspaces is not None else None, ) weights_fp8.append(weight_fp8) @@ -308,6 +311,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.quantizer_workspaces = quantizer_workspaces # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -348,6 +352,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_biases[i], grad_output[i] = tex.bgrad_quantize( grad_output_mats[i], ctx.grad_output_quantizers[i], + quantizer_workspace=ctx.quantizer_workspaces[i] if ctx.quantizer_workspaces is not None else None, ) else: # Unfused bias grad and multi-tensor quantize @@ -357,6 +362,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, + quantizer_workspaces=ctx.quantizer_workspaces, ) else: # Multi-tensor quantize @@ -364,6 +370,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, + quantizer_workspaces=ctx.quantizer_workspaces, ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) @@ -452,7 +459,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers, + quantizer_workspaces=ctx.quantizer_workspaces) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, @@ -634,6 +642,7 @@ def __init__( super().__init__(name) self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self._quantizer_workspaces = None self.num_gemms = num_gemms self.in_features = in_features self.out_features = out_features @@ -741,6 +750,23 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True + def _init_quantizer_workspace(self, recipe: Recipe) -> None: + """Allocate per-group workspace buffers for stateless quantizers. + + Each GEMM group needs its own 2-float32 [amax, scale] buffer + because groups may be quantized in parallel. + """ + if not recipe.float8_current_scaling(): + self._quantizer_workspaces = None + return + n = self.num_gemms + needed = n * 2 + if self._quantizer_workspace is None or self._quantizer_workspace.numel() < needed: + self._quantizer_workspace = torch.zeros(needed, dtype=torch.float32, device="cuda") + self._quantizer_workspaces = [ + self._quantizer_workspace[i * 2 : i * 2 + 2] for i in range(n) + ] + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1009,6 +1035,7 @@ def forward( None, # skip_fp8_weight_update self.save_original_input, debug, + self._quantizer_workspaces, ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ed91bc1235..fdc1ec8fad 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -113,6 +113,7 @@ def forward( grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer, + quantizer_workspace, cpu_offloading, tp_group, tp_size, @@ -227,6 +228,7 @@ def forward( normalization, fwd_ln_sm_margin, zero_centered_gamma, + quantizer_workspace=quantizer_workspace, ) nvtx_range_pop(f"{nvtx_label}.norm") @@ -248,16 +250,16 @@ def forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8 or debug: - ln_out = input_quantizer(ln_out) + ln_out = input_quantizer(ln_out, workspace=quantizer_workspace) input_quantizer.set_usage(rowwise=True, columnwise=False) - ln_out_total = input_quantizer(ln_out_total) + ln_out_total = input_quantizer(ln_out_total, workspace=quantizer_workspace) else: quantizer = None if fp8 or debug: quantizer = input_quantizer # custom recipe doesn't need to support quantized AG if not with_quantized_norm and not custom: - ln_out = quantizer(ln_out) + ln_out = quantizer(ln_out, workspace=quantizer_workspace) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( @@ -274,7 +276,7 @@ def forward( ) else: if (fp8 or debug) and not with_quantized_norm: - ln_out = input_quantizer(ln_out) + ln_out = input_quantizer(ln_out, workspace=quantizer_workspace) ln_out_total = ln_out nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # ------------------------------------------------------ @@ -307,6 +309,7 @@ def forward( skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + quantizer_workspace=quantizer_workspace, ) weightmat.update_usage(rowwise_usage=True) @@ -487,6 +490,7 @@ def forward( ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.input_quantizer = input_quantizer + ctx.quantizer_workspace = quantizer_workspace ctx.owns_input = inputmat is not inp ctx.weight = weight ctx.activation_dtype = activation_dtype @@ -642,6 +646,7 @@ def backward( grad_outputs[0], ctx.parallel_mode == "row", ctx.grad_output_quantizer, + ctx.quantizer_workspace, ) nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") @@ -819,14 +824,14 @@ def backward( ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.input_quantizer(ln_out_total) + ln_out_total = ctx.input_quantizer(ln_out_total, workspace=ctx.quantizer_workspace) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output) + grad_output = ctx.grad_output_quantizer(grad_output, workspace=ctx.quantizer_workspace) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD @@ -1551,6 +1556,7 @@ def forward( grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer, + self._quantizer_workspace, is_cpu_offload_enabled(), self.tp_group, self.tp_size, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index cc3dcc4064..98434b6419 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -203,6 +203,7 @@ def _forward( fc2_grad_input_quantizer, fc2_grad_weight_quantizer, fc2_grad_output_quantizer, + quantizer_workspace, cpu_offloading, tp_group, tp_size, @@ -410,6 +411,7 @@ def _forward( normalization, fwd_ln_sm_margin, zero_centered_gamma, + quantizer_workspace=quantizer_workspace, ) ln_out_return = None @@ -430,16 +432,16 @@ def _forward( ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total if fp8 or debug: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer(ln_out, workspace=quantizer_workspace) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) - ln_out_total = fc1_input_quantizer(ln_out_total) + ln_out_total = fc1_input_quantizer(ln_out_total, workspace=quantizer_workspace) else: quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer # custom recipe doesn't need to support quantized AG if not with_quantized_norm and not custom: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer(ln_out, workspace=quantizer_workspace) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: # Copy into Userbuffers buffer @@ -459,7 +461,7 @@ def _forward( ) else: if (fp8 or debug) and not with_quantized_norm: - ln_out = fc1_input_quantizer(ln_out) + ln_out = fc1_input_quantizer(ln_out, workspace=quantizer_workspace) ln_out_total = ln_out # Cast weights to expected dtype @@ -490,6 +492,7 @@ def _forward( skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + quantizer_workspace=quantizer_workspace, ) fc2_weight_final = module.get_weight_workspace( tensor=fc2_weight, @@ -499,6 +502,7 @@ def _forward( skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + quantizer_workspace=quantizer_workspace, ) fc1_weight_final.update_usage(rowwise_usage=True) fc2_weight_final.update_usage(rowwise_usage=True) @@ -588,26 +592,32 @@ def _forward( elif debug: fc1_out, *_ = fc1_outputs act_out = activation_func(fc1_out, None, **act_params) - act_out = fc2_input_quantizer(act_out) + act_out = fc2_input_quantizer(act_out, workspace=quantizer_workspace) else: fc1_out, *_ = fc1_outputs if fp8: recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_block_scaling(): - # tex.quantize does not support GELU fusion for blockwise act_out = activation_func(fc1_out, None, **act_params) act_out = tex.quantize(act_out, fc2_input_quantizer) elif recipe.custom(): - # tex.quantize does not support custom quantizers act_out = activation_func(fc1_out, None, **act_params) - act_out = fc2_input_quantizer(act_out) + act_out = fc2_input_quantizer(act_out, workspace=quantizer_workspace) else: - act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) + act_out = activation_func( + fc1_out, fc2_input_quantizer, + quantizer_workspace=quantizer_workspace, + **act_params, + ) else: if fp8_calibration: act_out = activation_func(fc1_out, None, **act_params) else: - act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) + act_out = activation_func( + fc1_out, fc2_input_quantizer, + quantizer_workspace=quantizer_workspace, + **act_params, + ) if not fp8 and fp8_calibration: if fc2_input_quantizer is not None: @@ -788,6 +798,7 @@ def _forward( ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer ctx.fc1_input_quantizer = fc1_input_quantizer ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.quantizer_workspace = quantizer_workspace ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad @@ -1046,7 +1057,8 @@ def backward( grad_output, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer + ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer, + ctx.quantizer_workspace, ) # Launch tensor-parallel communication for FC1 GEMM input @@ -1196,14 +1208,14 @@ def backward( act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) - act_out = ctx.fc2_input_quantizer(act_out) + act_out = ctx.fc2_input_quantizer(act_out, workspace=ctx.quantizer_workspace) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.fc2_grad_output_quantizer(grad_output) + grad_output = ctx.fc2_grad_output_quantizer(grad_output, workspace=ctx.quantizer_workspace) # Whether to set grad arg in general_gemm grad_arg = True @@ -1283,12 +1295,12 @@ def fc2_wgrad_gemm( assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: - dact = ctx.fc1_grad_output_quantizer(dact) + dact = ctx.fc1_grad_output_quantizer(dact, workspace=ctx.quantizer_workspace) elif ctx.debug: dact_func = _act_func(ctx.activation)[1] dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) - dact = ctx.fc1_grad_output_quantizer(dact) + dact = ctx.fc1_grad_output_quantizer(dact, workspace=ctx.quantizer_workspace) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and ctx.fp8 @@ -1301,6 +1313,7 @@ def fc2_wgrad_gemm( fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer, + quantizer_workspace=ctx.quantizer_workspace, **act_params, ) # quantize bgrad gelu fused else: @@ -1320,10 +1333,11 @@ def fc2_wgrad_gemm( or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) - dact = ctx.fc1_grad_output_quantizer(dact) + dact = ctx.fc1_grad_output_quantizer(dact, workspace=ctx.quantizer_workspace) else: fc1_bias_grad, dact = tex.bgrad_quantize( - dact, ctx.fc1_grad_output_quantizer + dact, ctx.fc1_grad_output_quantizer, + quantizer_workspace=ctx.quantizer_workspace, ) else: fuse_gemm_and_bias_fc1_wgrad = ( @@ -1437,7 +1451,7 @@ def fc2_wgrad_gemm( ln_out_total.update_usage(columnwise_usage=True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.fc1_input_quantizer(ln_out_total) + ln_out_total = ctx.fc1_input_quantizer(ln_out_total, workspace=ctx.quantizer_workspace) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and @@ -1447,7 +1461,7 @@ def fc2_wgrad_gemm( dact.update_usage(columnwise_usage=True) else: ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - dact = ctx.fc1_grad_output_quantizer(dact) + dact = ctx.fc1_grad_output_quantizer(dact, workspace=ctx.quantizer_workspace) # Output buffer for overlapping grad input # reduce-scatter with wgrad GEMM @@ -2121,6 +2135,7 @@ def forward( fc2_grad_input_quantizer, fc2_grad_weight_quantizer, fc2_grad_output_quantizer, + self._quantizer_workspace, is_cpu_offload_enabled(), self.tp_group, self.tp_size, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ea921341a4..91c45a571e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -104,6 +104,7 @@ def forward( grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer, + quantizer_workspace, fuse_wgrad_accumulation, cpu_offloading, tp_group, @@ -198,7 +199,7 @@ def forward( # tensor will not be cached for backward pass input_quantizer.set_usage(columnwise=False) own_quantized_input = False - inputmat = input_quantizer(inputmat) + inputmat = input_quantizer(inputmat, workspace=quantizer_workspace) else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -231,7 +232,7 @@ def forward( input_quantizer.set_usage( rowwise=True, columnwise=backward_needs_input and not save_original_input ) - inputmat = input_quantizer(inputmat) + inputmat = input_quantizer(inputmat, workspace=quantizer_workspace) own_quantized_input = True else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -273,6 +274,7 @@ def forward( skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, + quantizer_workspace=quantizer_workspace, ) weightmat.update_usage(rowwise_usage=True) @@ -446,6 +448,7 @@ def forward( ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer + ctx.quantizer_workspace = quantizer_workspace ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: # This check is needed to ensure that main_grad is not created @@ -601,6 +604,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, ctx.parallel_mode == "row", ctx.grad_output_quantizer, + ctx.quantizer_workspace, ) nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") @@ -773,7 +777,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - inputmat_total = ctx.input_quantizer(inputmat_total) + inputmat_total = ctx.input_quantizer(inputmat_total, workspace=ctx.quantizer_workspace) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and @@ -815,7 +819,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output) + grad_output = ctx.grad_output_quantizer(grad_output, workspace=ctx.quantizer_workspace) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD @@ -1429,6 +1433,7 @@ def forward( grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer, + self._quantizer_workspace, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 13cb519c19..7c7d98e475 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -105,7 +105,7 @@ def op_forward( # Quantize input to FP8 before caching if needed if self.cache_quantized_input: - input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3) input_quantizer.set_usage(rowwise=True, columnwise=False) x = input_quantizer(x) diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index b4427df41a..fed165b2d8 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -119,7 +119,6 @@ def op_forward( if self.cache_quantized_input: input_quantizer = Float8CurrentScalingQuantizer( tex.DType.kFloat8E4M3, - input_.device, ) input_quantizer.set_usage(rowwise=True, columnwise=False) input_ = input_quantizer(input_) @@ -274,7 +273,7 @@ def op_forward( # Quantize input to FP8 before caching if needed if self.cache_quantized_input: - input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3) input_quantizer.set_usage(rowwise=True, columnwise=False) x = input_quantizer(x) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 47e6d5c8dc..63847ee4d5 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1107,7 +1107,6 @@ class Float8CurrentScalingRecipeState(RecipeState): recipe: Float8CurrentScaling mode: str dtype: tex.DType - device: torch.device def __init__( self, @@ -1122,19 +1121,14 @@ def __init__( self.num_quantizers = num_quantizers self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - def make_quantizers(self) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ Float8CurrentScalingQuantizer( - self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales + self.dtype, force_pow_2_scales=self.recipe.use_power_2_scales ) - for i in range(self.num_quantizers) + for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..bca8bd8e91 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -269,6 +269,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Quantize tensor in-place""" raise NotImplementedError( @@ -281,30 +282,50 @@ def quantize( *, out: Optional[QuantizedTensor] = None, dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override + workspace: Optional[torch.Tensor] = None, ) -> QuantizedTensor: - """Quantize tensor""" + """Quantize tensor + + Parameters + ---------- + tensor: torch.Tensor + High-precision tensor to quantize. + out: QuantizedTensor, optional + Pre-allocated output tensor for in-place quantization. + dtype: torch.dtype, optional + Desired output dtype (used by some subclasses). + workspace: torch.Tensor, optional + Float32 workspace buffer for intermediate values (e.g. + amax and scale for current-scaling FP8). When ``None``, + subclasses may allocate a temporary buffer internally. + """ if out is not None: - return self.update_quantized(tensor, out) + return self.update_quantized(tensor, out, workspace=workspace) if (not self.internal) and torch.is_grad_enabled(): - return _QuantizeFunc.apply(tensor, self.quantize_impl) - return _QuantizeFunc.forward(None, tensor, self.quantize_impl) + return _QuantizeFunc.apply(tensor, self.quantize_impl, workspace) + return _QuantizeFunc.forward(None, tensor, self.quantize_impl, workspace) - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl(self, tensor: torch.Tensor, **kwargs) -> QuantizedTensor: """Quantize tensor implementation""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement quantize_impl function" ) - def multi_quantize(self, list_of_tensors): + def multi_quantize(self, list_of_tensors, *, workspace: Optional[torch.Tensor] = None): """Quantize multiple tensors""" list_of_output_tensors = [] for tensor in list_of_tensors: - list_of_output_tensors.append(self.quantize(tensor)) + list_of_output_tensors.append(self.quantize(tensor, workspace=workspace)) return list_of_output_tensors - def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: + def __call__( + self, + tensor: torch.Tensor, + *, + workspace: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: """Quantize tensor""" - return self.quantize(tensor) + return self.quantize(tensor, workspace=workspace) def make_empty( self, diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index ba3407e13b..bcfd548783 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -24,8 +24,11 @@ def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused tensor: torch.Tensor, quantize_impl: Callable, + workspace: Optional[torch.Tensor] = None, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring + if workspace is not None: + return quantize_impl(tensor, workspace=workspace) return quantize_impl(tensor) @staticmethod @@ -35,7 +38,7 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision - return grad, None + return grad, None, None class _IdentityFunc(torch.autograd.Function): diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5f00bc8017..7b81881b04 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -232,24 +232,18 @@ class Float8CurrentScalingQuantizer(Quantizer): value ("amax") in the tensor is computed directly by scanning the input high-precision tensor, without the need of any history window. - Unlike delayed scaling, scale and amax tensors are not needed to initialize the - quantizer, becuse they are simply GPU buffers that will be filled by current - scaling quantization kernels, instead of using values taken from delayed scaling - history window. Therefore, device parameter is needed for tensor allocation. + This quantizer is stateless: it holds only configuration, not GPU + workspace buffers. Temporary amax/scale workspace is provided + externally via the ``workspace`` argument to :meth:`quantize`, or + allocated on the fly when no workspace is supplied. Both Float8CurrentScalingQuantizer and Float8Quantizer produces Float8Tensor, because they are both per-tensor scaling, ie. one scaling factor per tensor. """ - """Workspace buffer for FP8 scaling factor""" - scale: torch.Tensor - """Workspace buffer for max-abs value""" - amax: torch.Tensor """FP8 datatype""" dtype: TE_DType - """amax update options""" - use_existing_amax: bool """amax reduction options""" with_amax_reduction: bool amax_reduction_group: Optional[dist_group_type] @@ -257,14 +251,14 @@ class Float8CurrentScalingQuantizer(Quantizer): force_pow_2_scales: bool amax_epsilon: float + WORKSPACE_FLOATS_PER_QUANTIZER: int = 2 # [amax, scale] + def __init__( self, fp8_dtype: TE_DType, - device: torch.device, *, rowwise: bool = True, columnwise: bool = True, - use_existing_amax: bool = False, with_amax_reduction: bool = False, amax_reduction_group: Optional[dist_group_type] = None, force_pow_2_scales: bool = False, @@ -273,18 +267,18 @@ def __init__( amax: Optional[torch.Tensor] = None, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - if scale is None: - scale = torch.empty(1, dtype=torch.float32, device=device) - if amax is None: - amax = torch.empty(1, dtype=torch.float32, device=device) - self.scale = scale - self.amax = amax self.dtype = fp8_dtype - self.use_existing_amax = use_existing_amax self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon + # scale/amax are NOT stored by default. They are accepted as + # optional kwargs for call sites that need shared buffers + # (e.g. context-parallel attention). + if scale is not None: + self.scale = scale + if amax is not None: + self.amax = amax def __getstate__(self): """Exclude unpicklable process group from serialized state.""" @@ -292,21 +286,38 @@ def __getstate__(self): state["amax_reduction_group"] = None return state + @staticmethod + def _get_or_alloc_workspace( + workspace: Optional[torch.Tensor], + device: torch.device, + ) -> torch.Tensor: + """Return a float32 workspace of size >= 2: [amax, scale]. + + If workspace is provided, returns it as-is. Otherwise allocates + a fresh temporary buffer. + """ + if workspace is not None: + return workspace + return torch.empty(2, dtype=torch.float32, device=device) + def copy(self) -> Float8CurrentScalingQuantizer: """Create shallow copy""" + kwargs = {} + if hasattr(self, "scale"): + kwargs["scale"] = self.scale + if hasattr(self, "amax"): + kwargs["amax"] = self.amax + quantizer = Float8CurrentScalingQuantizer( fp8_dtype=self.dtype, - device=0, rowwise=self.rowwise_usage, columnwise=self.columnwise_usage, with_amax_reduction=self.with_amax_reduction, amax_reduction_group=self.amax_reduction_group, - use_existing_amax=self.use_existing_amax, force_pow_2_scales=self.force_pow_2_scales, amax_epsilon=self.amax_epsilon, - scale=self.scale, - amax=self.amax, + **kwargs, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -319,6 +330,7 @@ def update_quantized( dst: QuantizedTensor, *, noop_flag: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, ) -> QuantizedTensor: if not isinstance(dst, Float8Tensor): raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor") @@ -329,17 +341,25 @@ def update_quantized( if not src.is_contiguous(): src = src.contiguous() + workspace = self._get_or_alloc_workspace(workspace, src.device) + # Launch cast kernel - tex.quantize(src, self, dst, noop_flag) + tex.quantize(src, self, dst, noop_flag, workspace) # Update FP8 dtype dst._fp8_dtype = self.dtype return dst - def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + def quantize_impl( + self, + tensor: torch.Tensor, + *, + workspace: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + workspace = self._get_or_alloc_workspace(workspace, tensor.device) + return tex.quantize(tensor, self, None, None, workspace) def make_empty( self, diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index c80bc8aaa4..0c045e689f 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -369,8 +369,10 @@ def _cast_master_weights_to_fp8_current_scaling( packed_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device) amaxes = [packed_amaxes[i : i + 1] for i in range(len(params))] - # Collect scales and scale_invs to update them after amax reduction. - scales, scale_invs = [], [] + # Allocate scale buffers and collect scale_invs for the multi-tensor update. + packed_scales = torch.ones(len(params), dtype=torch.float32, device=device) + scales = [packed_scales[i : i + 1] for i in range(len(params))] + scale_invs = [] # --------------------------------------------------------------------------------------------- # Step 1: Iterate through all the none empty master weights and compute amax of them. Store the @@ -397,7 +399,6 @@ def _cast_master_weights_to_fp8_current_scaling( f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" ) - scales.append(quantizer.scale.view(1)) scale_invs.append(model_weight._scale_inv.view(1)) # Compute amax of the master weight and store it in packed_amaxes. From 72134c63bb2ebcd0a5fda6fec29ba7bcc2671b68 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:25:07 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.h | 10 +- transformer_engine/pytorch/csrc/extensions.h | 101 +++++++++++------- .../pytorch/csrc/extensions/activation.cpp | 18 ++-- .../pytorch/csrc/extensions/attention.cpp | 25 +++-- .../pytorch/csrc/extensions/bias.cpp | 4 +- .../pytorch/csrc/extensions/cast.cpp | 24 ++--- .../pytorch/csrc/extensions/pybind.cpp | 8 +- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +- transformer_engine/pytorch/module/base.py | 7 +- .../pytorch/module/grouped_linear.py | 18 +++- .../pytorch/module/layernorm_linear.py | 8 +- .../pytorch/module/layernorm_mlp.py | 34 ++++-- transformer_engine/pytorch/module/linear.py | 8 +- 13 files changed, 167 insertions(+), 108 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 9afa4bc799..20934b9154 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -231,8 +231,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; /*! @brief Quantize to FP8 using provided amax/scale workspace buffers */ - void quantize(const TensorWrapper& input, TensorWrapper& out, - at::Tensor amax, at::Tensor scale, + void quantize(const TensorWrapper& input, TensorWrapper& out, at::Tensor amax, at::Tensor scale, const std::optional& noop_flag = std::nullopt); /*! @brief Quantize to FP8, skipping local amax computation. @@ -241,8 +240,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { * amax (e.g. computed by a fused LN kernel). The amax may still * be reduced across the amax reduction group. */ - void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, - at::Tensor amax, at::Tensor scale, + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, at::Tensor amax, + at::Tensor scale, const std::optional& noop_flag = std::nullopt); private: @@ -257,8 +256,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { */ inline std::pair split_quantizer_workspace(const at::Tensor& workspace) { NVTE_CHECK(workspace.numel() >= 2, "Quantizer workspace must have at least 2 float32 elements"); - return {workspace.slice(0, 0, 1).contiguous(), - workspace.slice(0, 1, 2).contiguous()}; + return {workspace.slice(0, 0, 1).contiguous(), workspace.slice(0, 1, 2).contiguous()}; } class Float8BlockQuantizer : public Quantizer { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 188c27c6ff..b7d0ea7fa7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -92,8 +92,7 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, - const std::optional Bias, + py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); @@ -216,54 +215,77 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = st **************************************************************************************************/ /* GLU (sigmoid gate) */ -py::object glu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object glu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); /* GELU and variants*/ -py::object gelu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object gelu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object geglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object geglu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object qgelu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object qgelu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object qgeglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object qgeglu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); /* ReLU and variants*/ -py::object relu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object relu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object reglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object reglu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object srelu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object srelu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object sreglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object sreglu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); /* Silu and variants*/ -py::object silu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object silu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object swiglu(const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object swiglu(const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw = std::nullopt); +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + std::optional qw = std::nullopt); -py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, std::optional qw, float limit, float alpha); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, + std::optional qw, float limit, float alpha); py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, std::optional qw, float limit, float alpha); @@ -316,33 +338,38 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation = false, - std::optional> quantizer_workspaces = std::nullopt); +std::vector split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation = false, + std::optional> quantizer_workspaces = std::nullopt); /*************************************************************************************************** * Bias gradient fusions **************************************************************************************************/ -std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer, - std::optional quantizer_workspace = std::nullopt); +std::vector bgrad_quantize( + const at::Tensor &input, py::handle py_quantizer, + std::optional quantizer_workspace = std::nullopt); std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer, std::optional qw = std::nullopt); + py::handle quantizer, + std::optional qw = std::nullopt); std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer, std::optional qw = std::nullopt); + py::handle quantizer, + std::optional qw = std::nullopt); std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer, std::optional qw = std::nullopt); + py::handle quantizer, + std::optional qw = std::nullopt); std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer, std::optional qw = std::nullopt); + py::handle quantizer, + std::optional qw = std::nullopt); std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer, std::optional qw = std::nullopt); + py::handle quantizer, + std::optional qw = std::nullopt); /*************************************************************************************************** * Dropout diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 95edf77f4c..412108fd1f 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -16,8 +16,8 @@ using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStrea template py::object activation_helper(const at::Tensor& input, py::handle quantizer, - std::optional quantizer_workspace, - int shape_divisor = 1, Args&&... args) { + std::optional quantizer_workspace, int shape_divisor = 1, + Args&&... args) { init_extension(); // Input tensor @@ -88,8 +88,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); - auto [temp_nvte, _] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype, cs_amax); + auto [temp_nvte, _] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax( + output_shape, fake_dtype, cs_amax); NVTE_SCOPED_GIL_RELEASE({ if constexpr (act_func == nullptr) { act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., @@ -128,8 +128,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, template py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer, - std::optional quantizer_workspace, + py::handle quantizer, std::optional quantizer_workspace, Args&&... args) { init_extension(); @@ -203,8 +202,8 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); - auto [temp_nvte, _] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype, cs_amax); + auto [temp_nvte, _] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax( + input_shape, fake_dtype, cs_amax); NVTE_SCOPED_GIL_RELEASE({ if constexpr (dact_func == nullptr) { dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), @@ -351,7 +350,8 @@ py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, std::optional qw, float limit, float alpha) { - return dactivation_helper(grad, input, quantizer, qw, limit, alpha); + return dactivation_helper(grad, input, quantizer, qw, limit, + alpha); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 4b0999b149..13984e7245 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -55,11 +55,10 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( } // helper function for S and dP quantizers -std::pair quantizer_helper(py::handle quantizer, - const std::vector &shape, DType dtype, - bool create_hp_tensor_for_cs, - std::optional data, - std::optional quantizer_workspace) { +std::pair quantizer_helper( + py::handle quantizer, const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, std::optional data, + std::optional quantizer_workspace) { std::unique_ptr T_quantizer = convert_quantizer(quantizer); TensorWrapper te_T; py::object py_T; @@ -80,13 +79,14 @@ std::pair quantizer_helper(py::handle quantizer, // current scaling auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { - at::Tensor ws = quantizer_workspace.has_value() - ? *quantizer_workspace - : at::empty({2}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + at::Tensor ws = + quantizer_workspace.has_value() + ? *quantizer_workspace + : at::empty({2}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); auto [cs_amax, cs_scale] = split_quantizer_workspace(ws); if (data.has_value()) { - std::tie(te_T, py_T) = - T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, cs_amax, data.value()); + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax( + shape, dtype, cs_amax, data.value()); } else { std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, cs_amax); @@ -112,8 +112,7 @@ std::vector fused_attn_fwd( const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, - py::handle s_quantizer, py::handle o_quantizer, - const std::optional Bias, + py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { // Ensure that cuDNN handle is created on the correct device, @@ -326,7 +325,7 @@ std::vector fused_attn_bwd( const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, - const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, + const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { auto none = py::none(); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 6a3a3f5876..78e56ed0ab 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -211,8 +211,8 @@ std::vector dact_dbias( NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Invalid quantizer for fused dact-amax kernel impl"); auto [cs_amax, cs_scale] = split_quantizer_workspace(*quantizer_workspace); - auto [temp_nvte, temp_py] = - fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype, cs_amax); + auto [temp_nvte, temp_py] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax( + input_shape, grad_output_dtype, cs_amax); NVTE_SCOPED_GIL_RELEASE({ dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); }); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index de36b99d54..5b6c74c6d0 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -251,11 +251,11 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) namespace { -void multi_tensor_quantize_impl(const std::vector &input_list, - std::vector &quantizer_py_list, - std::vector> &quantizer_cpp_list, - std::vector &output_list, - const std::optional> &workspaces = std::nullopt) { +void multi_tensor_quantize_impl( + const std::vector &input_list, std::vector &quantizer_py_list, + std::vector> &quantizer_cpp_list, + std::vector &output_list, + const std::optional> &workspaces = std::nullopt) { // Check number of tensors const size_t num_tensors = input_list.size(); NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, @@ -298,7 +298,8 @@ void multi_tensor_quantize_impl(const std::vector &input_list, for (size_t i = 0; i < num_tensors; ++i) { if (workspaces.has_value() && detail::IsFloat8CurrentScalingQuantizers(quantizer_py_list[i].ptr())) { - auto *quantizer_cs = dynamic_cast(quantizer_cpp_list[i].get()); + auto *quantizer_cs = + dynamic_cast(quantizer_cpp_list[i].get()); auto [amax, scale] = split_quantizer_workspace((*workspaces)[i]); quantizer_cs->quantize(input_list[i], output_list[i], amax, scale); } else { @@ -1260,11 +1261,10 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, } // namespace -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation, - std::optional> quantizer_workspaces) { +std::vector split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation, + std::optional> quantizer_workspaces) { init_extension(); // Check number of tensors @@ -1412,7 +1412,7 @@ std::vector split_quantize(const at::Tensor &tensor, default: // General multi-tensor quantization multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list, - quantizer_workspaces); + quantizer_workspaces); } return output_py_list; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d348f4bdd1..a92f34ecbe 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -181,8 +181,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), - py::arg("quantizer_workspace") = py::none(), - py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("quantizer_workspace") = py::none(), py::arg("limit") = 7.0f, + py::arg("alpha") = 1.702f); /* Backward of GLU */ m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); @@ -212,8 +212,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer"), py::arg("quantizer_workspace") = py::none()); m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer"), py::arg("quantizer_workspace") = py::none(), - py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("quantizer"), py::arg("quantizer_workspace") = py::none(), py::arg("limit") = 7.0f, + py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index cb9c9bb99e..74168fc0fd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -742,8 +742,7 @@ std::pair Float8CurrentScalingQuantizer::creat std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, - DType dtype, - at::Tensor amax, + DType dtype, at::Tensor amax, std::optional data) { amax.zero_(); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) @@ -849,8 +848,8 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, - bool compute_amax, - at::Tensor amax, at::Tensor scale) { + bool compute_amax, at::Tensor amax, + at::Tensor scale) { auto stream = at::cuda::getCurrentCUDAStream(); // Nothing to be done if input is empty @@ -910,8 +909,7 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW } void Float8CurrentScalingQuantizer::quantize_with_amax( - TensorWrapper& input, TensorWrapper& out, - at::Tensor amax, at::Tensor scale, + TensorWrapper& input, TensorWrapper& out, at::Tensor amax, at::Tensor scale, const std::optional& noop_flag) { NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), "Input does not use the appropriate amax tensor"); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cd7500733e..32c1da2c69 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -824,7 +824,9 @@ def _init_quantizer_workspace(self, recipe: Recipe) -> None: if self._quantizer_workspace is None or self._quantizer_workspace.numel() < 2: self._quantizer_workspace = torch.zeros( - 2, dtype=torch.float32, device="cuda", + 2, + dtype=torch.float32, + device="cuda", ) def get_fp8_meta_tensors(self) -> None: @@ -1282,7 +1284,8 @@ def grad_output_preprocess( grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize( - grad_output, quantizer, + grad_output, + quantizer, quantizer_workspace=quantizer_workspace, ) if not isinstance(grad_output, QuantizedTensorStorage): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cf4267b22c..cb3bf6f7b8 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -180,7 +180,9 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, workspace_dtype=activation_dtype, - quantizer_workspace=quantizer_workspaces[i] if quantizer_workspaces is not None else None, + quantizer_workspace=( + quantizer_workspaces[i] if quantizer_workspaces is not None else None + ), ) weights_fp8.append(weight_fp8) @@ -352,7 +354,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_biases[i], grad_output[i] = tex.bgrad_quantize( grad_output_mats[i], ctx.grad_output_quantizers[i], - quantizer_workspace=ctx.quantizer_workspaces[i] if ctx.quantizer_workspaces is not None else None, + quantizer_workspace=( + ctx.quantizer_workspaces[i] + if ctx.quantizer_workspaces is not None + else None + ), ) else: # Unfused bias grad and multi-tensor quantize @@ -459,8 +465,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers, - quantizer_workspaces=ctx.quantizer_workspaces) + inputmats = tex.split_quantize( + inp_view, + ctx.m_splits, + ctx.input_quantizers, + quantizer_workspaces=ctx.quantizer_workspaces, + ) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fdc1ec8fad..526e094a82 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -824,14 +824,18 @@ def backward( ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.input_quantizer(ln_out_total, workspace=ctx.quantizer_workspace) + ln_out_total = ctx.input_quantizer( + ln_out_total, workspace=ctx.quantizer_workspace + ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output, workspace=ctx.quantizer_workspace) + grad_output = ctx.grad_output_quantizer( + grad_output, workspace=ctx.quantizer_workspace + ) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 98434b6419..d06312cbef 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -605,7 +605,8 @@ def _forward( act_out = fc2_input_quantizer(act_out, workspace=quantizer_workspace) else: act_out = activation_func( - fc1_out, fc2_input_quantizer, + fc1_out, + fc2_input_quantizer, quantizer_workspace=quantizer_workspace, **act_params, ) @@ -614,7 +615,8 @@ def _forward( act_out = activation_func(fc1_out, None, **act_params) else: act_out = activation_func( - fc1_out, fc2_input_quantizer, + fc1_out, + fc2_input_quantizer, quantizer_workspace=quantizer_workspace, **act_params, ) @@ -1057,7 +1059,10 @@ def backward( grad_output, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer, + ctx, + grad_outputs[0], + True, + ctx.fc2_grad_output_quantizer, ctx.quantizer_workspace, ) @@ -1208,14 +1213,18 @@ def backward( act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) - act_out = ctx.fc2_input_quantizer(act_out, workspace=ctx.quantizer_workspace) + act_out = ctx.fc2_input_quantizer( + act_out, workspace=ctx.quantizer_workspace + ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.fc2_grad_output_quantizer(grad_output, workspace=ctx.quantizer_workspace) + grad_output = ctx.fc2_grad_output_quantizer( + grad_output, workspace=ctx.quantizer_workspace + ) # Whether to set grad arg in general_gemm grad_arg = True @@ -1333,10 +1342,13 @@ def fc2_wgrad_gemm( or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) - dact = ctx.fc1_grad_output_quantizer(dact, workspace=ctx.quantizer_workspace) + dact = ctx.fc1_grad_output_quantizer( + dact, workspace=ctx.quantizer_workspace + ) else: fc1_bias_grad, dact = tex.bgrad_quantize( - dact, ctx.fc1_grad_output_quantizer, + dact, + ctx.fc1_grad_output_quantizer, quantizer_workspace=ctx.quantizer_workspace, ) else: @@ -1451,7 +1463,9 @@ def fc2_wgrad_gemm( ln_out_total.update_usage(columnwise_usage=True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) - ln_out_total = ctx.fc1_input_quantizer(ln_out_total, workspace=ctx.quantizer_workspace) + ln_out_total = ctx.fc1_input_quantizer( + ln_out_total, workspace=ctx.quantizer_workspace + ) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and @@ -1461,7 +1475,9 @@ def fc2_wgrad_gemm( dact.update_usage(columnwise_usage=True) else: ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - dact = ctx.fc1_grad_output_quantizer(dact, workspace=ctx.quantizer_workspace) + dact = ctx.fc1_grad_output_quantizer( + dact, workspace=ctx.quantizer_workspace + ) # Output buffer for overlapping grad input # reduce-scatter with wgrad GEMM diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 91c45a571e..0411c17cb9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -777,7 +777,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) - inputmat_total = ctx.input_quantizer(inputmat_total, workspace=ctx.quantizer_workspace) + inputmat_total = ctx.input_quantizer( + inputmat_total, workspace=ctx.quantizer_workspace + ) # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and @@ -819,7 +821,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - grad_output = ctx.grad_output_quantizer(grad_output, workspace=ctx.quantizer_workspace) + grad_output = ctx.grad_output_quantizer( + grad_output, workspace=ctx.quantizer_workspace + ) # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD