From d16fc716e335f053a903891f3ba2081d7fa1d317 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 2 Sep 2025 07:34:41 +0000 Subject: [PATCH 1/8] add a100_qlinear.py Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/tritonv2.py | 21 ++- .../nn_modules/triton_utils/a100_qlinear.py | 120 ++++++++++++++++++ 2 files changed, 136 insertions(+), 5 deletions(-) create mode 100644 gptqmodel/nn_modules/triton_utils/a100_qlinear.py diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index bd81f40aa..59115b673 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -19,6 +19,7 @@ import torch from packaging import version +from .a100_linear import a100_qlinear from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...utils.backend import BACKEND @@ -148,15 +149,25 @@ def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) - out = QuantLinearFunction.apply( + # out = QuantLinearFunction.apply( + # x.reshape(-1, x.shape[-1]), + # self.qweight, + # self.scales, + # self.qzeros, + # self.g_idx, + # self.bits, + # self.pack_dtype_bits, + # self.maxq, + # ).reshape(out_shape) + + block_size_m = x.shape[0] + # TODO test a100_qlinear + out = a100_qlinear.apply( x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, - self.g_idx, - self.bits, - self.pack_dtype_bits, - self.maxq, + block_size_m, ).reshape(out_shape) if self.bias is not None: diff --git a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py new file mode 100644 index 000000000..8d6cd2018 --- /dev/null +++ b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py @@ -0,0 +1,120 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit() +def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, + m, n, k, + block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, + group_size_m: tl.constexpr, + ): + pid = tl.program_id(0) + + total_blocks_m = tl.cdiv(m, block_size_m) + total_blocks_n = tl.cdiv(n, block_size_n) + total_blocks_k = tl.cdiv(k, block_size_k) + + num_blocks_in_group = group_size_m * total_blocks_n + group_id = pid // num_blocks_in_group + group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) + + pid_m = group_id * group_size_m + (pid % group_size) + pid_n = (pid % num_blocks_in_group) // (group_size) + + offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m + offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n + + offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) + offs_k = tl.arange(0, block_size_k) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) + + scales_ptrs = scales_ptr + offs_bn * stride_scales_n + zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) + + shifter = (offs_k % 8) * 4 + zeros_shifter = (offs_bn % 8) * 4 + + output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) + for k in range(0, total_blocks_k): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + g_id = k // (groupsize // block_size_k) + + ptr = scales_ptrs + g_id * stride_scales_g + scales = tl.load(ptr) + + ptr = zeros_ptrs + g_id * stride_zeros_g + zeros = tl.load(ptr) + + zeros = (zeros >> zeros_shifter) & 0xF + zeros = (zeros + 1) * scales + + b = (b >> shifter[:, None]) & 0xF # b -> int32 + b = b * scales[None, :] - zeros[None, :] # b -> fp16 + + output += tl.dot(a, b) + a_ptrs += stride_ak * block_size_k + b_ptrs += (block_size_k // 8) * stride_bk + + output.to(tl.float16) + offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) + offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) + c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + tl.store(c_ptrs, output) + + +class a100_qlinear(torch.autograd.Function): + def forward(ctx, a, b, scales, zeros): + m, k = a.shape + _, n = b.shape + + quant_groupsize = 128 + block_size_m = 16 + block_size_n = 32 # [N = 4096 // 32] = 128 blocks + block_size_k = 256 + group_size_m = 8 + num_warps = 4 + num_stages = 8 + total_blocks_m = triton.cdiv(m, block_size_m) + total_blocks_n = triton.cdiv(n, block_size_n) + total_programs = total_blocks_m * total_blocks_n + grid = (total_programs, 1) + + c = torch.zeros((m, n), device=b.device, dtype=torch.float16) + k = _a100_quantized_matmul[grid]( + a, b, c, scales, zeros, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + quant_groupsize, + m, n, k, + block_size_m, block_size_n, block_size_k, group_size_m, + num_warps=num_warps, num_stages=num_stages, + ) + + print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n") + + with open('dequant_simple.txt', 'w') as f: + print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) + print("IR", k.asm['ttir'], file=f) + print("TTGIR", k.asm['ttgir'], file=f) + print("PTX", k.asm['ptx'], file=f) + print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) + + print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") + return c + + +a100_qlinear = a100_qlinear.apply \ No newline at end of file From b75da795b40c489b4468c918b5bc2022bc218d76 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 2 Sep 2025 07:35:46 +0000 Subject: [PATCH 2/8] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/triton_utils/a100_qlinear.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py index 8d6cd2018..a519deeb8 100644 --- a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py +++ b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py @@ -115,6 +115,3 @@ def forward(ctx, a, b, scales, zeros): print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") return c - - -a100_qlinear = a100_qlinear.apply \ No newline at end of file From d01de45d9306402beb5e83261816a368401a0b27 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 2 Sep 2025 07:39:01 +0000 Subject: [PATCH 3/8] cleanup Signed-off-by: ZX-ModelCloud --- .../nn_modules/triton_utils/a100_qlinear.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py index a519deeb8..8f5c96f07 100644 --- a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py +++ b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py @@ -74,17 +74,23 @@ def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, class a100_qlinear(torch.autograd.Function): - def forward(ctx, a, b, scales, zeros): + def forward(ctx, a, b, scales, zeros, group_size, block_size_m, ): m, k = a.shape _, n = b.shape - quant_groupsize = 128 - block_size_m = 16 - block_size_n = 32 # [N = 4096 // 32] = 128 blocks - block_size_k = 256 + # quant_groupsize = 128 + # block_size_m = 16 + # block_size_n = 32 # [N = 4096 // 32] = 128 blocks + # block_size_k = 256 + # group_size_m = 8 + # num_warps = 4 + # num_stages = 8 + block_size_n = b.shape[1] // group_size + block_size_k = b.shape[0] // group_size group_size_m = 8 num_warps = 4 num_stages = 8 + total_blocks_m = triton.cdiv(m, block_size_m) total_blocks_n = triton.cdiv(n, block_size_n) total_programs = total_blocks_m * total_blocks_n @@ -98,7 +104,7 @@ def forward(ctx, a, b, scales, zeros): c.stride(0), c.stride(1), scales.stride(0), scales.stride(1), zeros.stride(0), zeros.stride(1), - quant_groupsize, + group_size, m, n, k, block_size_m, block_size_n, block_size_k, group_size_m, num_warps=num_warps, num_stages=num_stages, From 7ce21b8cef6556b5420da4708012562cbfc8396f Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 2 Sep 2025 08:08:50 +0000 Subject: [PATCH 4/8] fix import error Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/tritonv2.py | 14 ++++++------- .../nn_modules/triton_utils/a100_qlinear.py | 20 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 59115b673..4582112f2 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -13,13 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional, Tuple import torch from packaging import version -from .a100_linear import a100_qlinear +from ..triton_utils.a100_qlinear import a100_qlinear from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...utils.backend import BACKEND @@ -29,15 +28,15 @@ try: # TODO: triton is not compatible with free threading - if not has_gil_disabled(): - raise Exception("GIL is disabled so Triton is not (yet) compatible.") + # if not has_gil_disabled(): + # raise Exception("GIL is disabled so Triton is not (yet) compatible.") import triton import triton.language as tl from triton import __version__ as triton_version - from ..triton_utils.dequant import QuantLinearFunction - from ..triton_utils.mixin import TritonModuleMixin + # from ..triton_utils.dequant import QuantLinearFunction + # from ..triton_utils.mixin import TritonModuleMixin if version.parse(triton_version) < version.parse("2.0.0"): raise ImportError(f"triton version must be >= 2.0.0: actual = {triton_version}") TRITON_AVAILABLE = True @@ -52,7 +51,7 @@ class TritonModuleMixin: log = setup_logger() -class TritonV2QuantLinear(TorchQuantLinear, TritonModuleMixin): +class TritonV2QuantLinear(TorchQuantLinear): SUPPORTS_BITS = [2, 4, 8] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] SUPPORTS_DESC_ACT = [True, False] @@ -167,6 +166,7 @@ def forward(self, x): self.qweight, self.scales, self.qzeros, + self.group_size, block_size_m, ).reshape(out_shape) diff --git a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py index 8f5c96f07..cd075eb27 100644 --- a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py +++ b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py @@ -110,14 +110,14 @@ def forward(ctx, a, b, scales, zeros, group_size, block_size_m, ): num_warps=num_warps, num_stages=num_stages, ) - print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n") - - with open('dequant_simple.txt', 'w') as f: - print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) - print("IR", k.asm['ttir'], file=f) - print("TTGIR", k.asm['ttgir'], file=f) - print("PTX", k.asm['ptx'], file=f) - print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) - - print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") + # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n") + # + # with open('dequant_simple.txt', 'w') as f: + # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) + # print("IR", k.asm['ttir'], file=f) + # print("TTGIR", k.asm['ttgir'], file=f) + # print("PTX", k.asm['ptx'], file=f) + # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) + # + # print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") return c From 5f3ce2693f49bf338120f0e6a6e595ec5ba9d65f Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 3 Sep 2025 01:01:12 +0000 Subject: [PATCH 5/8] revert tritonv2.py changes Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/tritonv2.py | 34 ++++++++---------------- gptqmodel/utils/importer.py | 4 ++- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 4582112f2..8277aac42 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Optional, Tuple import torch from packaging import version -from ..triton_utils.a100_qlinear import a100_qlinear from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...utils.backend import BACKEND @@ -28,15 +28,15 @@ try: # TODO: triton is not compatible with free threading - # if not has_gil_disabled(): - # raise Exception("GIL is disabled so Triton is not (yet) compatible.") + if not has_gil_disabled(): + raise Exception("GIL is disabled so Triton is not (yet) compatible.") import triton import triton.language as tl from triton import __version__ as triton_version - # from ..triton_utils.dequant import QuantLinearFunction - # from ..triton_utils.mixin import TritonModuleMixin + from ..triton_utils.dequant import QuantLinearFunction + from ..triton_utils.mixin import TritonModuleMixin if version.parse(triton_version) < version.parse("2.0.0"): raise ImportError(f"triton version must be >= 2.0.0: actual = {triton_version}") TRITON_AVAILABLE = True @@ -51,7 +51,7 @@ class TritonModuleMixin: log = setup_logger() -class TritonV2QuantLinear(TorchQuantLinear): +class TritonV2QuantLinear(TorchQuantLinear, TritonModuleMixin): SUPPORTS_BITS = [2, 4, 8] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] SUPPORTS_DESC_ACT = [True, False] @@ -148,26 +148,15 @@ def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) - # out = QuantLinearFunction.apply( - # x.reshape(-1, x.shape[-1]), - # self.qweight, - # self.scales, - # self.qzeros, - # self.g_idx, - # self.bits, - # self.pack_dtype_bits, - # self.maxq, - # ).reshape(out_shape) - - block_size_m = x.shape[0] - # TODO test a100_qlinear - out = a100_qlinear.apply( + out = QuantLinearFunction.apply( x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, - self.group_size, - block_size_m, + self.g_idx, + self.bits, + self.pack_dtype_bits, + self.maxq, ).reshape(out_shape) if self.bias is not None: @@ -219,4 +208,3 @@ def triton_xpu_available(): except Exception: return False - diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index c3c9322bf..76b4ccad3 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -32,6 +32,7 @@ from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear from ..nn_modules.qlinear.torch_fused import TorchFusedQuantLinear +from ..nn_modules.qlinear.triton_a100 import TritonA100QuantLinear from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear from ..quantization import FORMAT from ..utils.logger import setup_logger @@ -225,7 +226,8 @@ def select_quant_linear( if backend == BACKEND.TRITON: if not TRITON_AVAILABLE: raise ValueError(TRITON_INSTALL_HINT) - qlinear = TritonV2QuantLinear + # qlinear = TritonV2QuantLinear + qlinear = TritonA100QuantLinear elif backend == BACKEND.BITBLAS: qlinear = BitBLASQuantLinear elif backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16]: From 20540629e2e5874e85349902dd27d50a75c2430b Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 3 Sep 2025 01:06:05 +0000 Subject: [PATCH 6/8] add TritonA100QuantLinear Signed-off-by: ZX-ModelCloud --- gptqmodel/models/loader.py | 3 +- gptqmodel/nn_modules/qlinear/triton_a100.py | 161 ++++++++++++++++++ .../nn_modules/triton_utils/a100_qlinear.py | 1 + gptqmodel/utils/importer.py | 7 +- 4 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 gptqmodel/nn_modules/qlinear/triton_a100.py diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index c68261162..fe163d332 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -274,7 +274,8 @@ def from_quantized( os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER' if backend == BACKEND.TRITON: - from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT + # from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT + from ..nn_modules.qlinear.triton_a100 import TRITON_AVAILABLE, TRITON_INSTALL_HINT if not TRITON_AVAILABLE: raise ValueError(TRITON_INSTALL_HINT) diff --git a/gptqmodel/nn_modules/qlinear/triton_a100.py b/gptqmodel/nn_modules/qlinear/triton_a100.py new file mode 100644 index 000000000..fe426347c --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/triton_a100.py @@ -0,0 +1,161 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ..triton_utils.a100_qlinear import a100_qlinear +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import PackableQuantLinear +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from packaging import version + +log = setup_logger() + +try: + # TODO: triton is not compatible with free threading + # if not has_gil_disabled(): + # raise Exception("GIL is disabled so Triton is not (yet) compatible.") + + import triton + import triton.language as tl + from triton import __version__ as triton_version + + if version.parse(triton_version) < version.parse("2.0.0"): + raise ImportError(f"triton version must be >= 2.0.0: actual = {triton_version}") + TRITON_AVAILABLE = True +except BaseException: + TRITON_AVAILABLE = False + +TRITON_INSTALL_HINT = "Trying to use the triton backend, but it could not be imported. Please install triton by 'pip install gptqmodel[triton] --no-build-isolation'" +TRITON_XPU_INSTALL_HINT = "Trying to use the triton backend and xpu device, but it could not be imported. Please install triton by [intel-xpu-backend-for-triton](https://github.com/intel/intel-xpu-backend-for-triton)" + +class TritonA100QuantLinear(PackableQuantLinear): + SUPPORTS_BITS = [2, 4, 8] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] + + SUPPORTS_DEVICES = [DEVICE.CUDA] # Intel XPU can use Triton but this has been validated + SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] + SUPPORTS_PACK_DTYPES = [torch.int32, torch.int16, torch.int8] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + # for transformers/optimum tests compat + QUANT_TYPE = "triton_a100" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.TORCH), + adapter=adapter, + register_buffers=register_buffers, + **kwargs) + + self.dequant_dtype = torch.int16 if self.bits == 8 else torch.int8 + + # if self.group_size != self.in_features: + # self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) + # else: + # self.padded_infeatures = self.in_features + + def post_init(self): + # if self.padded_infeatures != self.in_features: + # self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) + # self.qzeros.resize_( + # math.ceil(self.padded_infeatures / self.group_size), + # self.out_features // self.pack_dtype_bits * self.bits + # ) + # self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) + # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, + # device=self.g_idx.device) + + super().post_init() + + # torch benefits the most from torch.compile, enable it by default + self.optimize() + + def optimize(self, backend: str = None, mode: str = None, fullgraph: bool = False): + if self.optimized: + return + + if backend is None: + # MPS doesn't support inductor. + backend = "inductor" if self.list_buffers()[0].device.type != "mps" else "aot_eager" + + # # compile dequantize + # self.dequantize_weight = torch_compile(self.dequantize_weight, backend=backend, mode=mode, fullgraph=fullgraph) + + if self.adapter: + self.adapter.optimize(backend=backend, mode=mode, fullgraph=fullgraph) + + super().optimize() + + def forward(self, x: torch.Tensor): + if self.training: + return super().forward(x) + + out_shape = x.shape[:-1] + (self.out_features,) + + block_size_m = x.shape[0] + # TODO test a100_qlinear + out = a100_qlinear.apply( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + self.group_size, + block_size_m, + ).reshape(out_shape) + + if self.bias is not None: + out.add_(self.bias) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out.to(dtype=x.dtype) + +__all__ = ["TritonA100QuantLinear"] diff --git a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py index cd075eb27..a4fbb231a 100644 --- a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py +++ b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py @@ -96,6 +96,7 @@ def forward(ctx, a, b, scales, zeros, group_size, block_size_m, ): total_programs = total_blocks_m * total_blocks_n grid = (total_programs, 1) + print("aa bb", a.dtype, a.shape, b.dtype, b.shape) c = torch.zeros((m, n), device=b.device, dtype=torch.float16) k = _a100_quantized_matmul[grid]( a, b, c, scales, zeros, diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 76b4ccad3..6513dccdc 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -32,8 +32,8 @@ from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear from ..nn_modules.qlinear.torch_fused import TorchFusedQuantLinear -from ..nn_modules.qlinear.triton_a100 import TritonA100QuantLinear -from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear +from ..nn_modules.qlinear.triton_a100 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonA100QuantLinear +# from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear from ..quantization import FORMAT from ..utils.logger import setup_logger from . import BACKEND @@ -49,7 +49,8 @@ BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, # optimized for bs > 1 BACKEND.EXLLAMA_V1: ExllamaQuantLinear, # optimized for bs == 1 BACKEND.TORCH_FUSED: TorchFusedQuantLinear, # optimized for Intel XPU - BACKEND.TRITON: TritonV2QuantLinear, # good all around kernel that JIT compiles + # BACKEND.TRITON: TritonV2QuantLinear, # good all around kernel that JIT compiles + BACKEND.TRITON: TritonA100QuantLinear, # good all around kernel that JIT compiles # BACKEND.CUDA: DynamicCudaQuantLinear, BACKEND.IPEX: IPEXQuantLinear, # best kernel Intel XPU and CPU with amx/avx512/xmx BACKEND.BITBLAS: BitBLASQuantLinear, # super slow AOT pre-compiler but fastest for bs=1 From 7a53e3baf155f4219f8759c0e47cbc72d084bce7 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 3 Sep 2025 01:35:48 +0000 Subject: [PATCH 7/8] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/triton_a100.py | 117 ++++++++++++++++- .../nn_modules/triton_utils/a100_qlinear.py | 124 ------------------ 2 files changed, 115 insertions(+), 126 deletions(-) delete mode 100644 gptqmodel/nn_modules/triton_utils/a100_qlinear.py diff --git a/gptqmodel/nn_modules/qlinear/triton_a100.py b/gptqmodel/nn_modules/qlinear/triton_a100.py index fe426347c..665a6e6e9 100644 --- a/gptqmodel/nn_modules/qlinear/triton_a100.py +++ b/gptqmodel/nn_modules/qlinear/triton_a100.py @@ -17,7 +17,6 @@ import torch -from ..triton_utils.a100_qlinear import a100_qlinear from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import PackableQuantLinear @@ -147,7 +146,6 @@ def forward(self, x: torch.Tensor): self.scales, self.qzeros, self.group_size, - block_size_m, ).reshape(out_shape) if self.bias is not None: @@ -159,3 +157,118 @@ def forward(self, x: torch.Tensor): return out.to(dtype=x.dtype) __all__ = ["TritonA100QuantLinear"] + +@triton.jit() +def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + stride_scales_g, stride_scales_n, + stride_zeros_g, stride_zeros_n, + groupsize, + m, n, k, + block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, + group_size_m: tl.constexpr, + ): + pid = tl.program_id(0) + + total_blocks_m = tl.cdiv(m, block_size_m) + total_blocks_n = tl.cdiv(n, block_size_n) + total_blocks_k = tl.cdiv(k, block_size_k) + + num_blocks_in_group = group_size_m * total_blocks_n + group_id = pid // num_blocks_in_group + group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) + + pid_m = group_id * group_size_m + (pid % group_size) + pid_n = (pid % num_blocks_in_group) // (group_size) + + offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m + offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n + + offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) + offs_k = tl.arange(0, block_size_k) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) + + scales_ptrs = scales_ptr + offs_bn * stride_scales_n + zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) + + shifter = (offs_k % 8) * 4 + zeros_shifter = (offs_bn % 8) * 4 + + output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) + for k in range(0, total_blocks_k): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + g_id = k // (groupsize // block_size_k) + + ptr = scales_ptrs + g_id * stride_scales_g + scales = tl.load(ptr) + + ptr = zeros_ptrs + g_id * stride_zeros_g + zeros = tl.load(ptr) + + zeros = (zeros >> zeros_shifter) & 0xF + zeros = (zeros + 1) * scales + + b = (b >> shifter[:, None]) & 0xF # b -> int32 + b = b * scales[None, :] - zeros[None, :] # b -> fp16 + + output += tl.dot(a, b) + a_ptrs += stride_ak * block_size_k + b_ptrs += (block_size_k // 8) * stride_bk + + output.to(tl.float16) + offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) + offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) + c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) + tl.store(c_ptrs, output) + + +class a100_qlinear(torch.autograd.Function): + def forward(ctx, a, b, scales, zeros, group_size): + m, k = a.shape + _, n = b.shape + + # quant_groupsize = 128 + quant_groupsize = group_size + block_size_m = 16 + block_size_n = 32 # [N = 4096 // 32] = 128 blocks + block_size_k = 256 + group_size_m = 8 + num_warps = 4 + num_stages = 4 + total_blocks_m = triton.cdiv(m, block_size_m) + total_blocks_n = triton.cdiv(n, block_size_n) + total_programs = total_blocks_m * total_blocks_n + grid = (total_programs, 1) + + c = torch.zeros((m, n), device=b.device, dtype=torch.float16) + k = _a100_quantized_matmul[grid]( + a, b, c, scales, zeros, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + scales.stride(0), scales.stride(1), + zeros.stride(0), zeros.stride(1), + quant_groupsize, + m, n, k, + block_size_m, block_size_n, block_size_k, group_size_m, + num_warps=num_warps, num_stages=num_stages, + ) + + # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n") + # + # with open('dequant_simple.txt', 'w') as f: + # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) + # print("IR", k.asm['ttir'], file=f) + # print("TTGIR", k.asm['ttgir'], file=f) + # print("PTX", k.asm['ptx'], file=f) + # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) + # + # print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") + return c + diff --git a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py b/gptqmodel/nn_modules/triton_utils/a100_qlinear.py deleted file mode 100644 index a4fbb231a..000000000 --- a/gptqmodel/nn_modules/triton_utils/a100_qlinear.py +++ /dev/null @@ -1,124 +0,0 @@ -import triton -import triton.language as tl -import torch - - -@triton.jit() -def _a100_quantized_matmul(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, - m, n, k, - block_size_m: tl.constexpr, block_size_n: tl.constexpr, block_size_k: tl.constexpr, - group_size_m: tl.constexpr, - ): - pid = tl.program_id(0) - - total_blocks_m = tl.cdiv(m, block_size_m) - total_blocks_n = tl.cdiv(n, block_size_n) - total_blocks_k = tl.cdiv(k, block_size_k) - - num_blocks_in_group = group_size_m * total_blocks_n - group_id = pid // num_blocks_in_group - group_size = min(total_blocks_m - group_id * group_size_m, group_size_m) - - pid_m = group_id * group_size_m + (pid % group_size) - pid_n = (pid % num_blocks_in_group) // (group_size) - - offs_m = (pid_m * block_size_m + tl.arange(0, block_size_m)) % m - offs_n = (pid_n * block_size_n + tl.arange(0, block_size_n)) % n - - offs_am = tl.max_contiguous(tl.multiple_of(offs_m, block_size_m), block_size_m) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, block_size_n), block_size_n) - offs_k = tl.arange(0, block_size_k) - - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn) - - scales_ptrs = scales_ptr + offs_bn * stride_scales_n - zeros_ptrs = zeros_ptr + ((offs_bn // 8) * stride_zeros_n) - - shifter = (offs_k % 8) * 4 - zeros_shifter = (offs_bn % 8) * 4 - - output = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) - for k in range(0, total_blocks_k): - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - g_id = k // (groupsize // block_size_k) - - ptr = scales_ptrs + g_id * stride_scales_g - scales = tl.load(ptr) - - ptr = zeros_ptrs + g_id * stride_zeros_g - zeros = tl.load(ptr) - - zeros = (zeros >> zeros_shifter) & 0xF - zeros = (zeros + 1) * scales - - b = (b >> shifter[:, None]) & 0xF # b -> int32 - b = b * scales[None, :] - zeros[None, :] # b -> fp16 - - output += tl.dot(a, b) - a_ptrs += stride_ak * block_size_k - b_ptrs += (block_size_k // 8) * stride_bk - - output.to(tl.float16) - offs_cm = pid_m * block_size_m + tl.arange(0, block_size_m) - offs_cn = pid_n * block_size_n + tl.arange(0, block_size_n) - c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn) - tl.store(c_ptrs, output) - - -class a100_qlinear(torch.autograd.Function): - def forward(ctx, a, b, scales, zeros, group_size, block_size_m, ): - m, k = a.shape - _, n = b.shape - - # quant_groupsize = 128 - # block_size_m = 16 - # block_size_n = 32 # [N = 4096 // 32] = 128 blocks - # block_size_k = 256 - # group_size_m = 8 - # num_warps = 4 - # num_stages = 8 - block_size_n = b.shape[1] // group_size - block_size_k = b.shape[0] // group_size - group_size_m = 8 - num_warps = 4 - num_stages = 8 - - total_blocks_m = triton.cdiv(m, block_size_m) - total_blocks_n = triton.cdiv(n, block_size_n) - total_programs = total_blocks_m * total_blocks_n - grid = (total_programs, 1) - - print("aa bb", a.dtype, a.shape, b.dtype, b.shape) - c = torch.zeros((m, n), device=b.device, dtype=torch.float16) - k = _a100_quantized_matmul[grid]( - a, b, c, scales, zeros, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - scales.stride(0), scales.stride(1), - zeros.stride(0), zeros.stride(1), - group_size, - m, n, k, - block_size_m, block_size_n, block_size_k, group_size_m, - num_warps=num_warps, num_stages=num_stages, - ) - - # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n") - # - # with open('dequant_simple.txt', 'w') as f: - # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) - # print("IR", k.asm['ttir'], file=f) - # print("TTGIR", k.asm['ttgir'], file=f) - # print("PTX", k.asm['ptx'], file=f) - # print(f"{k.n_regs} registers used, {k.n_spills} spills, {k.shared / 1000} kB shared memory\n", file=f) - # - # print(f"{total_blocks_m=} x {total_blocks_n=} = {total_programs=}") - return c From 879f45d688f41d585f617925ca0f5499bbdd770b Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 3 Sep 2025 02:00:28 +0000 Subject: [PATCH 8/8] test TritonA100QuantLinear Signed-off-by: ZX-ModelCloud --- tests/test_kernel_output.py | 166 ++++++++++++++++++------------------ 1 file changed, 84 insertions(+), 82 deletions(-) diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index 1bb1af6ef..b6507a1b6 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -7,7 +7,8 @@ from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +# from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +from gptqmodel.nn_modules.qlinear.triton_a100 import TritonA100QuantLinear from gptqmodel.utils.model import find_modules from logbar import LogBar from parameterized import parameterized @@ -30,7 +31,8 @@ class TestKernelOutput(unittest.TestCase): # BACKEND.EXLLAMA_V1: ExllamaQuantLinear, BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, - BACKEND.TRITON: TritonV2QuantLinear, + # BACKEND.TRITON: TritonV2QuantLinear, + BACKEND.TRITON: TritonA100QuantLinear, BACKEND.TORCH: TorchQuantLinear, # BACKEND.TORCH_FUSED: TorchFusedQuantLinear, # BACKEND.BITBLAS: BitBLASQuantLinear, @@ -116,14 +118,14 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, atol): #torch.allclose(a, b, rtol=0.15, atol=atol) @parameterized.expand([ - (BACKEND.TORCH, torch.float16, 0.0000), + # (BACKEND.TORCH, torch.float16, 0.0000), # (BACKEND.TORCH_FUSED, torch.float16, 0.0001), (BACKEND.TRITON, torch.float16, 0.00001), # (BACKEND.EXLLAMA_V1, torch.float16, 0.0050), - (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), - (BACKEND.MARLIN, torch.float16, 0.00035), + # (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), + # (BACKEND.MARLIN, torch.float16, 0.00035), # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), - (BACKEND.EXLLAMA_EORA, torch.float16, 0.0025), + # (BACKEND.EXLLAMA_EORA, torch.float16, 0.0025), ]) def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): out = self.forward(backend=backend, dtype=dtype) @@ -143,79 +145,79 @@ def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance f"Actual with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {out[i][:10]}") raise AssertionError - @parameterized.expand([ - (BACKEND.TORCH, torch.bfloat16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.bfloat16, 0.0001), - (BACKEND.TRITON, torch.bfloat16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0064), - (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), - (BACKEND.MARLIN, torch.bfloat16, 0.0031), - # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.012), - # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0031), TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 - ]) - def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - out = self.forward(backend=backend, dtype=dtype) - - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - data = self.data[dtype] - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out[i], out[i], - a_tolerance) # use torch as reference - except AssertionError: - log.error( - f"Torch with Lora output: dtype = `{dtype}`, backed = `{BACKEND.TORCH}`, i = `{i}`, {data.torch_kernel_out[i][:10]}") - log.error( - f"Actual with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {out[i][:10]}") - raise AssertionError - - @parameterized.expand([ - (BACKEND.TORCH, torch.float16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.float16, 0.0001), - (BACKEND.TRITON, torch.float16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.float16, 0.0054), - (BACKEND.EXLLAMA_V2, torch.float16, 0.0065), - (BACKEND.MARLIN, torch.float16, 0.00035), - # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), - (BACKEND.EXLLAMA_EORA, torch.float16, 0.0020) - ]) - def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - data = self.data[dtype] - out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) - - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out_with_lora[i], out[i], a_tolerance) # use torch as reference - except AssertionError: - log.error(f"Torch with Lora output: backed = dtype = `{dtype}`, `{backend}`, i = `{i}`, {data.torch_kernel_out_with_lora[i][:10]}") - raise AssertionError - - - @parameterized.expand([ - (BACKEND.TORCH, torch.bfloat16, 0.0000), - # (BACKEND.TORCH_FUSED, torch.bfloat16, 0.0001), - (BACKEND.TRITON, torch.bfloat16, 0.00001), - # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0062), - (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), - (BACKEND.MARLIN, torch.bfloat16, 0.0033), - # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.011), - # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0014) TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 - ]) - def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): - data = self.data[dtype] - out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) - - # torch as ref - pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() - for i in pb: - pb.subtitle(f"backed = `{backend}`").draw() - try: - self.assert_on_mismatch(data.torch_kernel_out_with_lora[i], out[i], a_tolerance) # use torch as reference - except AssertionError: - log.error(f"Torch with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {data.torch_kernel_out_with_lora[i][:10]}") - raise AssertionError + # @parameterized.expand([ + # (BACKEND.TORCH, torch.bfloat16, 0.0000), + # # (BACKEND.TORCH_FUSED, torch.bfloat16, 0.0001), + # (BACKEND.TRITON, torch.bfloat16, 0.00001), + # # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0064), + # (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), + # (BACKEND.MARLIN, torch.bfloat16, 0.0031), + # # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.012), + # # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0031), TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 + # ]) + # def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + # out = self.forward(backend=backend, dtype=dtype) + # + # # torch as ref + # pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() + # for i in pb: + # data = self.data[dtype] + # pb.subtitle(f"backed = `{backend}`").draw() + # try: + # self.assert_on_mismatch(data.torch_kernel_out[i], out[i], + # a_tolerance) # use torch as reference + # except AssertionError: + # log.error( + # f"Torch with Lora output: dtype = `{dtype}`, backed = `{BACKEND.TORCH}`, i = `{i}`, {data.torch_kernel_out[i][:10]}") + # log.error( + # f"Actual with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {out[i][:10]}") + # raise AssertionError + # + # @parameterized.expand([ + # (BACKEND.TORCH, torch.float16, 0.0000), + # # (BACKEND.TORCH_FUSED, torch.float16, 0.0001), + # (BACKEND.TRITON, torch.float16, 0.00001), + # # (BACKEND.EXLLAMA_V1, torch.float16, 0.0054), + # (BACKEND.EXLLAMA_V2, torch.float16, 0.0065), + # (BACKEND.MARLIN, torch.float16, 0.00035), + # # (BACKEND.MARLIN_FP16, torch.float16, 0.0035), + # (BACKEND.EXLLAMA_EORA, torch.float16, 0.0020) + # ]) + # def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + # data = self.data[dtype] + # out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) + # + # # torch as ref + # pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() + # for i in pb: + # pb.subtitle(f"backed = `{backend}`").draw() + # try: + # self.assert_on_mismatch(data.torch_kernel_out_with_lora[i], out[i], a_tolerance) # use torch as reference + # except AssertionError: + # log.error(f"Torch with Lora output: backed = dtype = `{dtype}`, `{backend}`, i = `{i}`, {data.torch_kernel_out_with_lora[i][:10]}") + # raise AssertionError + # + # + # @parameterized.expand([ + # (BACKEND.TORCH, torch.bfloat16, 0.0000), + # # (BACKEND.TORCH_FUSED, torch.bfloat16, 0.0001), + # (BACKEND.TRITON, torch.bfloat16, 0.00001), + # # (BACKEND.EXLLAMA_V1, torch.bfloat16, 0.0062), + # (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), + # (BACKEND.MARLIN, torch.bfloat16, 0.0033), + # # (BACKEND.MARLIN_FP16, torch.bfloat16, 0.011), + # # (BACKEND.EXLLAMA_EORA, torch.bfloat16, 0.0014) TODO FIX, abnormal output when Exllama Eora kernel is using bfloat16 + # ]) + # def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + # data = self.data[dtype] + # out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) + # + # # torch as ref + # pb = log.pb(len(out)).title(f"Actual Kernel Output With Lora {dtype}").manual() + # for i in pb: + # pb.subtitle(f"backed = `{backend}`").draw() + # try: + # self.assert_on_mismatch(data.torch_kernel_out_with_lora[i], out[i], a_tolerance) # use torch as reference + # except AssertionError: + # log.error(f"Torch with Lora output: dtype = `{dtype}`, backed = `{backend}`, i = `{i}`, {data.torch_kernel_out_with_lora[i][:10]}") + # raise AssertionError