Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,618 changes: 1,618 additions & 0 deletions tests/pytorch/test_hybrid_quantization.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;

size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if (tile_scales_inv_c != nullptr) {
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}

if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
size_t row_idx = tile_id_x;
size_t col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
Expand All @@ -189,7 +191,9 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
if (output_c != nullptr) {
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
}
}

// Step 4: store transpose into shared memory
Expand Down Expand Up @@ -388,13 +392,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;

size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if (tile_scales_inv_c != nullptr) {
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}

if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
size_t row_idx = tile_id_x;
size_t col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
Expand Down Expand Up @@ -433,8 +439,10 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
if (output_c != nullptr) {
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
}

if constexpr (kReturnTranspose) {
Expand Down Expand Up @@ -492,19 +500,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
"with MXFP8, which requires using power of two scaling factors.");
}

NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const bool return_identity = output.dptr != nullptr;
if (return_identity) {
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
}
NVTE_CHECK(return_identity || return_transpose,
"At least one of rowwise or columnwise output must be requested.");
const size_t row_length = input.shape.size() > 0 ? input.shape.back() : 1;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
}

NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions.");

size_t scale_k = scale_inv.shape[1];

const size_t scale_stride_x = 1;
const size_t scale_stride_y = scale_k;
size_t scale_k = 0;
const size_t scale_stride_x = return_identity ? 1 : 0;
size_t scale_stride_y = 0;
if (return_identity) {
NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions.");
scale_k = scale_inv.shape[1];
scale_stride_y = scale_k;
}

size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
Expand All @@ -522,22 +537,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
if (return_identity) {
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
}

NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions.");

scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}

const auto out_dtype = return_identity ? output.dtype : output_t.dtype;

const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,

TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
out_dtype, OutputType,

TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor import HybridQuantizer
from transformer_engine.pytorch.tensor import HybridQuantizedTensorStorage
from transformer_engine.pytorch.tensor import HybridQuantizedTensor

try:
torch._dynamo.config.error_on_nested_jit_trace = False
Expand Down
37 changes: 37 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..quantized_tensor import Quantizer
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom
from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage
from ..custom_recipes.gemm import custom_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer

Expand Down Expand Up @@ -69,6 +70,36 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
return 0.0


def _unwrap_hybrid_A(tensor, layout):
"""Extract the direction-appropriate native sub-storage for GEMM operand A.

Operand A's data direction is determined by its transpose flag (layout[0]):
T (transposed) → rowwise sub-storage (.data consumed by C++)
N (not-transposed) → columnwise sub-storage (.columnwise_data consumed by C++)
For non-hybrid tensors this is a no-op passthrough.
"""
if not isinstance(tensor, HybridQuantizedTensorStorage):
return tensor
if layout[0] == "T":
return tensor.rowwise_sub_storage
return tensor.columnwise_sub_storage


def _unwrap_hybrid_B(tensor, layout):
"""Extract the direction-appropriate native sub-storage for GEMM operand B.

Operand B's data direction is determined by its transpose flag (layout[1]):
N (not-transposed) → rowwise sub-storage (.data consumed by C++)
T (transposed) → columnwise sub-storage (.columnwise_data consumed by C++)
For non-hybrid tensors this is a no-op passthrough.
"""
if not isinstance(tensor, HybridQuantizedTensorStorage):
return tensor
if layout[1] == "N":
return tensor.rowwise_sub_storage
return tensor.columnwise_sub_storage


def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -95,6 +126,9 @@ def general_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"

A = _unwrap_hybrid_A(A, layout)
B = _unwrap_hybrid_B(B, layout)

alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
workspace = get_cublas_workspace(A.device.index, ub is not None, False)
Expand Down Expand Up @@ -204,6 +238,9 @@ def general_grouped_gemm(
"""
num_gemms = len(A)

A = [_unwrap_hybrid_A(a, layout) for a in A]
B = [_unwrap_hybrid_B(b, layout) for b in B]

transa = layout[0] == "T"
transb = layout[1] == "T"

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.hybrid_tensor import HybridQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
Expand Down Expand Up @@ -1258,8 +1259,9 @@ def grad_output_preprocess(
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
if isinstance(quantizer, Float8BlockQuantizer):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
if isinstance(quantizer, (Float8BlockQuantizer, HybridQuantizer)):
# Float8BlockQuantizer: unfused until cast_transpose + dgrad is ready.
# HybridQuantizer: tex.bgrad_quantize doesn't recognize hybrid quantizers.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
Expand Down
66 changes: 63 additions & 3 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,50 @@
prepare_for_saving,
restore_from_func_ctx,
)
from ..tensor.hybrid_tensor import HybridQuantizer
from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState


def _has_hybrid_quantizer(quantizers):
"""Check if any quantizer in the list is a HybridQuantizer."""
return any(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)


def _hybrid_split_quantize(tensor, m_splits, quantizers):
"""Grouped split+quantize for HybridQuantizer lists.

Runs tex.split_quantize twice (once per direction with the native
sub-quantizers), then zips the results into HybridQuantizedTensorStorage.
Non-hybrid quantizers in the list fall back to per-split Python quantize.
"""
from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage as HybridStorage

row_quantizers = [q.rowwise_quantizer for q in quantizers]
col_quantizers = [q.columnwise_quantizer for q in quantizers]

row_results = tex.split_quantize(tensor, m_splits, row_quantizers)
col_results = tex.split_quantize(tensor, m_splits, col_quantizers)

return [
HybridStorage(
rowwise_storage=row,
columnwise_storage=col,
rowwise_quantizer=rq,
columnwise_quantizer=cq,
quantizer=q,
fake_dtype=tensor.dtype,
Comment on lines +64 to +86
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _hybrid_split_quantize crashes on mixed-quantizer lists

_has_hybrid_quantizer returns True if any quantizer in the list is a HybridQuantizer, but _hybrid_split_quantize unconditionally accesses q.rowwise_quantizer and q.columnwise_quantizer for every element. If the list contains even one non-hybrid quantizer, this raises AttributeError at runtime.

The docstring claims "Non-hybrid quantizers in the list fall back to per-split Python quantize", but no such fallback exists in the implementation:

row_quantizers = [q.rowwise_quantizer for q in quantizers]  # crashes if q is not HybridQuantizer
col_quantizers = [q.columnwise_quantizer for q in quantizers]

Either the condition at the call site should assert all-or-nothing hybrid (all(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)), or the function needs to implement the per-element fallback its docstring promises. The same issue applies to all three call sites in both the forward and backward paths.

)
for row, col, rq, cq, q in zip(
row_results,
col_results,
row_quantizers,
col_quantizers,
quantizers,
)
]


__all__ = ["GroupedLinear"]


Expand Down Expand Up @@ -144,7 +185,8 @@ def forward(
)
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8 and not debug:
hybrid = _has_hybrid_quantizer(input_quantizers)
if fp8 and not debug and not hybrid:
# Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
Expand All @@ -154,6 +196,8 @@ def forward(
input_quantizers,
disable_bulk_allocation=cpu_offloading,
)
elif fp8 and hybrid:
inputmats = _hybrid_split_quantize(inp_view, m_splits, input_quantizers)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand Down Expand Up @@ -338,7 +382,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8 and not ctx.debug:
grad_output_hybrid = _has_hybrid_quantizer(ctx.grad_output_quantizers)
if ctx.fp8 and not ctx.debug and not grad_output_hybrid:
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe
Expand All @@ -365,6 +410,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.fp8 and grad_output_hybrid:
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = _hybrid_split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
elif ctx.debug:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
for i in range(ctx.num_gemms):
Expand Down Expand Up @@ -451,8 +506,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
else:
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8 and not ctx.debug:
input_hybrid = _has_hybrid_quantizer(ctx.input_quantizers)
if ctx.fp8 and not ctx.debug and not input_hybrid:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.fp8 and input_hybrid:
inputmats = _hybrid_split_quantize(
inp_view, ctx.m_splits, ctx.input_quantizers
)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
)
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.hybrid_tensor import HybridQuantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
Expand Down Expand Up @@ -206,12 +207,14 @@ def forward(
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
custom = is_custom(input_quantizer)
hybrid = isinstance(input_quantizer, HybridQuantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
and not hybrid
)

# Apply normalization
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.hybrid_tensor import HybridQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import (
is_cpu_offload_enabled,
Expand Down Expand Up @@ -390,12 +391,14 @@ def _forward(
# for debug: : layernorm output = High precision to enable processing of this norm

custom = is_custom(fc1_input_quantizer)
hybrid = isinstance(fc1_input_quantizer, HybridQuantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not custom
and not hybrid
)

# Apply normalization
Expand Down
Loading
Loading