Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -467,21 +467,27 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
exit 0
fi

# Gemma 4 31B uses a prequantized checkpoint and custom export script
# Gemma 4 31B: download the Q4_K_M GGUF and export via the GGUF loader
if [ "$MODEL_NAME" = "gemma4_31b" ]; then
pip install safetensors huggingface_hub gguf

# Download prequantized model outside OUTPUT_DIR to avoid uploading on failure
# Download GGUF + tokenizer outside OUTPUT_DIR to avoid uploading on failure.
# The unsloth GGUF repo ships the .gguf but no tokenizer.json, so the tokenizer
# is fetched from the (non-GGUF) unsloth/gemma-4-31B-it repo.
LOCAL_MODEL_DIR=$(mktemp -d)
INDUCTOR_CACHE=$(mktemp -d)
trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT

python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')"
GGUF_FILE="gemma-4-31B-it-Q4_K_M.gguf"
python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it-GGUF', '${GGUF_FILE}', local_dir='${LOCAL_MODEL_DIR}')"
python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it', 'tokenizer.json', local_dir='${LOCAL_MODEL_DIR}')"
GGUF_PATH="${LOCAL_MODEL_DIR}/${GGUF_FILE}"

# Sanity check: run inference on the prequantized model
# Sanity check: run inference on the GGUF model
echo "::group::Inference sanity check"
INFERENCE_OUTPUT=$(python -m executorch.examples.models.gemma4_31b.inference \
--prequantized "$LOCAL_MODEL_DIR" \
--gguf "$GGUF_PATH" \
--tokenizer-path "${LOCAL_MODEL_DIR}/tokenizer.json" \
--prompt "What is the capital of France?" \
--max-new-tokens 32 \
--temperature 0 \
Expand All @@ -494,13 +500,13 @@ if [ "$MODEL_NAME" = "gemma4_31b" ]; then
echo "::endgroup::"

# Copy tokenizer for the runner
cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json"
cp "${LOCAL_MODEL_DIR}/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json"

# Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues)
echo "::group::Export"
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
python -m executorch.examples.models.gemma4_31b.export \
--prequantized "$LOCAL_MODEL_DIR" \
--gguf "$GGUF_PATH" \
--output-dir "${OUTPUT_DIR}"
echo "::endgroup::"

Expand Down
1 change: 1 addition & 0 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ if(CMAKE_CUDA_COMPILER)
_aoti_cuda_shim_sources
runtime/shims/int4mm.cu
runtime/shims/int4_plain_mm.cu
runtime/shims/int6_plain_mm.cu
runtime/shims/int8_plain_mm.cu
runtime/shims/sort.cu
runtime/shims/rand.cu
Expand Down
7 changes: 7 additions & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
"aoti_torch_cuda_randint_low_out": None,
"executorch_cuda::int4_plain_mm": None,
"aoti_torch_cuda_int4_plain_mm": None,
"executorch_cuda::int6_plain_mm": None,
"aoti_torch_cuda_int6_plain_mm": None,
"executorch_cuda::int8_plain_mm": None,
"aoti_torch_cuda_int8_plain_mm": None,
}
Expand Down Expand Up @@ -314,6 +316,11 @@ def get_aoti_compile_options(
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
],
torch.ops.executorch_cuda.int6_plain_mm.default: [
"AOTITorchError aoti_torch_cuda_int6_plain_mm("
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
"AtenTensorHandle, int64_t, AtenTensorHandle*)"
],
torch.ops.executorch_cuda.int8_plain_mm.default: [
"AOTITorchError aoti_torch_cuda_int8_plain_mm("
"AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, "
Expand Down
207 changes: 207 additions & 0 deletions backends/cuda/packed_int6_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""ExecuTorch-internal packed-INT6 tensor for the CUDA W6A8 dp4a decode kernel.

``CudaPackedInt6Tensor`` is an ExecuTorch-internal tensor subclass that stores a
genuine 6-bit packed weight (0.75 B/elem), used for GGUF Q6_K weights. Unlike
the int8 path (``IntxUnpackedToInt8Tensor``, one int8 per 6-bit value), this
format wastes no bits and carries no zero tensor — Q6_K is symmetric.

The stored value is ``u = q + 32`` in ``[0, 63]`` (``q`` in ``[-32, 31]``); the
constant ``-32`` offset is applied in the decode kernel. The 6 bits are split
into two planes that mirror the INT4 nibble layout so the kernel can reuse the
INT4 even/odd extraction verbatim:

ql : (N, K/2) uint8 — low-nibble plane, nibble-packed even/odd
(``ql[:, j] = lo[:, 2j] | (lo[:, 2j+1] << 4)``, ``lo = u & 0xF``).
qh : (N, K/4) uint8 — high-2-bit plane, 4 values/byte, arranged per
32-weight chunk as ``hi_even_packed[4]`` then ``hi_odd_packed[4]``;
each byte holds the four 2-bit highs (``hi = (u >> 4) & 0x3``) of one
8-weight dp4a word, bit field ``j`` (bits ``2j..2j+1``) = the high 2
bits of that word's ``j``-th even/odd weight.
scale : (N, K/gs) bf16 — per-group scales, row-major (already coalesced; the
decode kernel reads it row-for-row, no transpose).

The pack/unpack helpers (:func:`pack_int6`, :func:`unpack_int6`) must stay in
lockstep with ``int6_plain_mm.cuh`` (the decode kernel) — the per-32-weight
``hi_even``/``hi_odd`` byte order is the single most error-prone detail and is
covered by the pack round-trip and the C++ gtest.
"""

from typing import List, Optional, Tuple

import torch
from torchao.utils import TorchAOBaseTensor

__all__ = [
"CudaPackedInt6Tensor",
"pack_int6",
"unpack_int6",
]


def pack_int6(q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pack symmetric Q6_K int values into the (ql, qh) planes.

Args:
q: (N, K) integer tensor with values in ``[-32, 31]``.

Returns:
``(ql, qh)`` where ``ql`` is ``(N, K/2)`` uint8 and ``qh`` is
``(N, K/4)`` uint8 (see the module docstring for the layout).
"""
if q.dim() != 2:
raise ValueError(f"pack_int6 expects a 2-D tensor, got shape {tuple(q.shape)}")
N, K = int(q.shape[0]), int(q.shape[1])
if K % 32 != 0:
raise ValueError(f"K={K} must be a multiple of 32 for INT6 packing")

# All intermediates are uint8 (values fit in a byte) to keep peak memory low
# — important for the ~1.4B-element tied token embedding.
u = (q.to(torch.int16) + 32).to(torch.uint8) # [0, 63]
lo = u & 0xF # low nibble (uint8)
hi = (u >> 4) & 0x3 # high 2 bits (uint8)

# ql: nibble-pack the low plane even/odd, exactly like the INT4 path.
ql = lo[:, 0::2] | (lo[:, 1::2] << 4) # (N, K/2) uint8

# qh: per 32-weight chunk -> [hi_even_packed[4], hi_odd_packed[4]]; each byte
# packs the four 2-bit highs of one 8-weight dp4a word, field j at bits 2j.
chunks = K // 32
hw = hi.reshape(N, chunks, 4, 8) # (N, chunk, word, pos-in-word)
even = hw[..., 0::2] # (N, chunk, 4, 4) positions 0,2,4,6
odd = hw[..., 1::2] # (N, chunk, 4, 4) positions 1,3,5,7
# Explicit OR (not sum) keeps the result uint8 (torch.sum would promote).
hi_even_byte = (
even[..., 0] | (even[..., 1] << 2) | (even[..., 2] << 4) | (even[..., 3] << 6)
) # (N, chunk, 4) uint8
hi_odd_byte = (
odd[..., 0] | (odd[..., 1] << 2) | (odd[..., 2] << 4) | (odd[..., 3] << 6)
)
qh = torch.cat([hi_even_byte, hi_odd_byte], dim=-1) # (N, chunk, 8) uint8
qh = qh.reshape(N, K // 4)
return ql.contiguous(), qh.contiguous()


def unpack_int6(ql: torch.Tensor, qh: torch.Tensor, N: int, K: int) -> torch.Tensor:
"""Inverse of :func:`pack_int6`. Returns ``(N, K)`` int16 q in ``[-32, 31]``.

Intermediates are uint8 to keep peak memory low; only the final ``- 32`` shift
(which produces negatives) widens to int16.
"""
qlu = ql.to(torch.uint8)
lo_even = qlu & 0xF # low nibble -> even weights
lo_odd = (qlu >> 4) & 0xF # high nibble -> odd weights
lo = torch.stack([lo_even, lo_odd], dim=-1).reshape(N, K) # uint8

chunks = K // 32
qhu = qh.to(torch.uint8).reshape(N, chunks, 8)
hi_even_byte = qhu[:, :, 0:4] # (N, chunk, 4) word w
hi_odd_byte = qhu[:, :, 4:8] # (N, chunk, 4)
hi_even = torch.stack(
[(hi_even_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1
) # (N, chunk, 4, 4) uint8
hi_odd = torch.stack([(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1)
hi = torch.empty(N, chunks, 4, 8, dtype=torch.uint8, device=ql.device)
hi[..., 0::2] = hi_even
hi[..., 1::2] = hi_odd
hi = hi.reshape(N, K)

u = lo | (hi << 4) # [0, 63] uint8
return u.to(torch.int16) - 32


class CudaPackedInt6Tensor(TorchAOBaseTensor):
"""Packed 6-bit weight (ql/qh planes + per-group scale), symmetric.

ExecuTorch-internal; see the module docstring. The CUDA decode/prefill
dispatch (``int6_dispatch.py``) is selected by *type* — it is registered on
this class only.
"""

tensor_data_names = ["ql", "qh", "scale"]
tensor_attribute_names = ["block_size", "shape"]

def __new__(
cls,
ql: torch.Tensor,
qh: torch.Tensor,
scale: torch.Tensor,
block_size: List[int],
shape: torch.Size,
):
kwargs = {}
kwargs["device"] = ql.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
ql: torch.Tensor,
qh: torch.Tensor,
scale: torch.Tensor,
block_size: List[int],
shape: torch.Size,
):
super().__init__()
self.ql = ql
self.qh = qh
self.scale = scale
self.block_size = block_size

def _quantization_type(self):
return (
f"shape={self.shape}, block_size={self.block_size}, "
f"device={self.device}"
)

@classmethod
def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor":
"""Build from a torchao ``IntxUnpackedToInt8Tensor`` decoded from Q6_K.

The source is symmetric (zero_point == 0), ``qdata`` is int8 in
``[-32, 31]`` and ``scale`` is ``(N, K/16)``. The ql/qh bit-pack is baked
into the serialized weight constant here, once at pack time.
"""
q = t.qdata
if not bool(torch.all(t.zero_point == 0)):
raise ValueError(
"CudaPackedInt6Tensor.from_intx_int8 requires symmetric Q6_K "
"weights (zero_point == 0)"
)
q_min, q_max = int(q.min()), int(q.max())
if q_min < -32 or q_max > 31:
raise ValueError(
f"Q6_K values must be in [-32, 31], got [{q_min}, {q_max}]"
)
ql, qh = pack_int6(q)
return cls(
ql,
qh,
t.scale.contiguous(),
list(t.block_size),
t.shape,
)

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize to a dense tensor (symmetric: ``w = q * scale``).

Used for the tied lm_head / token embedding (which can't gather a packed
tensor) and as the numerical reference.
"""
dtype = output_dtype if output_dtype is not None else self.scale.dtype
N, K = int(self.shape[0]), int(self.shape[1])
gs = self.block_size[-1]
q = unpack_int6(self.ql, self.qh, N, K).to(dtype)
scale = self.scale.to(dtype).repeat_interleave(gs, dim=-1)
return (q * scale).to(dtype)


# Allow a model with CudaPackedInt6Tensor weights to be loaded with
# `weights_only=True` (mirrors torchao quantized tensors).
torch.serialization.add_safe_globals([CudaPackedInt6Tensor])
5 changes: 4 additions & 1 deletion backends/cuda/quantize_op_dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
dequant logic instead of torchao's defaults. It registers:

* INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm``
* INT6 (``CudaPackedInt6Tensor``) → ``executorch_cuda::int6_plain_mm``
* INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm``

See ``int4_dispatch`` and ``int8_dispatch`` for the per-dtype details.
See ``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` for the per-dtype
details.

Import this package before using nn.Linear with quantized weights::

Expand All @@ -22,5 +24,6 @@

from executorch.backends.cuda.quantize_op_dispatch import ( # noqa: F401
int4_dispatch,
int6_dispatch,
int8_dispatch,
)
7 changes: 4 additions & 3 deletions backends/cuda/quantize_op_dispatch/_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

"""Shared torch.library handle for the ``executorch_cuda`` op namespace.

``int4_dispatch`` and ``int8_dispatch`` both register custom ops into the same
``executorch_cuda`` namespace, so they must share a single ``DEF`` library
instance — PyTorch allows only one ``DEF`` per namespace per process.
``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` all register custom
ops into the same ``executorch_cuda`` namespace, so they must share a single
``DEF`` library instance — PyTorch allows only one ``DEF`` per namespace per
process.
"""

from torch.library import Library
Expand Down
Loading
Loading