From 2f1203e5874efac764a3e92b732c83a6abf896e9 Mon Sep 17 00:00:00 2001 From: prefill-dev2 Date: Tue, 9 Jun 2026 22:03:05 -0700 Subject: [PATCH 1/4] [cuda][prefill] window-aware SDPA: skip fully-masked KV blocks (idea #1) Block-sparse early-exit in _sdpa_fwd_kernel_body: skip KV blocks that are entirely masked (sliding-window via HAS_MASK sum==0, causal via start_n>max_seq_pos). Exact (skipped blocks are x1,+0 no-ops). Prefill +46-88% all lengths; decode safe; SDPA nsys 58.1%->18.5%. Numerically bf16-exact vs dense+mask (unit test). --- backends/cuda/triton/kernels/sdpa.py | 102 ++++++++++++++++----------- 1 file changed, 62 insertions(+), 40 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 9f42a474b36..fb665e538bf 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -422,21 +422,22 @@ def _sdpa_fwd_kernel_body( offs_n_init = tl.arange(0, BLOCK_N) + # Window-aware early-exit. A KV block that is fully masked (sliding-window + # or causal) contributes nothing to the online softmax — every entry is + # -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up + # front and skip their K/V loads and both matmuls. This is exact: it only + # skips work the mask would have zeroed out anyway. At seq=2048 the 50 + # sliding-window(1024) layers and the 10 causal layers each leave roughly + # half (or more) of their KV blocks fully masked, so this is a large cut to + # the dominant prefill cost. The skip condition is a CTA-wide reduction, so + # the branch is uniform and turns into a real skip (not predication). + if IS_CAUSAL: + max_seq_pos = tl.max(seq_pos) + for start_n in tl.range(0, Lk, BLOCK_N): offs_n = start_n + offs_n_init - # K load: uniform (single KV head, shared across all Q heads in tile) - k_ptrs = K_ptr + ( - b * stride_kb - + h_kv * stride_kh - + (offs_n[:, None] * stride_kn) - + (offs_d[None, :] * stride_kd) - ) - k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) - k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) - - qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) - + # Decide whether any row in this tile actually attends to this KV block. if HAS_MASK: mask_ptrs = Mask_ptr + ( b * stride_mb @@ -445,39 +446,60 @@ def _sdpa_fwd_kernel_body( ) mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk) mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) - qk = tl.where( - mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) + block_active = tl.sum(mask_block.to(tl.int32)) > 0 + elif IS_CAUSAL: + # Block is entirely in the future for every row -> skip. + block_active = start_n <= max_seq_pos + else: + block_active = True + + if block_active: + # K load: uniform (single KV head, shared across Q heads in tile) + k_ptrs = K_ptr + ( + b * stride_kb + + h_kv * stride_kh + + (offs_n[:, None] * stride_kn) + + (offs_d[None, :] * stride_kd) ) + k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) - if IS_CAUSAL: - causal = offs_n[None, :] > seq_pos[:, None] - qk = tl.where( - causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk - ) + qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) - m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) - safe_diff = tl.where( - m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") - ) - p_f32 = tl.exp(safe_diff).to(tl.float32) - l_ij = tl.sum(p_f32, axis=1).to(tl.float32) - safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) - alpha = tl.exp(safe_alpha_diff).to(tl.float32) + if HAS_MASK: + qk = tl.where( + mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) + ) - # V load: uniform (single KV head) - v_ptrs = V_ptr + ( - b * stride_vb - + h_kv * stride_vh - + (offs_n[:, None] * stride_vn) - + (offs_d[None, :] * stride_vd) - ) - v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) - v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + if IS_CAUSAL: + causal = offs_n[None, :] > seq_pos[:, None] + qk = tl.where( + causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk + ) - p_bf16 = p_f32.to(tl.bfloat16) - acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) - l_i = (l_i * alpha + l_ij).to(tl.float32) - m_i = m_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) + safe_diff = tl.where( + m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") + ) + p_f32 = tl.exp(safe_diff).to(tl.float32) + l_ij = tl.sum(p_f32, axis=1).to(tl.float32) + safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) + alpha = tl.exp(safe_alpha_diff).to(tl.float32) + + # V load: uniform (single KV head) + v_ptrs = V_ptr + ( + b * stride_vb + + h_kv * stride_vh + + (offs_n[:, None] * stride_vn) + + (offs_d[None, :] * stride_vd) + ) + v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + + p_bf16 = p_f32.to(tl.bfloat16) + acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) + l_i = (l_i * alpha + l_ij).to(tl.float32) + m_i = m_ij inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) acc = acc * inv_l_i[:, None] From 390238ee00900d0a68b9508938f3cb6c7d873b8b Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 11 Jun 2026 21:30:16 -0700 Subject: [PATCH 2/4] [cuda] GGUF Q6_K real packed INT6 (W6A8 dp4a) + GGUF CI export Add a genuine 6-bit packed weight path for GGUF Q6_K on the CUDA backend, parallel to the int4/int8 plain_mm paths: - int6_plain_mm CUDA shim (W6A8 dp4a; ql/qh planes; spread2; -32 symmetric offset) - CudaPackedInt6Tensor (ql/qh + per-group bf16 scale; symmetric, no zero tensor) - int6_dispatch: F.linear routing (M<=4 -> executorch_cuda::int6_plain_mm op, M>4 -> dequant) - backend fallback-kernel + custom_ops_to_c_shims registration; CMake build - GGUF Q6_K: gguf_loader returns the native torchao IntxUnpackedToInt8Tensor and the backend packer (pack_cuda.pack_linear_for_cuda) repacks a symmetric Q6_K weight into CudaPackedInt6Tensor -- mirroring Int4Tensor -> CudaCoalescedInt4Tensor, so the loader stays backend-agnostic; dequantize_weight handles the tied embedding - tests: int6 gtest, test_int6_dispatch.py, pack round-trip; fix stale int4/int6 type asserts CI (export_model_artifact.sh, gemma4_31b): download the Q4_K_M GGUF from unsloth/gemma-4-31B-it-GGUF (tokenizer from unsloth/gemma-4-31B-it) and run the inference sanity check + export via the GGUF loader (--gguf) instead of the prequantized HF checkpoint. Signed-off-by: gasoonjia --- .ci/scripts/export_model_artifact.sh | 20 +- backends/cuda/CMakeLists.txt | 1 + backends/cuda/cuda_backend.py | 7 + backends/cuda/packed_int6_tensor.py | 209 +++++++++++ .../cuda/quantize_op_dispatch/__init__.py | 5 +- .../cuda/quantize_op_dispatch/_library.py | 7 +- .../quantize_op_dispatch/int6_dispatch.py | 116 ++++++ backends/cuda/runtime/shims/int6_plain_mm.cu | 81 ++++ backends/cuda/runtime/shims/int6_plain_mm.cuh | 353 ++++++++++++++++++ backends/cuda/runtime/shims/int6_plain_mm.h | 61 +++ .../cuda/runtime/shims/tests/CMakeLists.txt | 5 +- .../test_aoti_torch_cuda_int6_plain_mm.cpp | 306 +++++++++++++++ backends/cuda/tests/test_int6_dispatch.py | 226 +++++++++++ examples/models/gemma4_31b/gguf_loader.py | 12 +- examples/models/gemma4_31b/quant/pack_cuda.py | 39 +- examples/models/gemma4_31b/quant/quantize.py | 7 + .../gemma4_31b/quant/tests/test_pack_cuda.py | 104 ++++++ .../gemma4_31b/tests/test_cuda_pipeline.py | 11 +- 18 files changed, 1545 insertions(+), 25 deletions(-) create mode 100644 backends/cuda/packed_int6_tensor.py create mode 100644 backends/cuda/quantize_op_dispatch/int6_dispatch.py create mode 100644 backends/cuda/runtime/shims/int6_plain_mm.cu create mode 100644 backends/cuda/runtime/shims/int6_plain_mm.cuh create mode 100644 backends/cuda/runtime/shims/int6_plain_mm.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp create mode 100644 backends/cuda/tests/test_int6_dispatch.py diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index db447bb907f..e9218dce625 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -467,21 +467,27 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then exit 0 fi -# Gemma 4 31B uses a prequantized checkpoint and custom export script +# Gemma 4 31B: download the Q4_K_M GGUF and export via the GGUF loader if [ "$MODEL_NAME" = "gemma4_31b" ]; then pip install safetensors huggingface_hub gguf - # Download prequantized model outside OUTPUT_DIR to avoid uploading on failure + # Download GGUF + tokenizer outside OUTPUT_DIR to avoid uploading on failure. + # The unsloth GGUF repo ships the .gguf but no tokenizer.json, so the tokenizer + # is fetched from the (non-GGUF) unsloth/gemma-4-31B-it repo. LOCAL_MODEL_DIR=$(mktemp -d) INDUCTOR_CACHE=$(mktemp -d) trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT - python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')" + GGUF_FILE="gemma-4-31B-it-Q4_K_M.gguf" + python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it-GGUF', '${GGUF_FILE}', local_dir='${LOCAL_MODEL_DIR}')" + python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it', 'tokenizer.json', local_dir='${LOCAL_MODEL_DIR}')" + GGUF_PATH="${LOCAL_MODEL_DIR}/${GGUF_FILE}" - # Sanity check: run inference on the prequantized model + # Sanity check: run inference on the GGUF model echo "::group::Inference sanity check" INFERENCE_OUTPUT=$(python -m executorch.examples.models.gemma4_31b.inference \ - --prequantized "$LOCAL_MODEL_DIR" \ + --gguf "$GGUF_PATH" \ + --tokenizer-path "${LOCAL_MODEL_DIR}/tokenizer.json" \ --prompt "What is the capital of France?" \ --max-new-tokens 32 \ --temperature 0 \ @@ -494,13 +500,13 @@ if [ "$MODEL_NAME" = "gemma4_31b" ]; then echo "::endgroup::" # Copy tokenizer for the runner - cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" + cp "${LOCAL_MODEL_DIR}/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" # Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues) echo "::group::Export" TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ python -m executorch.examples.models.gemma4_31b.export \ - --prequantized "$LOCAL_MODEL_DIR" \ + --gguf "$GGUF_PATH" \ --output-dir "${OUTPUT_DIR}" echo "::endgroup::" diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 0ce48d85e92..51b459f02fa 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -114,6 +114,7 @@ if(CMAKE_CUDA_COMPILER) _aoti_cuda_shim_sources runtime/shims/int4mm.cu runtime/shims/int4_plain_mm.cu + runtime/shims/int6_plain_mm.cu runtime/shims/int8_plain_mm.cu runtime/shims/sort.cu runtime/shims/rand.cu diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index c07cc29b102..f9f23a842f9 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -231,6 +231,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "aoti_torch_cuda_randint_low_out": None, "executorch_cuda::int4_plain_mm": None, "aoti_torch_cuda_int4_plain_mm": None, + "executorch_cuda::int6_plain_mm": None, + "aoti_torch_cuda_int6_plain_mm": None, "executorch_cuda::int8_plain_mm": None, "aoti_torch_cuda_int8_plain_mm": None, } @@ -314,6 +316,11 @@ def get_aoti_compile_options( "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " "AtenTensorHandle, int64_t, AtenTensorHandle*)" ], + torch.ops.executorch_cuda.int6_plain_mm.default: [ + "AOTITorchError aoti_torch_cuda_int6_plain_mm(" + "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " + "AtenTensorHandle, int64_t, AtenTensorHandle*)" + ], torch.ops.executorch_cuda.int8_plain_mm.default: [ "AOTITorchError aoti_torch_cuda_int8_plain_mm(" "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " diff --git a/backends/cuda/packed_int6_tensor.py b/backends/cuda/packed_int6_tensor.py new file mode 100644 index 00000000000..104ed5bbfa0 --- /dev/null +++ b/backends/cuda/packed_int6_tensor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""ExecuTorch-internal packed-INT6 tensor for the CUDA W6A8 dp4a decode kernel. + +``CudaPackedInt6Tensor`` is an ExecuTorch-internal tensor subclass that stores a +genuine 6-bit packed weight (0.75 B/elem), used for GGUF Q6_K weights. Unlike +the int8 path (``IntxUnpackedToInt8Tensor``, one int8 per 6-bit value), this +format wastes no bits and carries no zero tensor — Q6_K is symmetric. + +The stored value is ``u = q + 32`` in ``[0, 63]`` (``q`` in ``[-32, 31]``); the +constant ``-32`` offset is applied in the decode kernel. The 6 bits are split +into two planes that mirror the INT4 nibble layout so the kernel can reuse the +INT4 even/odd extraction verbatim: + + ql : (N, K/2) uint8 — low-nibble plane, nibble-packed even/odd + (``ql[:, j] = lo[:, 2j] | (lo[:, 2j+1] << 4)``, ``lo = u & 0xF``). + qh : (N, K/4) uint8 — high-2-bit plane, 4 values/byte, arranged per + 32-weight chunk as ``hi_even_packed[4]`` then ``hi_odd_packed[4]``; + each byte holds the four 2-bit highs (``hi = (u >> 4) & 0x3``) of one + 8-weight dp4a word, bit field ``j`` (bits ``2j..2j+1``) = the high 2 + bits of that word's ``j``-th even/odd weight. + scale : (N, K/gs) bf16 — per-group scales, row-major (already coalesced; the + decode kernel reads it row-for-row, no transpose). + +The pack/unpack helpers (:func:`pack_int6`, :func:`unpack_int6`) must stay in +lockstep with ``int6_plain_mm.cuh`` (the decode kernel) — the per-32-weight +``hi_even``/``hi_odd`` byte order is the single most error-prone detail and is +covered by the pack round-trip and the C++ gtest. +""" + +from typing import List, Optional, Tuple + +import torch +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "CudaPackedInt6Tensor", + "pack_int6", + "unpack_int6", +] + + +def pack_int6(q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Pack symmetric Q6_K int values into the (ql, qh) planes. + + Args: + q: (N, K) integer tensor with values in ``[-32, 31]``. + + Returns: + ``(ql, qh)`` where ``ql`` is ``(N, K/2)`` uint8 and ``qh`` is + ``(N, K/4)`` uint8 (see the module docstring for the layout). + """ + if q.dim() != 2: + raise ValueError(f"pack_int6 expects a 2-D tensor, got shape {tuple(q.shape)}") + N, K = int(q.shape[0]), int(q.shape[1]) + if K % 32 != 0: + raise ValueError(f"K={K} must be a multiple of 32 for INT6 packing") + + # All intermediates are uint8 (values fit in a byte) to keep peak memory low + # — important for the ~1.4B-element tied token embedding. + u = (q.to(torch.int16) + 32).to(torch.uint8) # [0, 63] + lo = u & 0xF # low nibble (uint8) + hi = (u >> 4) & 0x3 # high 2 bits (uint8) + + # ql: nibble-pack the low plane even/odd, exactly like the INT4 path. + ql = lo[:, 0::2] | (lo[:, 1::2] << 4) # (N, K/2) uint8 + + # qh: per 32-weight chunk -> [hi_even_packed[4], hi_odd_packed[4]]; each byte + # packs the four 2-bit highs of one 8-weight dp4a word, field j at bits 2j. + chunks = K // 32 + hw = hi.reshape(N, chunks, 4, 8) # (N, chunk, word, pos-in-word) + even = hw[..., 0::2] # (N, chunk, 4, 4) positions 0,2,4,6 + odd = hw[..., 1::2] # (N, chunk, 4, 4) positions 1,3,5,7 + # Explicit OR (not sum) keeps the result uint8 (torch.sum would promote). + hi_even_byte = ( + even[..., 0] | (even[..., 1] << 2) | (even[..., 2] << 4) | (even[..., 3] << 6) + ) # (N, chunk, 4) uint8 + hi_odd_byte = ( + odd[..., 0] | (odd[..., 1] << 2) | (odd[..., 2] << 4) | (odd[..., 3] << 6) + ) + qh = torch.cat([hi_even_byte, hi_odd_byte], dim=-1) # (N, chunk, 8) uint8 + qh = qh.reshape(N, K // 4) + return ql.contiguous(), qh.contiguous() + + +def unpack_int6(ql: torch.Tensor, qh: torch.Tensor, N: int, K: int) -> torch.Tensor: + """Inverse of :func:`pack_int6`. Returns ``(N, K)`` int16 q in ``[-32, 31]``. + + Intermediates are uint8 to keep peak memory low; only the final ``- 32`` shift + (which produces negatives) widens to int16. + """ + qlu = ql.to(torch.uint8) + lo_even = qlu & 0xF # low nibble -> even weights + lo_odd = (qlu >> 4) & 0xF # high nibble -> odd weights + lo = torch.stack([lo_even, lo_odd], dim=-1).reshape(N, K) # uint8 + + chunks = K // 32 + qhu = qh.to(torch.uint8).reshape(N, chunks, 8) + hi_even_byte = qhu[:, :, 0:4] # (N, chunk, 4) word w + hi_odd_byte = qhu[:, :, 4:8] # (N, chunk, 4) + hi_even = torch.stack( + [(hi_even_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 + ) # (N, chunk, 4, 4) uint8 + hi_odd = torch.stack( + [(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 + ) + hi = torch.empty(N, chunks, 4, 8, dtype=torch.uint8, device=ql.device) + hi[..., 0::2] = hi_even + hi[..., 1::2] = hi_odd + hi = hi.reshape(N, K) + + u = lo | (hi << 4) # [0, 63] uint8 + return u.to(torch.int16) - 32 + + +class CudaPackedInt6Tensor(TorchAOBaseTensor): + """Packed 6-bit weight (ql/qh planes + per-group scale), symmetric. + + ExecuTorch-internal; see the module docstring. The CUDA decode/prefill + dispatch (``int6_dispatch.py``) is selected by *type* — it is registered on + this class only. + """ + + tensor_data_names = ["ql", "qh", "scale"] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + ql: torch.Tensor, + qh: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + kwargs = {} + kwargs["device"] = ql.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + ql: torch.Tensor, + qh: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + super().__init__() + self.ql = ql + self.qh = qh + self.scale = scale + self.block_size = block_size + + def _quantization_type(self): + return ( + f"shape={self.shape}, block_size={self.block_size}, " + f"device={self.device}" + ) + + @classmethod + def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor": + """Build from a torchao ``IntxUnpackedToInt8Tensor`` decoded from Q6_K. + + The source is symmetric (zero_point == 0), ``qdata`` is int8 in + ``[-32, 31]`` and ``scale`` is ``(N, K/16)``. The ql/qh bit-pack is baked + into the serialized weight constant here, once at pack time. + """ + q = t.qdata + if not bool(torch.all(t.zero_point == 0)): + raise ValueError( + "CudaPackedInt6Tensor.from_intx_int8 requires symmetric Q6_K " + "weights (zero_point == 0)" + ) + q_min, q_max = int(q.min()), int(q.max()) + if q_min < -32 or q_max > 31: + raise ValueError( + f"Q6_K values must be in [-32, 31], got [{q_min}, {q_max}]" + ) + ql, qh = pack_int6(q) + return cls( + ql, + qh, + t.scale.contiguous(), + list(t.block_size), + t.shape, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize to a dense tensor (symmetric: ``w = q * scale``). + + Used for the tied lm_head / token embedding (which can't gather a packed + tensor) and as the numerical reference. + """ + dtype = output_dtype if output_dtype is not None else self.scale.dtype + N, K = int(self.shape[0]), int(self.shape[1]) + gs = self.block_size[-1] + q = unpack_int6(self.ql, self.qh, N, K).to(dtype) + scale = self.scale.to(dtype).repeat_interleave(gs, dim=-1) + return (q * scale).to(dtype) + + +# Allow a model with CudaPackedInt6Tensor weights to be loaded with +# `weights_only=True` (mirrors torchao quantized tensors). +torch.serialization.add_safe_globals([CudaPackedInt6Tensor]) diff --git a/backends/cuda/quantize_op_dispatch/__init__.py b/backends/cuda/quantize_op_dispatch/__init__.py index 005c2b6e7c7..bc45b3906f9 100644 --- a/backends/cuda/quantize_op_dispatch/__init__.py +++ b/backends/cuda/quantize_op_dispatch/__init__.py @@ -11,9 +11,11 @@ dequant logic instead of torchao's defaults. It registers: * INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm`` + * INT6 (``CudaPackedInt6Tensor``) → ``executorch_cuda::int6_plain_mm`` * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` -See ``int4_dispatch`` and ``int8_dispatch`` for the per-dtype details. +See ``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` for the per-dtype +details. Import this package before using nn.Linear with quantized weights:: @@ -22,5 +24,6 @@ from executorch.backends.cuda.quantize_op_dispatch import ( # noqa: F401 int4_dispatch, + int6_dispatch, int8_dispatch, ) diff --git a/backends/cuda/quantize_op_dispatch/_library.py b/backends/cuda/quantize_op_dispatch/_library.py index c256e856c2c..2308ecf7102 100644 --- a/backends/cuda/quantize_op_dispatch/_library.py +++ b/backends/cuda/quantize_op_dispatch/_library.py @@ -6,9 +6,10 @@ """Shared torch.library handle for the ``executorch_cuda`` op namespace. -``int4_dispatch`` and ``int8_dispatch`` both register custom ops into the same -``executorch_cuda`` namespace, so they must share a single ``DEF`` library -instance — PyTorch allows only one ``DEF`` per namespace per process. +``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` all register custom +ops into the same ``executorch_cuda`` namespace, so they must share a single +``DEF`` library instance — PyTorch allows only one ``DEF`` per namespace per +process. """ from torch.library import Library diff --git a/backends/cuda/quantize_op_dispatch/int6_dispatch.py b/backends/cuda/quantize_op_dispatch/int6_dispatch.py new file mode 100644 index 00000000000..a26814ded1e --- /dev/null +++ b/backends/cuda/quantize_op_dispatch/int6_dispatch.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CudaPackedInt6Tensor F.linear dispatch for CUDA — eager / export trace time. + +This module registers an F.linear dispatch on ``CudaPackedInt6Tensor`` (an +ExecuTorch-internal subclass, see ``packed_int6_tensor.py``) so that +torch.export traces through our custom op and dequant logic. Routing is by +*type*: only GGUF Q6_K weights (converted to ``CudaPackedInt6Tensor``) take the +packed-int6 path; genuine INT8 weights stay on the int8 path. The code here runs +during eager inference and AOTI export tracing — it does NOT run at .pte runtime. + +At .pte runtime, the captured graph is executed by the AOTI-generated .so: + - The custom op ``executorch_cuda::int6_plain_mm`` maps to a C shim that runs + the W6A8 dp4a matvec kernel (backends/cuda/runtime/shims/int6_plain_mm.*). + - The inline dequant + F.linear is compiled by inductor into fused Triton + dequant + cuBLAS matmul kernels. + +Dispatch strategy (determines what gets captured in the export graph): + Decode (M<=4): Custom op ``executorch_cuda::int6_plain_mm`` + Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) + +The packed-int6 weight is symmetric (no zero point): ``w = q * scale`` with +``q`` in ``[-32, 31]`` stored as the ql/qh planes. The op signature mirrors +int4_plain_mm / int8_plain_mm but takes two weight planes (ql, qh) instead of +one, and no zero tensor. + +Importing the parent ``quantize_op_dispatch`` package registers this dispatch +override (along with the INT4 / INT8 ones):: + + import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 +""" + +import torch +import torch.nn.functional as F +from executorch.backends.cuda.packed_int6_tensor import ( + CudaPackedInt6Tensor, + unpack_int6, +) +from executorch.backends.cuda.quantize_op_dispatch._library import lib as _lib +from torch.library import impl + +# --------------------------------------------------------------------------- +# Custom op for INT6 decode (M<=4): W6A8 dp4a matvec in C shim. +# --------------------------------------------------------------------------- + +_lib.define( + "int6_plain_mm(Tensor self, Tensor ql, Tensor qh, Tensor scale, int group_size) -> Tensor" +) + + +@impl(_lib, "int6_plain_mm", "Meta") +def _meta_int6(self, ql, qh, scale, group_size): + return torch.empty(self.shape[0], ql.shape[0], dtype=self.dtype, device=self.device) + + +@impl(_lib, "int6_plain_mm", "CUDA") +def _cuda_int6(self, ql, qh, scale, group_size): + return _dequant_matmul_int6(self, ql, qh, scale, group_size) + + +def _dequant_matmul_int6(x, ql, qh, scale, group_size): + """Dequant packed-INT6 weights to input dtype and call F.linear. + + ql [N, K/2] / qh [N, K/4] pack symmetric Q6_K values q in [-32, 31]; + scale [N, K//gs]. Dequant: w[n, k] = q[n, k] * scale[n, k//gs]. + """ + N = ql.shape[0] + K = ql.shape[1] * 2 + n_groups = K // group_size + dtype = x.dtype + + q = unpack_int6(ql, qh, N, K).to(dtype).reshape(N, n_groups, group_size) + s = scale.to(dtype).reshape(N, n_groups, 1) + w_deq = (q * s).reshape(N, K) + + return F.linear(x, w_deq) + + +# --------------------------------------------------------------------------- +# CudaPackedInt6Tensor F.linear dispatch (W6A8 dp4a for decode) +# --------------------------------------------------------------------------- + +aten = torch.ops.aten +_implements_i6 = CudaPackedInt6Tensor.implements +_implements_torch_function_i6 = CudaPackedInt6Tensor.implements_torch_function + + +@_implements_i6([aten.linear.default]) +@_implements_torch_function_i6([F.linear]) +def _(func, types, args, kwargs): + input_tensor = args[0] + weight_tensor = args[1] + bias = args[2] if len(args) > 2 else None + + orig_shape = input_tensor.shape + x_2d = input_tensor.reshape(-1, orig_shape[-1]) + + ql = weight_tensor.ql + qh = weight_tensor.qh + scale = weight_tensor.scale + gs = weight_tensor.block_size[-1] + + M = x_2d.shape[0] + if M <= 4: + out = torch.ops.executorch_cuda.int6_plain_mm(x_2d, ql, qh, scale, gs) + else: + out = _dequant_matmul_int6(x_2d, ql, qh, scale, gs) + + out = out.reshape(*orig_shape[:-1], -1) + if bias is not None: + out = out + bias + return out diff --git a/backends/cuda/runtime/shims/int6_plain_mm.cu b/backends/cuda/runtime/shims/int6_plain_mm.cu new file mode 100644 index 00000000000..dd068a5766b --- /dev/null +++ b/backends/cuda/runtime/shims/int6_plain_mm.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda_int6_plain_mm( + Tensor* self, + Tensor* ql, + Tensor* qh, + Tensor* scale, + int64_t group_size, + Tensor** ret0) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + ql != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: ql is null"); + + ET_CHECK_OR_RETURN_ERROR( + qh != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: qh is null"); + + ET_CHECK_OR_RETURN_ERROR( + scale != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: scale is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: ret0 is null"); + + int32_t M = self->size(0); + int32_t N = ql->size(0); + Tensor* C = nullptr; + std::array c_shape = {M, N}; + std::array c_stride = {N, 1}; + aoti_torch_empty_strided( + 2, + c_shape.data(), + c_stride.data(), + static_cast( + executorch::backends::aoti::slim::c10::ScalarType::BFloat16), + static_cast( + executorch::backends::aoti::slim::c10::DeviceType::CUDA), + 0, + &C); + + _int6_plain_mm_cuda(*self, *ql, *qh, *scale, group_size, C); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + + *ret0 = C; + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int6_plain_mm.cuh b/backends/cuda/runtime/shims/int6_plain_mm.cuh new file mode 100644 index 00000000000..a1c7206e6a7 --- /dev/null +++ b/backends/cuda/runtime/shims/int6_plain_mm.cuh @@ -0,0 +1,353 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// W6A8 dp4a matvec for packed INT6 decode (M <= 4), used for GGUF Q6_K weights. +// +// Reads a genuine 6-bit packed weight (CudaPackedInt6Tensor format), split into +// two planes: +// ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd exactly +// like the INT4 path (ql[:,j] = lo[:,2j] | (lo[:,2j+1] << 4)). +// qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte, arranged per +// 32-weight chunk as hi_even_packed[4] then hi_odd_packed[4] (each +// byte holds the four 2-bit highs of one dp4a word in even/odd order). +// scale : [N, K/gs] bf16 — per-group scales, row-major (coalesced; no zero). +// The stored 6-bit value is u = q + 32 in [0, 63] (q in [-32, 31]); the constant +// -32 offset is applied in the kernel, so Q6_K's symmetry means NO zero tensor. +// +// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks, even/odd +// order, identical to the INT4 path), reconstructs full 6-bit weight bytes per +// dp4a word (vfull = vi_lo | (spread2(hi_byte) << 4)), and uses dp4a for fused +// int6xint8 dot products with vectorized weight loads and warp-cooperative +// quantization. +// +// Symbol names are suffixed _i6 / distinct from int4_plain_mm.cuh and +// int8_plain_mm.cuh so all three translation units can be linked together +// without ODR conflicts. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +namespace c10 = executorch::backends::aoti::slim::c10; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +constexpr int32_t MV6_NWARPS = 8; +constexpr int32_t MV6_WARP_SIZE = 32; +constexpr int32_t MV6_THREADS = MV6_NWARPS * MV6_WARP_SIZE; +constexpr int32_t Q8_BLOCK_SIZE_I6 = 32; + +__host__ __forceinline__ int32_t log2_pow2_i6(int32_t v) { + int32_t r = 0; + while (v > 1) { + v >>= 1; + r++; + } + return r; +} + +// Expand a byte's four 2-bit fields into four byte lanes (each in bits 0-1): +// in : b = [.. b7 b6 | b5 b4 | b3 b2 | b1 b0] +// out : lane0=[b1 b0], lane1=[b3 b2], lane2=[b5 b4], lane3=[b7 b6] +// ~6 ALU ops; verified by truth-table. Used to place the high 2 bits of each +// weight into bits 4-5 of the corresponding dp4a byte lane. +__device__ __forceinline__ uint32_t spread2_i6(uint32_t b) { + uint32_t t = (b | (b << 12)) & 0x000F000F; + uint32_t r = (t | (t << 6)) & 0x03030303; + return r; +} + +// --------------------------------------------------------------------------- +// Activation quantization: bf16 -> int8 (warp-cooperative, per-32-element +// blocks, EVEN/ODD order — identical to the INT4 path's Q8Block). +// --------------------------------------------------------------------------- + +// alignas(16) pads sizeof(Q8Block_i6) to 48 so each block (and its qs_even/qs_odd +// 16-byte halves) is 16-byte aligned, allowing two vectorized uint4 loads of a +// block's int8 activations instead of eight scalar int32 loads. +struct alignas(16) Q8Block_i6 { + int8_t qs_even[Q8_BLOCK_SIZE_I6 / 2]; + int8_t qs_odd[Q8_BLOCK_SIZE_I6 / 2]; + float d; // scale +}; + +__global__ void quantize_activations_q8_i6_kernel( + const __nv_bfloat16* __restrict__ A, + Q8Block_i6* __restrict__ q8, + int32_t K) { + const int32_t m = blockIdx.y; + const int32_t block_id = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t n_blocks = K / Q8_BLOCK_SIZE_I6; + if (block_id >= n_blocks) + return; + + const int32_t lane = threadIdx.x; + const __nv_bfloat16* src = + A + static_cast(m) * K + block_id * Q8_BLOCK_SIZE_I6; + Q8Block_i6* dst = q8 + static_cast(m) * n_blocks + block_id; + + float val = __bfloat162float(src[lane]); + + float amax = fabsf(val); + for (int offset = 16; offset > 0; offset >>= 1) + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, offset)); + + float d = amax / 127.0f; + float id = (d > 0.0f) ? 1.0f / d : 0.0f; + int32_t q = __float2int_rn(val * id); + q = max(-128, min(127, q)); + + if (lane % 2 == 0) + dst->qs_even[lane / 2] = static_cast(q); + else + dst->qs_odd[lane / 2] = static_cast(q); + + if (lane == 0) + dst->d = d; +} + +// --------------------------------------------------------------------------- +// W6A8 dp4a matvec kernel +// +// dp4a is linear, so reconstructing v = lo + (hi<<4) and dotting once is +// equivalent to two separate dp4a passes. We reconstruct the full 6-bit byte +// (vfull = vi_lo | (spread2(hi_byte) << 4)) so a single dp4a per even/odd half +// covers the whole weight. The per-group zero is the constant 32 (in u-space), +// applied as out += scale * a_scale * (dp - 32 * a_sum) — no zero load. +// --------------------------------------------------------------------------- + +__global__ void __launch_bounds__(MV6_THREADS) int6_w6a8_matvec_kernel( + const uint8_t* __restrict__ ql, // [N, K/2] + const uint8_t* __restrict__ qh, // [N, K/4] + const __nv_bfloat16* __restrict__ w_scale, // [N, n_groups] + const Q8Block_i6* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift, + int32_t n_groups) { + const int32_t n = blockIdx.x * MV6_NWARPS + threadIdx.y; + const int32_t m = blockIdx.y; + if (n >= N) + return; + + const int32_t K_half = K / 2; + const int32_t K_quarter = K / 4; + const int32_t lane_id = threadIdx.x; + const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE_I6; + + const uint8_t* qlrow = ql + static_cast(n) * K_half; + const uint8_t* qhrow = qh + static_cast(n) * K_quarter; + const __nv_bfloat16* scale_row = w_scale + static_cast(n) * n_groups; + const Q8Block_i6* q8_row = q8 + static_cast(m) * n_q8_blocks; + + // Vectorized loads: one uint4 of ql (32 weights) + one uint2 of qh (the + // 8 high-bit bytes for the same 32-weight chunk) per iteration. + const uint4* qlrow16 = reinterpret_cast(qlrow); + const uint2* qhrow8 = reinterpret_cast(qhrow); + const int32_t K_half_16 = K_half / 16; + + float sum = 0.0f; + + int32_t prev_g = -1; + float ws = 0.0f; + + for (int32_t i = lane_id; i < K_half_16; i += MV6_WARP_SIZE) { + uint4 packed16 = __ldg(&qlrow16[i]); + uint2 qh_chunk = __ldg(&qhrow8[i]); + int32_t k_base = i * 32; + uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // qh_chunk.x bytes = hi_even_packed[0..3], qh_chunk.y = hi_odd_packed[0..3]. + uint32_t hi_even_word = qh_chunk.x; + uint32_t hi_odd_word = qh_chunk.y; + + // One uint4 (32 weights) maps to exactly one Q8 activation block (32 + // activations), i.e. q8_block_idx == i. Load the whole block with two + // vectorized uint4 loads (+ one scale load). + const Q8Block_i6* qb = &q8_row[i]; + uint4 ae = *reinterpret_cast(qb->qs_even); + uint4 ao = *reinterpret_cast(qb->qs_odd); + float a_scale = qb->d; + const uint32_t a_even[4] = {ae.x, ae.y, ae.z, ae.w}; + const uint32_t a_odd[4] = {ao.x, ao.y, ao.z, ao.w}; + +#pragma unroll + for (int32_t w = 0; w < 4; w++) { + uint32_t packed = words[w]; + int32_t k_word = k_base + w * 8; + int32_t g = k_word >> gs_shift; + + if (g != prev_g) { + ws = __bfloat162float(__ldg(&scale_row[g])); + prev_g = g; + } + + int32_t vi_lo = static_cast(packed & 0x0F0F0F0F); + int32_t vi_hi = static_cast((packed >> 4) & 0x0F0F0F0F); + + uint32_t hi_even_byte = (hi_even_word >> (w * 8)) & 0xFF; + uint32_t hi_odd_byte = (hi_odd_word >> (w * 8)) & 0xFF; + + // Reconstruct full 6-bit weight bytes (u in [0, 63]). + int32_t vfull_even = + vi_lo | static_cast(spread2_i6(hi_even_byte) << 4); + int32_t vfull_odd = + vi_hi | static_cast(spread2_i6(hi_odd_byte) << 4); + + int32_t dp = __dp4a(vfull_even, static_cast(a_even[w]), 0); + dp = __dp4a(vfull_odd, static_cast(a_odd[w]), dp); + + int32_t a_sum = __dp4a(0x01010101, static_cast(a_even[w]), 0); + a_sum = __dp4a(0x01010101, static_cast(a_odd[w]), a_sum); + + // q = u - 32, so the -32 offset replaces the per-group zero point. + sum += ws * a_scale * + (static_cast(dp) - 32.0f * static_cast(a_sum)); + } + } + + for (int offset = MV6_WARP_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + if (lane_id == 0) + out[static_cast(m) * N + n] = __float2bfloat16(sum); +} + +// --------------------------------------------------------------------------- +// Persistent Q8 buffer (lazy init, not thread-safe — single-stream only). +// Freed at process exit via a static guard so leak detectors stay quiet; the +// CUDA runtime would otherwise reclaim it on teardown anyway. +// --------------------------------------------------------------------------- + +static Q8Block_i6* g_q8_buf_i6 = nullptr; +static size_t g_q8_buf_i6_size = 0; + +namespace { +struct Q8BufferGuardI6 { + ~Q8BufferGuardI6() { + if (g_q8_buf_i6) { + // Ignore errors: during process teardown the CUDA context may already be + // gone (cudaErrorCudartUnloading), which is harmless here. + cudaFree(g_q8_buf_i6); + g_q8_buf_i6 = nullptr; + g_q8_buf_i6_size = 0; + } + } +}; +Q8BufferGuardI6 g_q8_buf_i6_guard; +} // namespace + +static Q8Block_i6* get_q8_buffer_i6(size_t needed) { + if (g_q8_buf_i6_size < needed) { + if (g_q8_buf_i6) + cudaFree(g_q8_buf_i6); + cudaError_t err = cudaMalloc(&g_q8_buf_i6, needed); + ET_CHECK_MSG( + err == cudaSuccess, + "cudaMalloc failed for Q8 buffer (int6): %s", + cudaGetErrorString(err)); + g_q8_buf_i6_size = needed; + } + return g_q8_buf_i6; +} + +// --------------------------------------------------------------------------- +// Main entry point +// --------------------------------------------------------------------------- + +inline void _int6_plain_mm_cuda( + const Tensor& A, // [M, K] bf16 + const Tensor& ql, // [N, K/2] uint8 + const Tensor& qh, // [N, K/4] uint8 + const Tensor& scale, // [N, K/gs] bf16 + int64_t group_size, + Tensor* output) { // [M, N] bf16, pre-allocated + int32_t M = A.size(0); + int32_t K = A.size(1); + int32_t N = ql.size(0); + + ET_CHECK(A.dtype() == c10::ScalarType::BFloat16); + ET_CHECK( + ql.dtype() == c10::ScalarType::Byte || + ql.dtype() == c10::ScalarType::Char); + ET_CHECK( + qh.dtype() == c10::ScalarType::Byte || + qh.dtype() == c10::ScalarType::Char); + ET_CHECK(scale.dtype() == c10::ScalarType::BFloat16); + ET_CHECK(A.dim() == 2); + ET_CHECK(ql.dim() == 2); + ET_CHECK(ql.size(1) == K / 2); + ET_CHECK(qh.dim() == 2); + ET_CHECK(qh.size(1) == K / 4); + ET_CHECK(scale.dim() == 2); + ET_CHECK(scale.size(0) == N); + + int32_t gs = static_cast(group_size); + ET_CHECK_MSG( + gs > 0 && (gs & (gs - 1)) == 0, "group_size=%d must be a power of 2", gs); + // group_size must be a multiple of 8 (the dp4a word stride) so a word never + // straddles a group boundary; gs=16 covers GGUF Q6_K. + ET_CHECK_MSG( + gs % 8 == 0, + "group_size=%d must be a multiple of 8 (e.g. 16 for GGUF Q6_K)", + gs); + ET_CHECK_MSG( + K >= Q8_BLOCK_SIZE_I6 && K % Q8_BLOCK_SIZE_I6 == 0, + "K=%d must be a positive multiple of %d for dp4a int6 kernel", + K, + Q8_BLOCK_SIZE_I6); + + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_MSG(stream_result.ok(), "Failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + int32_t gs_shift = log2_pow2_i6(gs); + + // Quantize activations to INT8 (even/odd order) + int32_t n_q8_blocks = K / Q8_BLOCK_SIZE_I6; + size_t q8_bytes = static_cast(M) * n_q8_blocks * sizeof(Q8Block_i6); + Q8Block_i6* q8_buf = get_q8_buffer_i6(q8_bytes); + + constexpr int32_t Q8_WARPS = 8; + int32_t blocks_per_m = (n_q8_blocks + Q8_WARPS - 1) / Q8_WARPS; + dim3 q8_grid(blocks_per_m, M); + dim3 q8_block(MV6_WARP_SIZE, Q8_WARPS); + quantize_activations_q8_i6_kernel<<>>( + reinterpret_cast(A.data_ptr()), q8_buf, K); + + // dp4a matvec + dim3 grid((N + MV6_NWARPS - 1) / MV6_NWARPS, M); + dim3 block(MV6_WARP_SIZE, MV6_NWARPS); + + int32_t n_groups = static_cast(scale.size(1)); + int6_w6a8_matvec_kernel<<>>( + reinterpret_cast(ql.data_ptr()), + reinterpret_cast(qh.data_ptr()), + reinterpret_cast(scale.data_ptr()), + q8_buf, + reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), + N, + K, + gs_shift, + n_groups); +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int6_plain_mm.h b/backends/cuda/runtime/shims/int6_plain_mm.h new file mode 100644 index 00000000000..e093fb9f055 --- /dev/null +++ b/backends/cuda/runtime/shims/int6_plain_mm.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Packed INT6 matrix multiplication for GGUF Q6_K weights (symmetric). + * + * The 6-bit weight is split into two planes plus a per-group scale; there is + * NO zero tensor — Q6_K is symmetric and the stored value is u = q + 32 in + * [0, 63] (q in [-32, 31]), with the constant -32 offset applied in the kernel. + * + * Weight format: + * ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd + * (ql[:,j] = (u[:,2j] & 0xF) | ((u[:,2j+1] & 0xF) << 4)). + * qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte, arranged per + * 32-weight chunk as hi_even_packed[4] then hi_odd_packed[4]; each + * byte holds the four 2-bit highs of one dp4a word, bit field j + * (bits 2j..2j+1) = high 2 bits of that word's j-th even/odd weight. + * scale : [N, K//group_size] bf16 per-group scales (row-major). + * W6A8 dp4a matvec: dynamically quantizes activations to INT8, reconstructs + * full 6-bit weight bytes, then uses dp4a for fused int6xint8 dot products. + * + * @param self Input activation [M, K] bf16 + * @param ql Low-nibble plane [N, K/2] uint8 + * @param qh High-2-bit plane [N, K/4] uint8 + * @param scale Per-group scales [N, K//group_size] bf16 + * @param group_size Quantization group size (multiple of 8; e.g. 16 for Q6_K) + * @param ret0 Output [M, N] bf16 + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_int6_plain_mm( + Tensor* self, + Tensor* ql, + Tensor* qh, + Tensor* scale, + int64_t group_size, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index 62e9180d603..072e4effad4 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -49,8 +49,9 @@ set(CUDA_SHIM_TESTS ) # CUDA-specific tests requiring GPU kernels -set(CUDA_KERNEL_TESTS test_aoti_torch_cuda__weight_int4pack_mm - test_aoti_torch_cuda_int4_plain_mm +set(CUDA_KERNEL_TESTS + test_aoti_torch_cuda__weight_int4pack_mm test_aoti_torch_cuda_int4_plain_mm + test_aoti_torch_cuda_int6_plain_mm ) enable_testing() diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp new file mode 100644 index 00000000000..43d3946294a --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp @@ -0,0 +1,306 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::backends::cuda::aoti_torch_cuda_int6_plain_mm; +using executorch::backends::cuda::aoti_torch_empty_strided; +using executorch::backends::cuda::AOTITorchError; +using executorch::runtime::Error; +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +using Tensor = executorch::backends::aoti::slim::SlimTensor; + +// W6A8 dp4a matvec shim for packed-INT6 decode (CudaPackedInt6Tensor layout, +// GGUF Q6_K). The 6-bit weight is split into two planes plus a per-group scale; +// there is NO zero tensor (Q6_K is symmetric, the -32 offset is applied in the +// kernel): +// ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd +// qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte (per 32-weight +// chunk: hi_even_packed[4] then hi_odd_packed[4]) +// scale : [N, K//gs] bf16 — per-group scales (row-major) +// +// Expected outputs are generated from the export-path reference +// _dequant_matmul_int6 (backends/cuda/quantize_op_dispatch/int6_dispatch.py): +// w[n, k] = q[n, k] * scale[n, k//gs]; out = A @ w^T (q symmetric, in +// [-32,31]). +// The kernel runs W6A8 (it also quantizes activations to int8), so a 0.5 atol +// absorbs the activation-quant noise (same tolerance as the int4/int8 tests). +class AOTITorchInt6PlainMMTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available"; + } + } + + Tensor* create_tensor( + const std::vector& sizes, + slim_c10::ScalarType dtype) { + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + static_cast(dtype), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } + + Tensor* create_bf16(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::BFloat16); + } + + // ql / qh are uint8 (ScalarType::Byte) packed planes. + Tensor* create_uint8(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::Byte); + } + + // Upload raw bytes to a CUDA tensor. + void upload(Tensor* t, const void* host_data, size_t bytes) { + cudaMemcpy(t->data_ptr(), host_data, bytes, cudaMemcpyHostToDevice); + } + + // Download CUDA tensor to host buffer. + void download(const Tensor* t, void* host_data, size_t bytes) { + cudaMemcpy(host_data, t->data_ptr(), bytes, cudaMemcpyDeviceToHost); + } + + // Run the shim and return the output tensor (asserts success). + Tensor* + run(Tensor* A, Tensor* ql, Tensor* qh, Tensor* scale, int64_t group_size) { + Tensor* output = nullptr; + AOTITorchError error = + aoti_torch_cuda_int6_plain_mm(A, ql, qh, scale, group_size, &output); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(output, nullptr); + return output; + } + + // Check output bf16 values against expected, with absolute tolerance. + void check_bf16_output( + Tensor* output, + const uint16_t* expected_data, + int64_t count, + float atol = 0.5f) { + std::vector actual(count); + download(output, actual.data(), count * sizeof(uint16_t)); + cudaDeviceSynchronize(); + + for (int64_t i = 0; i < count; i++) { + // Convert bf16 raw bits to float: bf16 is the upper 16 bits of float32. + uint32_t actual_bits = static_cast(actual[i]) << 16; + uint32_t expected_bits = static_cast(expected_data[i]) << 16; + float actual_f, expected_f; + memcpy(&actual_f, &actual_bits, sizeof(float)); + memcpy(&expected_f, &expected_bits, sizeof(float)); + + EXPECT_NEAR(actual_f, expected_f, atol) + << "Mismatch at index " << i << ": actual=" << actual_f + << " expected=" << expected_f; + } + } + + // Upload data and run the shim. ql/qh are uint8; scale/A are bf16. + Tensor* setup_and_run( + int64_t M, + int64_t N, + int64_t K, + int64_t gs, + const uint8_t* ql_host, + const uint8_t* qh_host, + const uint16_t* scale_host, + const uint16_t* A_host) { + int64_t ng = K / gs; + Tensor* A = create_bf16({M, K}); + Tensor* ql = create_uint8({N, K / 2}); + Tensor* qh = create_uint8({N, K / 4}); + Tensor* scale = create_bf16({N, ng}); + EXPECT_NE(A, nullptr); + EXPECT_NE(ql, nullptr); + EXPECT_NE(qh, nullptr); + EXPECT_NE(scale, nullptr); + + upload(A, A_host, static_cast(M) * K * sizeof(uint16_t)); + upload(ql, ql_host, static_cast(N) * (K / 2) * sizeof(uint8_t)); + upload(qh, qh_host, static_cast(N) * (K / 4) * sizeof(uint8_t)); + upload(scale, scale_host, static_cast(N) * ng * sizeof(uint16_t)); + + return run(A, ql, qh, scale, gs); + } +}; + +// Q6KGroupSize16: M=2, N=4, K=64, gs=16, symmetric (no zero), q in [-32,31]. +// The canonical GGUF Q6_K shape (group_size=16). +TEST_F(AOTITorchInt6PlainMMTest, Q6KGroupSize16) { + int64_t M = 2, N = 4, K = 64, gs = 16; + + // clang-format off + uint8_t ql_host[] = { + 249, 176, 113, 205, 113, 130, 205, 208, 208, 220, 36, 28, 90, 117, 20, 139, + 24, 99, 43, 2, 253, 112, 107, 185, 154, 203, 229, 119, 15, 8, 139, 95, + 117, 50, 27, 48, 120, 65, 40, 224, 147, 165, 182, 177, 210, 160, 239, 192, + 136, 20, 241, 201, 43, 56, 64, 34, 219, 104, 39, 103, 79, 70, 196, 157, + 193, 90, 70, 26, 31, 78, 234, 55, 53, 19, 198, 24, 26, 71, 88, 181, + 205, 210, 95, 167, 16, 80, 183, 76, 106, 66, 44, 124, 17, 197, 49, 227, + 46, 51, 2, 185, 46, 243, 128, 59, 39, 121, 45, 252, 221, 98, 155, 170, + 27, 31, 108, 91, 235, 129, 177, 104, 44, 22, 110, 142, 169, 226, 255, 217 + }; + uint8_t qh_host[] = { + 21, 230, 10, 92, 55, 212, 46, 90, 227, 91, 52, 88, 49, 132, 203, 60, + 255, 132, 109, 173, 8, 49, 181, 163, 130, 224, 227, 13, 216, 86, 234, 219, + 180, 142, 137, 139, 87, 161, 244, 72, 109, 20, 107, 165, 31, 47, 99, 59, + 215, 173, 1, 159, 180, 83, 227, 190, 15, 222, 95, 108, 117, 157, 225, 105 + }; + uint16_t scale_host[] = { + 0x3C6E, 0xBCF6, 0x3CC3, 0xBB88, 0xBD0C, 0x3D5A, 0x3B40, 0x3D43, 0xBB71, 0xBD6A, 0x3D16, 0xBCC3, + 0xBC1E, 0x3D2A, 0xBCC3, 0xBD37 + }; + uint16_t A_host[] = { + 0x3F5C, 0xBF3E, 0x0000, 0xBC33, 0xBE9A, 0x3CAA, 0x3F7A, 0xBF94, 0xC016, 0xBFF6, 0x0000, 0x3E71, + 0xBFD3, 0x3F5E, 0xBF96, 0x3E2A, 0x4023, 0x3EC0, 0x3E90, 0xC00C, 0x3F84, 0xBEEA, 0xBE32, 0x3F71, + 0x0000, 0x3EC9, 0xBEE2, 0x3EE8, 0x3F30, 0xBECB, 0x3F1F, 0xBF2F, 0xBF2A, 0x3F01, 0x3F11, 0x3F88, + 0xBF6A, 0x3FD4, 0xBDD5, 0x3F8F, 0xBF5F, 0xBEBA, 0xBF24, 0xBF45, 0xBF3F, 0x3E51, 0xBE7D, 0xBF35, + 0x3E73, 0x3F1B, 0x3F34, 0x3EA2, 0xBF13, 0xBF4F, 0xBEE2, 0x4006, 0x3F37, 0x3EC5, 0x3F9F, 0xBD79, + 0x3F21, 0xBF0C, 0xBEA9, 0x3FF2, 0x3F55, 0x3FD6, 0x3FAB, 0x3F89, 0xBDA1, 0x3EDD, 0xBF8D, 0xBE4F, + 0xC005, 0xBFBD, 0xBF59, 0x3CD7, 0x3E07, 0xBEEA, 0x3EAC, 0x4038, 0x3F7E, 0xBE4B, 0xBE3A, 0xBF99, + 0xBFCC, 0x3EF0, 0xBF84, 0xBEE8, 0xBF6E, 0xBC97, 0xBF57, 0xBF3F, 0x3FD7, 0xBFB5, 0x3F0C, 0x3E3F, + 0x3F77, 0xBE45, 0x3FAA, 0x3FE1, 0x3D9C, 0x3F8F, 0xBF38, 0xBF1F, 0xBF07, 0xBE94, 0xBF58, 0xBF85, + 0x3FCE, 0x3F2A, 0x3EAC, 0xBF45, 0x3DC4, 0x3E9E, 0xBF9C, 0x3F0A, 0x3E8F, 0x3EA7, 0xBEFB, 0xBE65, + 0xBFB1, 0xBF58, 0xBF88, 0x3EC2, 0xC008, 0x3F7C, 0xBFFC, 0xBF66 + }; + uint16_t expected[] = { + 0x3F46, 0xC02B, 0x40C5, 0xBED9, 0xBECA, 0x4098, 0x3F96, 0x3F19 + }; + // clang-format on + + Tensor* output = + setup_and_run(M, N, K, gs, ql_host, qh_host, scale_host, A_host); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// LargeKGroupSize16: M=1, N=2, K=256, gs=16, symmetric — a larger-K decode case +// (16 groups) exercising the multi-iteration warp loop on the gs=16 path. +TEST_F(AOTITorchInt6PlainMMTest, LargeKGroupSize16) { + int64_t M = 1, N = 2, K = 256, gs = 16; + + // clang-format off + uint8_t ql_host[] = { + 69, 12, 100, 182, 132, 79, 45, 206, 141, 218, 39, 249, 136, 245, 75, 210, + 18, 150, 51, 178, 183, 119, 174, 151, 235, 77, 75, 247, 29, 241, 55, 154, + 12, 189, 29, 93, 92, 153, 20, 52, 67, 219, 12, 178, 99, 207, 12, 151, + 5, 133, 30, 141, 56, 234, 26, 101, 93, 150, 46, 101, 80, 30, 33, 153, + 240, 83, 103, 193, 72, 152, 248, 85, 69, 52, 240, 168, 4, 81, 134, 98, + 101, 106, 122, 199, 212, 244, 190, 139, 33, 62, 6, 147, 243, 106, 105, 196, + 120, 49, 123, 17, 38, 205, 200, 90, 10, 248, 177, 182, 9, 195, 90, 9, + 127, 194, 250, 109, 105, 141, 182, 53, 35, 162, 151, 192, 134, 134, 246, 198, + 202, 191, 86, 93, 221, 185, 60, 230, 242, 167, 247, 189, 35, 210, 188, 146, + 8, 218, 95, 120, 119, 39, 177, 110, 158, 144, 0, 36, 69, 219, 134, 94, + 29, 25, 81, 213, 207, 185, 206, 89, 113, 1, 50, 59, 238, 29, 69, 128, + 97, 97, 229, 181, 211, 253, 157, 118, 71, 232, 63, 21, 171, 62, 115, 78, + 3, 109, 188, 187, 172, 5, 144, 190, 60, 214, 171, 194, 232, 6, 192, 189, + 136, 45, 201, 26, 110, 239, 63, 229, 197, 85, 25, 121, 147, 63, 227, 20, + 30, 66, 228, 231, 197, 90, 65, 116, 255, 50, 51, 88, 142, 60, 112, 10, + 18, 192, 52, 144, 148, 19, 197, 32, 3, 157, 152, 52, 176, 31, 38, 242 + }; + uint8_t qh_host[] = { + 235, 21, 174, 144, 160, 216, 229, 90, 25, 104, 128, 211, 93, 165, 189, 219, + 87, 210, 115, 144, 79, 31, 166, 108, 199, 41, 50, 92, 21, 45, 124, 158, + 142, 126, 0, 139, 23, 77, 180, 181, 218, 246, 98, 252, 50, 141, 10, 82, + 82, 31, 128, 233, 230, 216, 156, 120, 193, 161, 94, 122, 62, 85, 233, 8, + 199, 237, 102, 124, 105, 252, 43, 58, 34, 218, 77, 242, 219, 85, 16, 221, + 102, 49, 77, 226, 23, 30, 142, 36, 110, 63, 97, 59, 164, 214, 221, 103, + 253, 67, 106, 140, 18, 75, 207, 144, 21, 18, 108, 84, 110, 217, 45, 114, + 180, 170, 6, 111, 131, 171, 200, 246, 55, 206, 40, 185, 16, 114, 54, 62 + }; + uint16_t scale_host[] = { + 0x3CF1, 0x3C5B, 0x3B89, 0x3B53, 0x3865, 0xBD3E, 0x3D2F, 0x3AD1, 0x3CC6, 0x3D06, 0xBCFE, 0x3BDD, + 0x3D60, 0x3BD0, 0xBD1A, 0x3D1F, 0xBBBA, 0x3D58, 0x3CD5, 0xBCD3, 0x3BB7, 0x3CF3, 0x3D05, 0x3D0B, + 0x3D42, 0xBBF0, 0x3CC5, 0xBC17, 0xBD73, 0xBC09, 0xBC01, 0xBD24 + }; + uint16_t A_host[] = { + 0xBF33, 0xBF48, 0xBE27, 0x3F25, 0xBFF5, 0x3F5C, 0xBFCE, 0xBF36, 0x3DFA, 0x3EE3, 0xBF64, 0x3E14, + 0xBF41, 0x3E5C, 0x3ED3, 0xBF93, 0xBF45, 0x3BC7, 0xBEF0, 0x3D95, 0xBF20, 0x3E4D, 0xBEA8, 0xBF49, + 0x3F65, 0xBF75, 0xBEA2, 0x3F35, 0x3DE0, 0xBDB1, 0xBEA7, 0xBF5B, 0x3F7F, 0x3F47, 0x3FA4, 0x3FB6, + 0xBE20, 0xBFDE, 0xBD38, 0xBFC6, 0x3F22, 0xBF91, 0xBEA8, 0xBFEA, 0x3FA0, 0xBFAB, 0x3F78, 0xBFAC, + 0x3EA4, 0x3FB3, 0xBF88, 0xBF3B, 0xBEA4, 0x3EDF, 0x3F01, 0x3E7A, 0xBF5F, 0xBD3E, 0x3FA3, 0xBF68, + 0xBF32, 0x3EC0, 0xBF59, 0x3EE9, 0xBEB9, 0xBEC4, 0x3F1E, 0xBE8A, 0x3FBE, 0x3F19, 0x3FC2, 0xC00B, + 0xBEF4, 0xBF45, 0xBEC8, 0x3FC7, 0x3F09, 0x3F97, 0x3F43, 0xBF47, 0x3FCF, 0x3E26, 0x3E10, 0xBEA9, + 0x3EA2, 0x3FAE, 0x3F3F, 0x3E93, 0xBFB6, 0x3FCA, 0x3F70, 0x3FD6, 0x3E58, 0xBF17, 0x3FB2, 0xBE16, + 0x4006, 0x3FC1, 0x3F7D, 0x3F3E, 0xBE03, 0x3ED5, 0x3F0A, 0xBE95, 0xBE89, 0x3F8E, 0x3EF0, 0x3FBB, + 0x3F83, 0xBFCB, 0x3E18, 0x3FA8, 0x3F60, 0x3F1D, 0xBFB4, 0x3FB8, 0xBDB3, 0xBF77, 0xBEBC, 0x3E68, + 0x3EAC, 0x3F54, 0x3F72, 0xC01B, 0x3E4C, 0x3FA9, 0xBDCC, 0xBE59, 0xBF8D, 0xBE29, 0x3E80, 0x3FB9, + 0xBFD0, 0x3E11, 0xBF42, 0xBECE, 0xBE42, 0x4016, 0x3C98, 0x3E5B, 0x3F43, 0x3FB1, 0x3F30, 0xBE69, + 0x3F2C, 0x3F4A, 0x3F43, 0x3FAB, 0x3E4C, 0xBF9C, 0xBEF7, 0xBF87, 0x3DA9, 0x3F2E, 0xBEA8, 0xBF4A, + 0x3F80, 0xBF1E, 0xBE81, 0x3EA5, 0x3F0E, 0xBF50, 0x3EA4, 0x3FD3, 0xBE3C, 0x3F8D, 0xBF38, 0xBEB3, + 0x3E86, 0x3F79, 0xBF77, 0x3E26, 0x3F6E, 0x3DDF, 0xBCB2, 0x3F92, 0xBE11, 0xBF0E, 0xBFFE, 0xBF6A, + 0x3FA0, 0x0000, 0xBF84, 0x3FA7, 0x3F23, 0x3F8F, 0xBF90, 0xBF2F, 0x3F8A, 0x0000, 0xBDA4, 0x3F6A, + 0x3E9D, 0x3FAB, 0xBEDB, 0x3F06, 0x3EFB, 0xBF86, 0x3DAD, 0xBE1C, 0xBF85, 0x3F65, 0xBF5C, 0xBE89, + 0x3EC4, 0x3F85, 0x3EF7, 0x3C47, 0x3E98, 0x3EFB, 0x3DC9, 0x3D1B, 0xBECD, 0x4007, 0x3ED0, 0xBF28, + 0x3F99, 0x3E9F, 0xBF7A, 0x3EBD, 0xBEEE, 0xBF1C, 0xBED0, 0xBF01, 0x3F76, 0xBE8A, 0xBF8C, 0x3EDD, + 0x3FE6, 0x3ECA, 0x3F45, 0xBF64, 0xBE8F, 0x3FC7, 0x3FD4, 0xBF2D, 0x3F0C, 0x3F58, 0x3F45, 0x3E8B, + 0x3A08, 0x3F9E, 0x4004, 0x3F9D, 0xBFDE, 0xBF69, 0xBF8E, 0xBF0B, 0x3F89, 0x3DFA, 0xBF91, 0xC019, + 0x3DAA, 0x3F09, 0x3F69, 0x3F3E + }; + uint16_t expected[] = { + 0xC196, 0x40F7 + }; + // clang-format on + + Tensor* output = + setup_and_run(M, N, K, gs, ql_host, qh_host, scale_host, A_host); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +TEST_F(AOTITorchInt6PlainMMTest, NullInputHandling) { + int64_t M = 2, K = 128, N = 64, gs = 16; + + Tensor* A = create_bf16({M, K}); + Tensor* ql = create_uint8({N, K / 2}); + Tensor* qh = create_uint8({N, K / 4}); + Tensor* scale = create_bf16({N, K / gs}); + Tensor* output = nullptr; + + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(nullptr, ql, qh, scale, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, nullptr, qh, scale, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, ql, nullptr, scale, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, ql, qh, nullptr, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, ql, qh, scale, gs, nullptr), + Error::InvalidArgument); +} diff --git a/backends/cuda/tests/test_int6_dispatch.py b/backends/cuda/tests/test_int6_dispatch.py new file mode 100644 index 00000000000..63602618b3a --- /dev/null +++ b/backends/cuda/tests/test_int6_dispatch.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for CudaPackedInt6Tensor F.linear dispatch via int6_dispatch. + +These tests validate the eager / trace-time dispatch path — the same code that +torch.export traces through when building the AOTI graph. They do NOT test the +.pte runtime C shim (W6A8 dp4a kernel); that is covered by +test_aoti_torch_cuda_int6_plain_mm.cpp (C++ unit tests). + +The API contract: after importing int6_dispatch, F.linear / nn.Linear with a +CudaPackedInt6Tensor weight produce numerically correct results, routed by +batch size (decode M<=4 -> custom op, prefill M>4 -> inline dequant). Routing +tests run without a GPU by recording calls to the decode custom op. + +Usage: + python -m pytest backends/cuda/tests/test_int6_dispatch.py -v +""" + +import contextlib +import unittest +from unittest import mock + +import executorch.backends.cuda.quantize_op_dispatch.int6_dispatch # noqa: F401 +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor, pack_int6 +from executorch.backends.cuda.quantize_op_dispatch.int6_dispatch import ( + _dequant_matmul_int6, +) + + +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +def _make_int6_tensor(N, K, group_size=16): + """Build a CudaPackedInt6Tensor (symmetric Q6_K) and return (tensor, q, scale). + + ``q`` (int8 in [-32, 31]) and ``scale`` are the originals, so tests can + measure against the exact dequant reference ``w = q * scale``. + """ + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // group_size) * 0.1 + 0.01).to(torch.bfloat16) + ql, qh = pack_int6(q) + t = CudaPackedInt6Tensor(ql, qh, scale, [1, group_size], torch.Size([N, K])) + return t, q, scale + + +def _ref_weight(q, scale, group_size, dtype=torch.bfloat16): + """Exact dequant reference: w[n, k] = q[n, k] * scale[n, k//gs].""" + N, K = q.shape + ng = K // group_size + w = q.to(dtype).reshape(N, ng, group_size) * scale.to(dtype).reshape(N, ng, 1) + return w.reshape(N, K) + + +@contextlib.contextmanager +def _record_int6_plain_mm(): + """Record calls to the decode custom op without needing a GPU. + + Replaces ``torch.ops.executorch_cuda.int6_plain_mm`` (whose real impl is the + CUDA C shim) with a recorder that computes the result via the eager CPU + dequant, so the dispatch handler still returns a valid tensor. + """ + calls = [] + + def _fake(self, ql, qh, scale, group_size): + calls.append((tuple(self.shape), group_size)) + return _dequant_matmul_int6(self, ql, qh, scale, group_size) + + with mock.patch.object(torch.ops.executorch_cuda, "int6_plain_mm", _fake): + yield calls + + +class TestDispatchRouting(unittest.TestCase): + """Type-based routing: M<=4 -> int6_plain_mm op, M>4 -> inline dequant. + + Runs without a GPU by recording calls to the decode custom op and computing + the result with the eager CPU dequant. + """ + + def setUp(self): + torch.manual_seed(0) + + def _rel_err(self, out, ref): + return ( + (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + ).item() + + def test_decode_routes_to_int6_plain_mm(self): + """M<=4 routes to the decode custom op.""" + t, _, _ = _make_int6_tensor(16, 64) + x = torch.randn(1, 64, dtype=torch.bfloat16) # M=1 (decode regime) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(len(calls), 1) + self.assertEqual(out.shape, (1, 16)) + + def test_prefill_uses_dequant(self): + """M>4 uses inline dequant (no custom op) and is numerically correct.""" + t, q, scale = _make_int6_tensor(16, 64) + x = torch.randn(8, 64, dtype=torch.bfloat16) # M=8 > 4 (prefill regime) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(calls, []) + ref = F.linear(x, _ref_weight(q, scale, 16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_decode_result_matches_reference(self): + """The decode op (eager -> dequant) is numerically correct.""" + t, q, scale = _make_int6_tensor(24, 128) + x = torch.randn(2, 128, dtype=torch.bfloat16) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(len(calls), 1) + ref = F.linear(x, _ref_weight(q, scale, 16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_with_bias(self): + """Bias is added after the matmul on the decode path.""" + t, q, scale = _make_int6_tensor(16, 64) + bias = torch.randn(16, dtype=torch.bfloat16) + x = torch.randn(1, 64, dtype=torch.bfloat16) + with _record_int6_plain_mm(): + out = F.linear(x, t, bias) + ref = F.linear(x, _ref_weight(q, scale, 16), bias) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_3d_batched_input(self): + """3D input is flattened and the output shape is restored.""" + t, q, scale = _make_int6_tensor(16, 64) + x = torch.randn(2, 8, 64, dtype=torch.bfloat16) # flattened M=16 > 4 + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(calls, []) # prefill regime + self.assertEqual(out.shape, (2, 8, 16)) + ref = F.linear(x, _ref_weight(q, scale, 16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_from_intx_int8_roundtrip(self): + """from_intx_int8 packs a symmetric int8 tensor and dispatch is correct.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K, gs = 16, 64, 16 + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + intx = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=torch.zeros_like(scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, gs), + dtype=torch.bfloat16, + activation_quantization=None, + ) + t = CudaPackedInt6Tensor.from_intx_int8(intx) + x = torch.randn(1, K, dtype=torch.bfloat16) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(len(calls), 1) + ref = F.linear(x, _ref_weight(q, scale, gs)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_from_intx_int8_rejects_asymmetric(self): + """A non-zero zero_point (not Q6_K) is rejected.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K, gs = 8, 64, 16 + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + intx = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=torch.ones_like(scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, gs), + dtype=torch.bfloat16, + activation_quantization=None, + ) + with self.assertRaises(ValueError): + CudaPackedInt6Tensor.from_intx_int8(intx) + + +class TestFLinearDispatchCuda(unittest.TestCase): + """F.linear with a CudaPackedInt6Tensor weight on CUDA (eager -> dequant).""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.02): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def _linear(self, N, K, gs=16): + t, q, scale = _make_int6_tensor(N, K, gs) + module = nn.Linear(K, N, bias=False, dtype=torch.bfloat16) + module.weight = nn.Parameter(t, requires_grad=False) + module.cuda() + return module, _ref_weight(q, scale, gs).cuda() + + def test_decode_m1(self): + module, w_ref = self._linear(256, 512) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_prefill_m64(self): + module, w_ref = self._linear(256, 512) + x = torch.randn(64, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_dequantize_matches_reference(self): + t, q, scale = _make_int6_tensor(32, 128) + ref = _ref_weight(q, scale, 16) + self.assertTrue(torch.equal(t.dequantize().cpu(), ref)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 5d7c5ec540d..1cd9c0db8b0 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -16,9 +16,11 @@ by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share the one quantized tensor. -* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor``; +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor`` (native + torchao tensors; the backend packer in ``quant/pack_cuda.py`` repacks them into + ``CudaCoalescedInt4Tensor`` / the genuine 6-bit ``CudaPackedInt6Tensor``). ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to - bf16 (``Int4Tensor`` can't gather), so they are untied. + bf16 (the packed tensors can't gather), so they are untied. Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -91,7 +93,11 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" if backend == "mlx": return gtensor - # CUDA: native torchao quantized tensors. + # CUDA: native torchao quantized tensors. Q4_K -> Int4Tensor; Q6_K (and any + # other quant type) -> IntxUnpackedToInt8Tensor. The backend packer in + # quant/pack_cuda.py repacks these into the ExecuTorch-internal CUDA layouts + # (CudaCoalescedInt4Tensor / CudaPackedInt6Tensor), so the loader itself stays + # backend-agnostic and carries no backends/cuda dependency. if gtensor.ggml_type == "q4_k": return gtensor.to_int4_tensor() return gtensor.to_intx_unpacked_to_int8_tensor() diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index 655d773e7b3..e22e99789b6 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -6,11 +6,17 @@ """CUDA packer: assign quantized weights to model modules. -Converts ``Int4Tensor`` weights to the ExecuTorch-internal -``CudaCoalescedInt4Tensor`` (which owns the scale/zero transpose to the -coalesced [N, n_groups] layout) and passes ``IntxUnpackedToInt8Tensor`` through -as ``nn.Parameter`` without conversion. The quantize_op_dispatch package -(``int4_dispatch`` / ``int8_dispatch``) handles F.linear at runtime. +Repacks native torchao quantized tensors into the ExecuTorch-internal CUDA +layouts read by the decode kernels: + + * ``Int4Tensor`` -> ``CudaCoalescedInt4Tensor`` (bakes the scale/zero transpose + into the coalesced [N, n_groups] layout). + * symmetric Q6_K ``IntxUnpackedToInt8Tensor`` -> ``CudaPackedInt6Tensor`` (the + genuine 6-bit ql/qh planes). + +A genuine INT8 ``IntxUnpackedToInt8Tensor`` is left unchanged for the int8 path. +The quantize_op_dispatch package (``int4_dispatch`` / ``int6_dispatch`` / +``int8_dispatch``) handles F.linear at runtime. No CUDA is required for packing. The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. @@ -28,9 +34,26 @@ # Per-module packers +def _is_symmetric_q6k(w) -> bool: + """True if ``w`` is a symmetric Q6_K ``IntxUnpackedToInt8Tensor``. + + GGUF Q6_K decodes (``gguf.to_intx_unpacked_to_int8_tensor``) to a symmetric + int8 tensor with 16-wide groups and values in ``[-32, 31]``. Those three + properties together distinguish it from a genuine INT8 weight (wider groups + and/or the full int8 range), so the int8 path is never misrouted into the + 6-bit packer. + """ + if tuple(int(b) for b in w.block_size) != (1, 16): + return False + if not bool(torch.all(w.zero_point == 0)): + return False + return int(w.qdata.min()) >= -32 and int(w.qdata.max()) <= 31 + + def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: """Assign a quantized weight to an ``nn.Linear`` module.""" from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor @@ -48,6 +71,12 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> w = CudaCoalescedInt4Tensor.from_int4_tensor(w) module.weight = nn.Parameter(w, requires_grad=False) elif isinstance(w, IntxUnpackedToInt8Tensor): + # GGUF Q6_K decodes to a symmetric int8 tensor; repack it into the genuine + # 6-bit CudaPackedInt6Tensor (ql/qh planes, 0.75 B/elem) for the W6A8 dp4a + # decode kernel — the bit-pack is baked into the weight constant here, + # once. A genuine INT8 weight is left unchanged for the int8 path. + if _is_symmetric_q6k(w): + w = CudaPackedInt6Tensor.from_intx_int8(w) module.weight = nn.Parameter(w, requires_grad=False) else: raise ValueError(f"Unsupported weight type: {type(w).__name__}") diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index ade85efd788..1baf65a1c3e 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -272,6 +272,13 @@ def dequantize_weight( zero = weight.zero_point.float().repeat_interleave(gs, dim=-1) return ((weight.qdata.float() - zero) * scale).to(dtype) + # CudaPackedInt6Tensor (GGUF Q6_K on CUDA) carries its own dequant (symmetric, + # ql/qh planes). Imported lazily to avoid a hard backends/cuda dependency. + from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor + + if isinstance(weight, CudaPackedInt6Tensor): + return weight.dequantize(dtype) + raise TypeError(f"Cannot dequantize {type(weight).__name__}") diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index e4f68fce43c..38eca18f5b8 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -18,6 +18,11 @@ import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 import torch import torch.nn as nn +from executorch.backends.cuda.packed_int6_tensor import ( + CudaPackedInt6Tensor, + pack_int6, + unpack_int6, +) from executorch.examples.models.gemma4_31b.quant.pack import pack_one from executorch.examples.models.gemma4_31b.quant.pack_cuda import ( DEFAULT_CUDA_PACKERS, @@ -124,6 +129,105 @@ def test_unsupported_type_raises(self): pack_linear_for_cuda(module, {"weight": torch.randn(32, 64)}) +class TestPackLinearInt6(unittest.TestCase): + """pack_linear_for_cuda converts a symmetric Q6_K IntxUnpackedToInt8Tensor + (the gguf_loader output) into a CudaPackedInt6Tensor. + + The pack/unpack round-trip is lossless and dequantize() == q * scale (no + CUDA required); the F.linear correctness check is CUDA-only. A genuine INT8 + weight is left on the int8 path. + """ + + def setUp(self): + torch.manual_seed(0) + + def _make_int6(self, N, K, gs=16): + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + ql, qh = pack_int6(q) + t = CudaPackedInt6Tensor(ql, qh, scale, [1, gs], torch.Size([N, K])) + return t, q, scale + + def _make_q6k_intx(self, N, K, gs=16): + """Build a symmetric Q6_K IntxUnpackedToInt8Tensor (mirrors gguf.py).""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + zero = torch.zeros(N, K // gs, dtype=torch.int8) + t = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=zero, + target_dtype=torch.int8, + block_size=(1, gs), + dtype=torch.bfloat16, + activation_quantization=None, + ) + return t, q, scale + + def test_pack_unpack_roundtrip(self): + q = torch.randint(-32, 32, (64, 128), dtype=torch.int8) + ql, qh = pack_int6(q) + self.assertEqual(tuple(ql.shape), (64, 64)) # [N, K/2] + self.assertEqual(tuple(qh.shape), (64, 32)) # [N, K/4] + q_rt = unpack_int6(ql, qh, 64, 128).to(torch.int8) + self.assertTrue(torch.equal(q_rt, q)) + + def test_dequantize_equals_q_scale(self): + t, q, scale = self._make_int6(32, 128, gs=16) + ref = q.to(torch.bfloat16) * scale.to(torch.bfloat16).repeat_interleave( + 16, dim=-1 + ) + self.assertTrue(torch.equal(t.dequantize(), ref)) + + def test_pack_linear_converts_q6k(self): + t, _, _ = self._make_q6k_intx(32, 128) + with torch.device("meta"): + module = nn.Linear(128, 32, bias=False) + pack_linear_for_cuda(module, {"weight": t}) + self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) + self.assertEqual(module.weight.shape, torch.Size([32, 128])) + + def test_pack_linear_real_int8_passthrough(self): + """A genuine INT8 weight (wide groups, full range) is NOT repacked.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + q = torch.randint(-128, 128, (32, 128), dtype=torch.int8) + scale = (torch.rand(32, 128 // 32) * 0.1 + 0.01).to(torch.bfloat16) + zero = torch.zeros(32, 128 // 32, dtype=torch.int8) + t = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=zero, + target_dtype=torch.int8, + block_size=(1, 32), + dtype=torch.bfloat16, + activation_quantization=None, + ) + with torch.device("meta"): + module = nn.Linear(128, 32, bias=False) + pack_linear_for_cuda(module, {"weight": t}) + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + + def test_matmul_correct(self): + _require_cuda(self) + t, q, scale = self._make_q6k_intx(256, 128, gs=16) + module = nn.Linear(128, 256, bias=False) + pack_linear_for_cuda(module, {"weight": t}) + self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) + module.cuda() + x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + w_ref = ( + q.to(torch.bfloat16) + * scale.to(torch.bfloat16).repeat_interleave(16, dim=-1) + ).cuda() + ref = torch.nn.functional.linear(x, w_ref) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + class TestPackEmbedding(unittest.TestCase): """pack_embedding_for_cuda with INT8 per-axis weights.""" diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 0e31a50f37b..4cee363a123 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -245,12 +245,12 @@ def _load(self, tmp): return load_gguf_model(path, backend="cuda", config=GGUF_CONFIG) def test_load_converts_weights(self): - """GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> IntxUnpacked, + """GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> CudaPackedInt6Tensor, embedding bf16.""" from executorch.backends.cuda.coalesced_int4_tensor import ( CudaCoalescedInt4Tensor, ) - from torchao.quantization import IntxUnpackedToInt8Tensor + from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor with tempfile.TemporaryDirectory() as tmp: model, _ = self._load(tmp) @@ -259,9 +259,12 @@ def test_load_converts_weights(self): model.layers[0].self_attn.q_proj.weight.data, CudaCoalescedInt4Tensor ) self.assertIsInstance( - model.layers[0].mlp.down_proj.weight.data, IntxUnpackedToInt8Tensor + model.layers[0].mlp.down_proj.weight.data, CudaPackedInt6Tensor ) - # Token embedding is dequantized to bf16 (Int4/Intx can't gather). + # Tied lm_head is repacked to int6 by pack_cuda (it keeps quantization, + # unlike the token embedding which is dequantized for the gather). + self.assertIsInstance(model.lm_head.weight.data, CudaPackedInt6Tensor) + # Token embedding is dequantized to bf16 (Int4/packed-int6 can't gather). self.assertEqual(model.embed_tokens.weight.dtype, torch.bfloat16) def test_generate(self): From 7db4bba168119202fa68904045c05749645d206f Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 12 Jun 2026 13:41:07 -0700 Subject: [PATCH 3/4] remove comment --- examples/models/gemma4_31b/gguf_loader.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 1cd9c0db8b0..6606ccaa524 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -16,11 +16,10 @@ by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share the one quantized tensor. -* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor`` (native - torchao tensors; the backend packer in ``quant/pack_cuda.py`` repacks them into - ``CudaCoalescedInt4Tensor`` / the genuine 6-bit ``CudaPackedInt6Tensor``). - ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to - bf16 (the packed tensors can't gather), so they are untied. +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``CudaPackedInt6Tensor`` (a genuine + 6-bit packed weight, lossless, symmetric); ``lm_head`` keeps the quantized + tensor but the token embedding is dequantized to bf16 (the packed tensors can't + gather), so they are untied. Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -93,11 +92,7 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" if backend == "mlx": return gtensor - # CUDA: native torchao quantized tensors. Q4_K -> Int4Tensor; Q6_K (and any - # other quant type) -> IntxUnpackedToInt8Tensor. The backend packer in - # quant/pack_cuda.py repacks these into the ExecuTorch-internal CUDA layouts - # (CudaCoalescedInt4Tensor / CudaPackedInt6Tensor), so the loader itself stays - # backend-agnostic and carries no backends/cuda dependency. + # CUDA: native torchao quantized tensors. if gtensor.ggml_type == "q4_k": return gtensor.to_int4_tensor() return gtensor.to_intx_unpacked_to_int8_tensor() From eaea4a78a67b9e657c3717cf4dcce8922eed27a1 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 12 Jun 2026 13:45:45 -0700 Subject: [PATCH 4/4] lin --- backends/cuda/packed_int6_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backends/cuda/packed_int6_tensor.py b/backends/cuda/packed_int6_tensor.py index 104ed5bbfa0..06582df197f 100644 --- a/backends/cuda/packed_int6_tensor.py +++ b/backends/cuda/packed_int6_tensor.py @@ -105,9 +105,7 @@ def unpack_int6(ql: torch.Tensor, qh: torch.Tensor, N: int, K: int) -> torch.Ten hi_even = torch.stack( [(hi_even_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 ) # (N, chunk, 4, 4) uint8 - hi_odd = torch.stack( - [(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 - ) + hi_odd = torch.stack([(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1) hi = torch.empty(N, chunks, 4, 8, dtype=torch.uint8, device=ql.device) hi[..., 0::2] = hi_even hi[..., 1::2] = hi_odd