Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,10 @@ def forward(
fp8_recipe = fp8_meta["local_recipes"][0]
if fp8_recipe.float8_current_scaling():
S_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=S_quantizer.dtype, device="cuda"
fp8_dtype=S_quantizer.dtype,
)
dP_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=dP_quantizer.dtype, device="cuda"
fp8_dtype=dP_quantizer.dtype,
)

if "2" in qkv_layout or "3" in qkv_layout:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2183,10 +2183,7 @@ def print_quantizers(
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
print(f"{label} >> {names[i]:14s}: {type_str}")


def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
Expand Down
39 changes: 27 additions & 12 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ class Float8Quantizer : public Quantizer {

class Float8CurrentScalingQuantizer : public Quantizer {
public:
at::Tensor scale;
at::Tensor scale_inv;
at::Tensor amax;
DType dtype;
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
Expand All @@ -217,33 +215,50 @@ class Float8CurrentScalingQuantizer : public Quantizer {
py::object quantizer, const std::optional<at::Tensor>& first_dims, size_t logical_first_dim,
size_t logical_last_dim) const override;

/*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
/*! @brief Construct an unquantized tensor with an amax buffer.
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
* The provided amax tensor is zeroed out and set on the output tensor.
* Most TE kernels that output amax expect amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
const std::vector<size_t>& shape, DType dtype, at::Tensor amax,
std::optional<at::Tensor> data = std::nullopt);

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

/*! @brief Quantize to FP8 (virtual fallback, allocates local amax/scale) */
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;

/*! @brief Quantize to FP8, skipping local amax computation
/*! @brief Quantize to FP8 using provided amax/scale workspace buffers */
void quantize(const TensorWrapper& input, TensorWrapper& out, at::Tensor amax, at::Tensor scale,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);

/*! @brief Quantize to FP8, skipping local amax computation.
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
* The provided amax tensor is assumed to already hold the local
* amax (e.g. computed by a fused LN kernel). The amax may still
* be reduced across the amax reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, at::Tensor amax,
at::Tensor scale,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);

private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
const std::optional<TensorWrapper>& noop_flag, bool compute_amax,
at::Tensor amax, at::Tensor scale);
};

/*! @brief Extract amax and scale from a quantizer workspace tensor.
*
* Workspace layout: [amax, scale] (2 float32).
*/
inline std::pair<at::Tensor, at::Tensor> split_quantizer_workspace(const at::Tensor& workspace) {
NVTE_CHECK(workspace.numel() >= 2, "Quantizer workspace must have at least 2 float32 elements");
return {workspace.slice(0, 0, 1).contiguous(), workspace.slice(0, 1, 2).contiguous()};
}

class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
Expand Down
107 changes: 70 additions & 37 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,57 +215,80 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
**************************************************************************************************/

/* GLU (sigmoid gate) */
py::object glu(const at::Tensor &input, py::handle quantizer);
py::object glu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

/* GELU and variants*/
py::object gelu(const at::Tensor &input, py::handle quantizer);
py::object gelu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object geglu(const at::Tensor &input, py::handle quantizer);
py::object geglu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object qgelu(const at::Tensor &input, py::handle quantizer);
py::object qgelu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object qgeglu(const at::Tensor &input, py::handle quantizer);
py::object qgeglu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

/* ReLU and variants*/
py::object relu(const at::Tensor &input, py::handle quantizer);
py::object relu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object reglu(const at::Tensor &input, py::handle quantizer);
py::object reglu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object srelu(const at::Tensor &input, py::handle quantizer);
py::object srelu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object sreglu(const at::Tensor &input, py::handle quantizer);
py::object sreglu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

/* Silu and variants*/
py::object silu(const at::Tensor &input, py::handle quantizer);
py::object silu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object swiglu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha);
py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer,
std::optional<at::Tensor> qw, float limit, float alpha);

py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
float limit, float alpha);
std::optional<at::Tensor> qw, float limit, float alpha);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
Expand All @@ -278,7 +301,8 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
float eps, py::object ln_out, py::handle quantizer,
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma);
const bool zero_centered_gamma,
std::optional<at::Tensor> quantizer_workspace = std::nullopt);

/***************************************************************************************************
* RMSNorm
Expand All @@ -295,14 +319,16 @@ std::vector<py::object> rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &

std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer, DType otype,
const int sm_margin, const bool zero_centered_gamma);
const int sm_margin, const bool zero_centered_gamma,
std::optional<at::Tensor> quantizer_workspace = std::nullopt);

/***************************************************************************************************
* Cast
**************************************************************************************************/

py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop_flag);
std::optional<at::Tensor> noop_flag,
std::optional<at::Tensor> workspace = std::nullopt);

py::object dequantize(const py::handle &input, DType otype);

Expand All @@ -312,31 +338,38 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation = false);
std::vector<py::object> split_quantize(
const at::Tensor &tensor, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list, bool disable_bulk_allocation = false,
std::optional<std::vector<at::Tensor>> quantizer_workspaces = std::nullopt);

/***************************************************************************************************
* Bias gradient fusions
**************************************************************************************************/

std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);
std::vector<py::object> bgrad_quantize(
const at::Tensor &input, py::handle py_quantizer,
std::optional<at::Tensor> quantizer_workspace = std::nullopt);

std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
py::handle quantizer,
std::optional<at::Tensor> qw = std::nullopt);

/***************************************************************************************************
* Dropout
Expand Down
Loading