diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 32253c8db..fd99d280d 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -13,48 +13,51 @@ patch_safetensors_save_file() patch_triton_autotuner() -from .utils.env import env_flag -from .utils.logger import setup_logger - - -DEBUG_ON = env_flag("DEBUG") - -from .utils.linalg_warmup import run_torch_linalg_warmup -from .utils.threadx import DeviceThreadPool - - -DEVICE_THREAD_POOL = DeviceThreadPool( - inference_mode=True, - warmups={ - "cuda": run_torch_linalg_warmup, - "xpu": run_torch_linalg_warmup, - "mps": run_torch_linalg_warmup, - "cpu": run_torch_linalg_warmup, - }, - workers={ - "cuda:per": 4, - "xpu:per": 1, - "mps": 8, - "cpu": min(12, max(1, (os.cpu_count() or 1) // 2)), - "model_loader:cpu": 2, - }, - empty_cache_every_n=512, -) - -from .models import GPTQModel, get_best_device -from .models.auto import ASCII_LOGO -from .quantization import BaseQuantizeConfig, QuantizeConfig -from .utils import BACKEND -from .utils.exllama import exllama_set_max_input_length -from .version import __version__ - - -setup_logger().info("\n%s", ASCII_LOGO) - - -if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: - try: - from modelscope.utils.hf_util.patcher import patch_hub - patch_hub() - except Exception: - raise ModuleNotFoundError("you have set GPTQMODEL_USE_MODELSCOPE env, but doesn't have modelscope? install it with `pip install modelscope`") +if os.environ.get("GPTQMODEL_SKIP_INIT", "0") == "1": + __all__ = [] +else: + from .utils.env import env_flag + from .utils.logger import setup_logger + + + DEBUG_ON = env_flag("DEBUG") + + from .utils.linalg_warmup import run_torch_linalg_warmup + from .utils.threadx import DeviceThreadPool + + + DEVICE_THREAD_POOL = DeviceThreadPool( + inference_mode=True, + warmups={ + "cuda": run_torch_linalg_warmup, + "xpu": run_torch_linalg_warmup, + "mps": run_torch_linalg_warmup, + "cpu": run_torch_linalg_warmup, + }, + workers={ + "cuda:per": 4, + "xpu:per": 1, + "mps": 8, + "cpu": min(12, max(1, (os.cpu_count() or 1) // 2)), + "model_loader:cpu": 2, + }, + empty_cache_every_n=512, + ) + + from .models import GPTQModel, get_best_device + from .models.auto import ASCII_LOGO + from .quantization import BaseQuantizeConfig, QuantizeConfig + from .utils import BACKEND + from .utils.exllama import exllama_set_max_input_length + from .version import __version__ + + + setup_logger().info("\n%s", ASCII_LOGO) + + + if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: + try: + from modelscope.utils.hf_util.patcher import patch_hub + patch_hub() + except Exception: + raise ModuleNotFoundError("you have set GPTQMODEL_USE_MODELSCOPE env, but doesn't have modelscope? install it with `pip install modelscope`") diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 3ad6b8a51..28b87eae3 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -68,6 +68,10 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset self.export_compatible = False self.version = qcfg.format + if self.version == FORMAT.GEMM: + self.version = FORMAT.GEMM_V2 + elif self.version == FORMAT.GEMV: + self.version = FORMAT.GEMV_V2 # TODO Can it be configured? # The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len. @@ -675,13 +679,13 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time linear_layer.weight.data = wq - if self.version == "gemm": + if self.version in ("gemm", "gemm_v2"): scales = scales.t().contiguous() if zeros is not None: zeros = zeros.t().contiguous() q_linear_module = WQLinear_GEMM - elif self.version == "gemv": + elif self.version in ("gemv", "gemv_v2"): q_linear_module = WQLinear_GEMV elif self.version == "marlin": @@ -790,9 +794,11 @@ def submodule_finalize(self, module: NamedModule, **kwargs): module.state.pop("w", None) # no need for original weights now def finalize(self, model: BaseQModel, **kwargs): - if model.quantize_config.format == FORMAT.GEMM: + if model.quantize_config.format in (FORMAT.GEMM, FORMAT.GEMM_V2): + model.quantize_config.format = FORMAT.GEMM_V2 model.qlinear_kernel = AwqGEMMQuantLinear - elif model.quantize_config.format == FORMAT.GEMV: + elif model.quantize_config.format in (FORMAT.GEMV, FORMAT.GEMV_V2): + model.quantize_config.format = FORMAT.GEMV_V2 model.qlinear_kernel = AwqGEMVQuantLinear elif model.quantize_config.format == FORMAT.GEMV_FAST: model.qlinear_kernel = AwqGEMVFastQuantLinear diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index bcd1ff641..3ce74f29d 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -841,11 +841,10 @@ def quantize( ) if self.quantize_config.quant_method == METHOD.AWQ: - if self.quantize_config.format == FORMAT.GEMV_FAST: - # AWQ GEMV_FAST only supports pack_dtype is torch.int16 - log.info("Quantize Model: Auto fix `pack_dtype` to `torch.int16`") + if self.quantize_config.format in (FORMAT.GEMM, FORMAT.GEMM_V2, FORMAT.GEMV, FORMAT.GEMV_V2, FORMAT.GEMV_FAST): + log.info("Quantize Model: Auto fix `pack_dtype` to `torch.int16` for AWQ layout") self.quantize_config.pack_dtype = torch.int16 - elif self.quantize_config.format == FORMAT.MARLIN: + if self.quantize_config.format == FORMAT.MARLIN: # AWQ MARLIN only supports zero_point is false log.info("Quantize Model: Auto fix `zero_point` to `False`") self.quantize_config.zero_point = False diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 28baf7fe2..619887657 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -387,7 +387,7 @@ def from_quantized( if backend == BACKEND.VLLM or backend == BACKEND.SGLANG: if backend == BACKEND.VLLM: - if qcfg.format != FORMAT.GPTQ and qcfg.format != FORMAT.GEMM: + if qcfg.format not in (FORMAT.GPTQ, FORMAT.GEMM, FORMAT.GEMM_V2): raise ValueError(f"{backend} backend only supports FORMAT.GPTQ or FORMAT.GEMM: actual = {qcfg.format}") elif backend == BACKEND.SGLANG: if qcfg.format != FORMAT.GPTQ: diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 442e925be..8427e8be5 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -1030,6 +1030,8 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_ # print("self qw", self.qweight, self.scales, self.qzeros) class AWQuantLinear(BaseQuantLinear): + REQUIRES_FORMAT_V2 = False + def __init__(self, bias: bool = False, register_buffers: bool = False, diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index e2b95100e..6d167d432 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -1,19 +1,91 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium +from __future__ import annotations import torch from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear -from ...quantization.awq.modules.linear.gemm import WQLinearMMFunction +from ...quantization.awq.utils.mit_repacker import ( + multiply_scale_qzero_negative as mit_multiply_scale_qzero_negative, + packing_v2_from_unpacked as mit_packing_v2_from_unpacked, + qweight_unpack as mit_qweight_unpack, +) +from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND -from ...utils.logger import setup_logger +from ...utils.gemv import calculate_zeros_width +__all__ = ["AwqGEMMQuantLinear"] + +awq_ext, msg = try_import("gptqmodel_awq_kernels") + +def _convert_awq_v1_to_v2( + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + *, + bits: int, + group_size: int, + in_features: int, + out_features: int, + interleave: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = qweight.device + unpacked = mit_qweight_unpack(qweight.to(device)) + if unpacked.shape == (in_features, out_features): + unpacked = unpacked.transpose(0, 1).contiguous() + elif unpacked.shape != (out_features, in_features): + raise ValueError( + f"Unexpected legacy qweight shape {tuple(unpacked.shape)}; expected " + f"({out_features}, {in_features}) or ({in_features}, {out_features})." + ) + qweight_v2 = mit_packing_v2_from_unpacked(unpacked, interleave, 64).contiguous() + + pack_num = 32 // bits + zeros_width = calculate_zeros_width(in_features, group_size, pack_num=pack_num) + + groups = 1 if group_size in (-1, 0) else in_features // group_size + scales_legacy = scales.to(device=device).contiguous() + if scales_legacy.shape == (out_features, groups): + scales_groups_first = scales_legacy.transpose(0, 1).contiguous() + elif scales_legacy.shape == (groups, out_features): + scales_groups_first = scales_legacy + else: + raise ValueError( + f"Unexpected legacy scales shape {tuple(scales_legacy.shape)}; " + f"expected ({out_features}, {groups}) or ({groups}, {out_features})." + ) + + qzeros_legacy = qzeros.to(device=device).contiguous() + expected_zero_cols = out_features // pack_num + if qzeros_legacy.shape == (out_features, expected_zero_cols): + qzeros_groups_first = qzeros_legacy.transpose(0, 1).contiguous() + elif qzeros_legacy.shape == (expected_zero_cols, out_features): + qzeros_groups_first = qzeros_legacy.transpose(0, 1).contiguous() + elif qzeros_legacy.shape == (groups, expected_zero_cols): + qzeros_groups_first = qzeros_legacy + else: + raise ValueError( + f"Unexpected legacy qzeros shape {tuple(qzeros_legacy.shape)}; " + f"expected one of {{({out_features}, {expected_zero_cols}), ({expected_zero_cols}, {out_features}), ({groups}, {expected_zero_cols})}}." + ) + + scaled_zeros_groups_first = mit_multiply_scale_qzero_negative( + scales_groups_first, qzeros_groups_first, zp_shift=0 + ) + + padded_rows = zeros_width * pack_num + scales_processed = torch.zeros( + (padded_rows, out_features), + dtype=scales_groups_first.dtype, + device=device, + ) + zeros_processed = torch.zeros_like(scales_processed) + rows = min(padded_rows, scales_groups_first.shape[0]) + scales_processed[:rows, :] = scales_groups_first[:rows, :] + zeros_processed[:rows, :] = scaled_zeros_groups_first[:rows, :] + return qweight_v2, scales_processed, zeros_processed -log = setup_logger() class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] @@ -28,15 +100,13 @@ class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_DEVICES = [DEVICE.ALL] SUPPORTS_PLATFORM = [PLATFORM.ALL] - SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_PACK_DTYPES = [torch.int16] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] - - REQUIRES_FORMAT_V2 = False + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] - # for transformers/optimum tests compat QUANT_TYPE = "awq_gemm" + INTERLEAVE = 4 def __init__( self, @@ -47,11 +117,12 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - pack_dtype: torch.dtype = torch.int32, + pack_dtype: torch.dtype = torch.int16, adapter: Adapter = None, register_buffers: bool = False, **kwargs, - ): + ) -> None: + backend = kwargs.pop("backend", BACKEND.GEMM) super().__init__( bits=bits, group_size=group_size, @@ -61,65 +132,103 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.GEMM), + backend=backend, adapter=adapter, - register_buffers=register_buffers, - **kwargs) - - def post_init(self): - # if self.padded_infeatures != self.in_features: - # self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) - # self.qzeros.resize_( - # math.ceil(self.padded_infeatures / self.group_size), - # self.out_features // self.pack_dtype_bits * self.bits - # ) - # self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) - # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, - # device=self.g_idx.device) - - # awq only accepts float16 - if self.scales is not None: - self.scales = self.scales.to(dtype=torch.float16) - - super().post_init() - - def forward(self, x: torch.Tensor): - out_shape = x.shape[:-1] + (self.out_features,) - - input_dtype = x.dtype - if input_dtype != torch.float16: - x = x.half() - - if self.training: - out = WQLinearMMFunction.apply( + register_buffers=False, + **kwargs, + ) + self.interleave = self.INTERLEAVE + self.split_k_iters = 8 + self.bias = None + if register_buffers: + self._init_buffers() + + def _init_buffers(self) -> None: + int16_pack = 16 // self.bits + self.register_buffer( + "qweight", + torch.zeros( + (self.out_features // self.INTERLEAVE, self.in_features // int16_pack * self.INTERLEAVE), + dtype=torch.int16, + ), + ) + self.register_buffer( + "qzeros", + torch.zeros((self.in_features // self.group_size, self.out_features), dtype=torch.float16), + ) + self.register_buffer( + "scales", + torch.zeros((self.in_features // self.group_size, self.out_features), dtype=torch.float16), + ) + if self.bias is not None: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=torch.float16)) + + def load_legacy_tensors( + self, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bias: torch.Tensor | None, + ) -> None: + device = qweight.device + for name in ("qweight", "scales", "qzeros", "bias"): + if hasattr(self, name): + delattr(self, name) + if name in getattr(self, "_buffers", {}): + del self._buffers[name] + qweight_v2, scales_processed, zeros_processed = _convert_awq_v1_to_v2( + qweight, + qzeros, + scales, + bits=self.bits, + group_size=self.group_size, + in_features=self.in_features, + out_features=self.out_features, + interleave=self.INTERLEAVE, + ) + self.register_buffer("qweight", qweight_v2) + self.register_buffer("scales", scales_processed) + self.register_buffer("qzeros", zeros_processed) + if bias is not None: + self.register_buffer("bias", bias.to(device=device, dtype=scales.dtype)) + else: + self.bias = None + self.pack_dtype = torch.int16 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if awq_ext is None: + raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) + if x.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"AWQ GEMM kernels support float16/bfloat16 inputs only. Got {x.dtype}.") + if self.scales.dtype != x.dtype: + self.scales = self.scales.to(dtype=x.dtype) + if self.qzeros.dtype != x.dtype: + self.qzeros = self.qzeros.to(dtype=x.dtype) + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias = self.bias.to(dtype=x.dtype) + num_tokens = x.numel() // x.shape[-1] + use_fp32_accum = True + if num_tokens < 8: + out = awq_ext.gemv_forward_cuda( x, self.qweight, - self.qzeros, self.scales, - self.bits, - self.group_size, - self.bias, + self.qzeros, + num_tokens, self.out_features, + self.in_features, + self.group_size, ) else: - with torch.inference_mode(): - out = WQLinearMMFunction.apply( - x, - self.qweight, - self.qzeros, - self.scales, - self.bits, - self.group_size, - self.bias, - self.out_features, - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - + out = awq_ext.gemm_forward_cuda( + x, + self.qweight, + self.scales, + self.qzeros, + use_fp32_accum, + ) + if self.bias is not None: + out = out + self.bias if self.adapter: out = self.adapter.apply(x=x, out=out) - - return out.reshape(out_shape) - -__all__ = ["AwqGEMMQuantLinear"] + return out diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py b/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py deleted file mode 100644 index 0138ffd46..000000000 --- a/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import torch - -from ...adapter.adapter import Adapter, Lora -from ...models._const import DEVICE, PLATFORM -from ...quantization.awq.utils.packing_utils import dequantize_gemm -from ...utils.backend import BACKEND -from ...utils.logger import setup_logger -from .awq_gemm import AwqGEMMQuantLinear - - -log = setup_logger() - -try: - from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear - - assert hasattr(IPEXWeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" - IPEX_INSTALLED = True -except: - IPEX_INSTALLED = False - - -class Awq_IPEXQuantLinear(AwqGEMMQuantLinear): - SUPPORTS_BITS = [4] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] - SUPPORTS_DESC_ACT = [True, False] - SUPPORTS_SYM = [True, False] - SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = False - SUPPORTS_AUTO_PADDING = False - SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] - SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] - - SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] - SUPPORTS_PLATFORM = [PLATFORM.ALL] - SUPPORTS_PACK_DTYPES = [torch.int32] - SUPPORTS_ADAPTERS = [Lora] - - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] - - # for transformers/optimum tests compat - QUANT_TYPE = "awq_gemm_ipex" - - def __init__( - self, - bits: int, - group_size: int, - sym: bool, - desc_act: bool, - in_features: int, - out_features: int, - bias: bool = False, - pack_dtype: torch.dtype = torch.int32, - adapter: Adapter = None, - **kwargs, - ): - assert IPEX_INSTALLED, \ - "Please install IPEX package with `pip install intel_extension_for_pytorch`." - - self.init_ipex = False - - super().__init__( - bits=bits, - group_size=group_size, - sym=sym, - desc_act=desc_act, - in_features=in_features, - out_features=out_features, - bias=bias, - pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.IPEX), - adapter=adapter, - **kwargs) - - def post_init(self): - # if self.padded_infeatures != self.in_features: - # self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) - # self.qzeros.resize_( - # math.ceil(self.padded_infeatures / self.group_size), - # self.out_features // self.pack_dtype_bits * self.bits - # ) - # self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) - # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, - # device=self.g_idx.device) - - # awq only accepts float16 - self.scales = self.scales.to(dtype=torch.float16) - - device_type = self.qweight.device.type - if device_type != "meta": - assert device_type in ("cpu", "xpu") - - super().post_init() - - def init_ipex_linear(self): - if not self.training: - self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, - self.in_features, self.out_features, None, - self.bias, - self.group_size, None, quant_method=1, dtype=0) - - def forward(self, x: torch.Tensor): - assert IPEX_INSTALLED, ( - "IPEX kernels could not be loaded. " - "Please install with `pip install intel_extension_for_pytorch` and " - "refer to the detial https://github.com/intel/intel-extension-for-pytorch/tree/main") - - if not self.init_ipex: - self.init_ipex_linear() - self.init_ipex = True - - out_shape = x.shape[:-1] + (self.out_features,) - - if hasattr(self, "ipex_linear"): - with torch.inference_mode(): - out = self.ipex_linear(x) - else: - out = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size).to(x.dtype) - out = torch.matmul(x, out) - - if self.adapter: - out = self.adapter.apply(x=x, out=out) - - return out.reshape(out_shape) - - def backward(self, grad_output): - weights = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size).to( - grad_output.dtype) - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None - - def extra_repr(self) -> str: - return ("in_features={}, out_features={}, bias={}, bits={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.bits, - self.group_size, - )) - - -__all__ = ["Awq_IPEXQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index 150ad3b69..079d2ba76 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -1,162 +1,4 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import torch - -from ...adapter.adapter import Adapter, Lora -from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear -from ...quantization.awq.utils.module import try_import -from ...utils.backend import BACKEND -from ...utils.gemv import calculate_zeros_width -from ...utils.logger import setup_logger - - -log = setup_logger() - -awq_ext, msg = try_import("gptqmodel_awq_kernels") - -class AwqGEMVQuantLinear(AWQuantLinear): - SUPPORTS_BITS = [4] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] - SUPPORTS_DESC_ACT = [True, False] - SUPPORTS_SYM = [True, False] - SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = True - SUPPORTS_AUTO_PADDING = False - SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] - SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] - - SUPPORTS_DEVICES = [DEVICE.ALL] - SUPPORTS_PLATFORM = [PLATFORM.ALL] - SUPPORTS_PACK_DTYPES = [torch.int32] - SUPPORTS_ADAPTERS = [Lora] - - SUPPORTS_DTYPES = [torch.float16] - - # for transformers/optimum tests compat - QUANT_TYPE = "awq_gemv" - - def __init__( - self, - bits: int, - group_size: int, - sym: bool, - desc_act: bool, - in_features: int, - out_features: int, - bias: bool = False, - pack_dtype: torch.dtype = torch.int32, - adapter: Adapter = None, - register_buffers: bool = False, - **kwargs, - ): - backend = kwargs.pop("backend", BACKEND.GEMV) - super().__init__( - bits=bits, - group_size=group_size, - sym=sym, - desc_act=desc_act, - in_features=in_features, - out_features=out_features, - bias=bias, - pack_dtype=pack_dtype, - backend=backend, - adapter=adapter, - register_buffers=False, - **kwargs) - - self.split_k_iters = 8 - - self.bias = None - - if register_buffers: - self.register_buffer( - "qweight", - torch.zeros((out_features, in_features // self.pack_factor), dtype=self.pack_dtype), - ) - self.register_buffer( - "qzeros", - torch.zeros( - out_features, - calculate_zeros_width(in_features, self.group_size), - dtype=self.pack_dtype, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - out_features, - calculate_zeros_width(in_features, self.group_size) * self.pack_factor, - dtype=torch.float16, - ), - ) - - if bias: - self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16)) - - def post_init(self): - # if self.padded_infeatures != self.in_features: - # self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) - # self.qzeros.resize_( - # math.ceil(self.padded_infeatures / self.group_size), - # self.out_features // self.pack_dtype_bits * self.bits - # ) - # self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) - # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, - # device=self.g_idx.device) - - # awq only accepts float16 - self.scales = self.scales.to(dtype=torch.float16) - - super().post_init() - - def forward(self, x: torch.Tensor): - if awq_ext is None: - raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) - - out_shape = x.shape[:-1] + (self.out_features,) - inputs = x.reshape(-1, x.shape[-1]) - - input_dtype = inputs.dtype - if input_dtype != torch.float16: - inputs = inputs.half() - - if inputs.shape[0] > 8: - out = awq_ext.gemmv2_forward_cuda( - inputs, - self.qweight, - self.scales, - self.qzeros, - self.group_size, - self.split_k_iters, - ) - else: - out = awq_ext.gemv_forward_cuda( - inputs, self.qweight, self.scales, self.qzeros, self.group_size - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - out = out + self.bias if self.bias is not None else out - - if self.adapter: - out = self.adapter.apply(x=x, out=out) - - return out.reshape(out_shape) - - def extra_repr(self) -> str: - return ( - "in_features={}, out_features={}, bias={}, bits={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.bits, - self.group_size, - ) - ) +from .awq_gemm import AwqGEMMQuantLinear as AwqGEMVQuantLinear __all__ = ["AwqGEMVQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py index e12337046..ec64d5e43 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py @@ -16,7 +16,8 @@ log = setup_logger() -awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") +awq_ext, msg = try_import("gptqmodel_awq_kernels") + class AwqGEMVFastQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] @@ -34,7 +35,7 @@ class AwqGEMVFastQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int16] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] # for transformers/optimum tests compat QUANT_TYPE = "awq_gemv_fast" @@ -48,7 +49,7 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - pack_dtype: torch.dtype = torch.int32, + pack_dtype: torch.dtype = torch.int16, adapter: Adapter = None, register_buffers: bool = False, **kwargs, @@ -114,14 +115,27 @@ def post_init(self): super().post_init() def forward(self, x: torch.Tensor): - if awq_v2_ext is None: - raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg) + if awq_ext is None: + raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) inputs = x batch_size, n_tokens, _ = inputs.shape + if inputs.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"{self.__class__.__name__} only supports dtypes {{torch.float16, torch.bfloat16}}: got {inputs.dtype}.") + + if self.scales is not None and self.scales.dtype != inputs.dtype: + self.scales = self.scales.to(dtype=inputs.dtype) + + if self.qzeros is not None and self.qzeros.dtype != inputs.dtype: + self.qzeros = self.qzeros.to(dtype=inputs.dtype) + + if self.bias is not None and self.bias.dtype != inputs.dtype: + self.bias = self.bias.to(dtype=inputs.dtype) + + use_fp32_accum = True if batch_size < 8 and n_tokens == 1: - out = awq_v2_ext.gemv_forward_cuda_decode( + out = awq_ext.gemv_forward_cuda( inputs, self.qweight, self.scales, @@ -132,10 +146,11 @@ def forward(self, x: torch.Tensor): self.group_size, ) else: - out = awq_v2_ext.gemm_forward_cuda_prefill( - inputs, self.qweight, self.scales, self.qzeros + out = awq_ext.gemm_forward_cuda( + inputs, self.qweight, self.scales, self.qzeros, use_fp32_accum ) - out = out + self.bias if self.bias is not None else out + if self.bias is not None: + out = out + self.bias if self.adapter: out = self.adapter.apply(x=x, out=out) diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index 5b41a9d6f..8e741a96c 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -53,9 +53,7 @@ class AwqMarlinQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] - - REQUIRES_FORMAT_V2 = False + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] # for transformers/optimum tests compat QUANT_TYPE = "awq_marlin" diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/awq_torch.py index 5e298d7dc..c6ff85aa6 100644 --- a/gptqmodel/nn_modules/qlinear/awq_torch.py +++ b/gptqmodel/nn_modules/qlinear/awq_torch.py @@ -7,13 +7,19 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...quantization.awq.utils.packing_utils import dequantize_gemm +from ...quantization.awq.utils.mit_repacker import ( + multiply_scale_qzero_negative as mit_multiply_scale_qzero_negative, + qweight_unpack as mit_qweight_unpack, +) +from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND from ...utils.logger import setup_logger from . import AWQuantLinear +from .awq_gemm import _convert_awq_v1_to_v2 log = setup_logger() +awq_ext, msg = try_import("gptqmodel_awq_kernels") class AwqTorchQuantLinear(AWQuantLinear): @@ -32,9 +38,8 @@ class AwqTorchQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] - REQUIRES_FORMAT_V2 = False QUANT_TYPE = "awq_torch" @@ -67,6 +72,10 @@ def __init__( **kwargs, ) + self._legacy_qweight: torch.Tensor | None = None + self._legacy_qzeros: torch.Tensor | None = None + self._legacy_scales: torch.Tensor | None = None + def post_init(self): super().post_init() @@ -76,19 +85,127 @@ def extra_repr(self) -> str: f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}" ) + def load_legacy_tensors( + self, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bias: torch.Tensor | None, + ) -> None: + device = qweight.device + for name in ("qweight", "scales", "qzeros", "bias"): + if hasattr(self, name): + delattr(self, name) + if name in getattr(self, "_buffers", {}): + del self._buffers[name] + + self._legacy_qweight = qweight.to(device) + self._legacy_qzeros = qzeros.to(device) + self._legacy_scales = scales.to(device) + + qweight_v2, scales_processed, zeros_processed = _convert_awq_v1_to_v2( + qweight, + qzeros, + scales, + bits=self.bits, + group_size=self.group_size, + in_features=self.in_features, + out_features=self.out_features, + interleave=4, + ) + self.register_buffer("qweight", qweight_v2) + self.register_buffer("scales", scales_processed) + self.register_buffer("qzeros", zeros_processed) + if bias is not None: + self.register_buffer("bias", bias.to(device=device, dtype=scales.dtype)) + else: + self.bias = None + self.pack_dtype = torch.int16 + + def _dequantize_weight_fallback(self, *, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + if ( + self._legacy_qweight is None + or self._legacy_qzeros is None + or self._legacy_scales is None + ): + raise RuntimeError("Legacy AWQ tensors unavailable for Torch fallback.") + + qweight = self._legacy_qweight.to(device=device) + qzeros = self._legacy_qzeros.to(device=device) + scales = self._legacy_scales.to(device=device) + + unpacked = mit_qweight_unpack(qweight) + if unpacked.shape == (self.in_features, self.out_features): + unpacked = unpacked.transpose(0, 1).contiguous() + elif unpacked.shape != (self.out_features, self.in_features): + raise ValueError( + f"Unexpected unpacked qweight shape {tuple(unpacked.shape)}; " + f"expected ({self.out_features}, {self.in_features})" + ) + unpacked = unpacked.to(torch.float32) + + groups = 1 if self.group_size in (-1, 0) else self.in_features // self.group_size + scales_groups = scales + if scales_groups.shape == (self.out_features, groups): + scales_groups = scales_groups.transpose(0, 1).contiguous() + elif scales_groups.shape != (groups, self.out_features): + raise ValueError( + f"Unexpected legacy scales shape {tuple(scales_groups.shape)}; " + f"expected ({groups}, {self.out_features}) or ({self.out_features}, {groups})." + ) + + pack_num = 32 // self.bits + expected_zero_cols = self.out_features // pack_num + qzeros_groups = qzeros + if qzeros_groups.shape == (self.out_features, expected_zero_cols): + qzeros_groups = qzeros_groups.transpose(0, 1).contiguous() + elif qzeros_groups.shape == (expected_zero_cols, self.out_features): + qzeros_groups = qzeros_groups.transpose(0, 1).contiguous() + elif qzeros_groups.shape != (groups, expected_zero_cols): + raise ValueError( + f"Unexpected legacy qzeros shape {tuple(qzeros_groups.shape)}; " + f"expected one of {{({self.out_features}, {expected_zero_cols}), ({expected_zero_cols}, {self.out_features}), ({groups}, {expected_zero_cols})}}." + ) + + scaled_zeros_groups = mit_multiply_scale_qzero_negative(scales_groups, qzeros_groups, zp_shift=0) + + weight = torch.empty((self.out_features, self.in_features), dtype=torch.float32, device=device) + for group_idx in range(groups): + start = group_idx * self.group_size + end = min(start + self.group_size, self.in_features) + weight[:, start:end] = ( + unpacked[:, start:end] * scales_groups[group_idx].to(torch.float32).unsqueeze(1) + + scaled_zeros_groups[group_idx].to(torch.float32).unsqueeze(1) + ) + + return weight.transpose(0, 1).contiguous().to(dtype=dtype) + def forward(self, x: torch.Tensor): original_shape = x.shape[:-1] + (self.out_features,) device = x.device x_flat = x.reshape(-1, x.shape[-1]) - weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size) - assert weight.dtype == torch.float16, f"weight {weight.dtype} is not float16" - if weight.dtype != x_flat.dtype or weight.device != device: - weight = weight.to(device=device, dtype=x_flat.dtype) - - output = torch.matmul(x_flat, weight) + if awq_ext is not None and hasattr(awq_ext, "dequantize_weights_cuda"): + weight = awq_ext.dequantize_weights_cuda( + self.qweight, + self.scales, + self.qzeros, + 0, + 0, + 0, + False, + ) + weight = weight.to(dtype=x_flat.dtype, device=device) + output = torch.matmul(x_flat, weight) + else: + matmul_weight = self._dequantize_weight_fallback(device=device, dtype=torch.float32) + matmul_input = x_flat.to(torch.float32) + output = torch.matmul(matmul_input, matmul_weight) + output = output.to(dtype=x_flat.dtype) if self.bias is not None: + if self.bias.dtype != output.dtype: + self.bias = self.bias.to(dtype=output.dtype) output = output + self.bias if self.adapter: diff --git a/gptqmodel/quantization/awq/modules/linear/__init__.py b/gptqmodel/quantization/awq/modules/linear/__init__.py index 4c298707e..162de045c 100644 --- a/gptqmodel/quantization/awq/modules/linear/__init__.py +++ b/gptqmodel/quantization/awq/modules/linear/__init__.py @@ -6,7 +6,6 @@ from .exllama import WQLinear_Exllama, exllama_post_init from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from .gemm import WQLinear_GEMM -from .gemm_ipex import WQLinear_IPEX, ipex_post_init from .gemv import WQLinear_GEMV from .gemv_fast import WQLinear_GEMVFast from .marlin import WQLinear_Marlin, marlin_post_init diff --git a/gptqmodel/quantization/awq/modules/linear/gemm.py b/gptqmodel/quantization/awq/modules/linear/gemm.py index ad8e87825..0dd7a15a5 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemm.py +++ b/gptqmodel/quantization/awq/modules/linear/gemm.py @@ -1,306 +1,4 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium +from .gemv_fast import WQLinear_GEMVFast as WQLinear_GEMM -import warnings - -import torch -import torch.nn as nn -from torch.autograd import Function - -from gptqmodel.quantization.awq.utils.module import try_import -from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm -from gptqmodel.quantization.awq.utils.utils import get_best_device - - -# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed. - -awq_ext, msg = try_import("gptqmodel_awq_kernels") -user_has_been_warned = False - -try: - from gptqmodel.quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton - - # covers CUDA, ROCm and XPU. If we can import triton, then we can use it. - TRITON_AVAILABLE = True - -except ImportError: - TRITON_AVAILABLE = False - -# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev -class WQLinearMMFunction(Function): - @staticmethod - # ctx is the first argument to forward - def forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - ): - # The forward pass can use ctx. - ctx.save_for_backward(x, qweight, qzeros, scales, bias) - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - if x.shape[0] == 0: - return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - if awq_ext is not None: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 0, 0, 0, False - ) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 - ) - - elif TRITON_AVAILABLE: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_dequantize_triton(qweight, scales, qzeros) - out = torch.matmul(x, out.to(x.dtype)) - else: - out = awq_gemm_triton( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, - ) - - else: - global user_has_been_warned - if not user_has_been_warned: - warnings.warn("Using naive (slow) implementation." + msg) - user_has_been_warned = True - out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) - out = torch.matmul(x, out) - - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - - # always want 3D tensor if tensor is 2D - if len(out.shape) == 2: - out = out.unsqueeze(0) - - return out - - @staticmethod - def backward(ctx, grad_output): - input, qweight, qzeros, scales, bias = ctx.saved_tensors - - if awq_ext is None and not TRITON_AVAILABLE: - raise ValueError( - "either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" - " by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels" - ) - - # Cast to correct dtype for mixed precision training - if awq_ext is not None: - weights = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 1, 0, 0, False - ).to(grad_output.dtype) - else: - weights = awq_dequantize_triton( - qweight, scales, qzeros - ).to(grad_output.dtype) - - if ctx.needs_input_grad[0]: - # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm - # to propagate gradient across all batch sizes. - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None - -class WQLinear_GEMM(nn.Module): - def __init__( - self, w_bit, group_size, in_features, out_features, bias, dev, training=False - ): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.training = training - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - - self.register_buffer( - "qweight", - torch.zeros( - (in_features, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.float16, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", - torch.zeros( - (out_features), - dtype=torch.float16, - device=dev, - ), - ) - else: - self.bias = None - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - # need scales and zeros info for real quantization - assert scales is not None and zeros is not None - scale_zeros = zeros * scales - - awq_linear.scales = scales.clone().half() - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() - - pack_num = 32 // awq_linear.w_bit - - intweight = [] - for idx in range(awq_linear.in_features): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[idx // group_size]) - / awq_linear.scales[idx // group_size] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.to(dtype=torch.int32) - - best_device = get_best_device() - - # Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device - if "mps" in best_device: - intweight = intweight.to("cpu") - - qweight = torch.zeros( - (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), - dtype=torch.int32, - device=intweight.device, - ) - - for col in range(intweight.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qweight_col = intweight[:, col * pack_num + order_map[i]] - qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) - awq_linear.qweight = qweight - - zeros = zeros.to(dtype=torch.int32, device=best_device) - - if "mps" in best_device: - zeros = zeros.to("cpu") - - qzeros = torch.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), - dtype=torch.int32, - device=zeros.device, - ) - - for col in range(zeros.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qzero_col = zeros[:, col * pack_num + order_map[i]] - qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) - awq_linear.qzeros = qzeros - - return awq_linear - - def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features,) - - input_dtype = x.dtype - if input_dtype != torch.float16: - x = x.half() - - if self.training: - out = WQLinearMMFunction.apply( - x, - self.qweight, - self.qzeros, - self.scales, - self.w_bit, - self.group_size, - self.bias, - self.out_features, - ) - else: - with torch.inference_mode(): - out = WQLinearMMFunction.apply( - x, - self.qweight, - self.qzeros, - self.scales, - self.w_bit, - self.group_size, - self.bias, - self.out_features, - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - return out.reshape(out_shape) - - def extra_repr(self) -> str: - return ( - "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - ) - ) +__all__ = ["WQLinear_GEMM"] diff --git a/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py b/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py deleted file mode 100644 index 7ab4eec70..000000000 --- a/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import torch -import torch.nn as nn - -from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm - -from .gemm import WQLinear_GEMM - - -try: - from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear - assert hasattr(IPEXWeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" - IPEX_INSTALLED = True -except: - IPEX_INSTALLED = False - - -class WQLinear_IPEX(WQLinear_GEMM): - - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): - nn.Module.__init__(self) - assert IPEX_INSTALLED, \ - "Please install IPEX package with `pip install intel_extension_for_pytorch`." - assert w_bit == 4, "Only 4 bit are supported for now." - - self.use_bf16 = True # Intel platform support bf16 even without amx. - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.scale_dtype = torch.float32 - self.training = training - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - self.pack_num = 32 // self.w_bit - - self.init_ipex = False - - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // self.pack_num), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.bfloat16 if self.use_bf16 else torch.float32, - device=dev, - )) - if bias: - self.register_buffer( - "bias", - torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev), - ) - else: - self.register_buffer( - "bias", - None, - ) - qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev) - self.register_buffer("qweight", qweight) - - def post_init(self): - device_type = self.qweight.device.type - if device_type != "meta": - assert device_type in ("cpu", "xpu") - - def init_ipex_linear(self): - if not self.training: - self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ - self.in_features, self.out_features, None, self.bias, \ - self.group_size, None, quant_method=1, dtype=0) - - @classmethod - def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - raise NotImplementedError("Only inference is supported for IPEX kernels") - - def forward(self, x): - assert IPEX_INSTALLED, ( - "IPEX kernels could not be loaded. " - "Please install with `pip install intel_extension_for_pytorch` and " - "refer to the detial https://github.com/intel/intel-extension-for-pytorch/tree/main") - - if not self.init_ipex: - self.init_ipex_linear() - self.init_ipex = True - - if hasattr(self, "ipex_linear"): - with torch.inference_mode(): - outputs = self.ipex_linear(x) - else: - outputs = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(x.dtype) - outputs = torch.matmul(x, outputs) - - return outputs - - def backward(self, grad_output): - weights = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(grad_output.dtype) - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None - - def extra_repr(self) -> str: - return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - )) - - -def ipex_post_init(model): - for _, submodule in model.named_modules(): - if isinstance(submodule, WQLinear_IPEX): - submodule.post_init() - - return model diff --git a/gptqmodel/quantization/awq/modules/linear/gemv.py b/gptqmodel/quantization/awq/modules/linear/gemv.py index c62863289..f02b95d1e 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemv.py +++ b/gptqmodel/quantization/awq/modules/linear/gemv.py @@ -3,202 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from .gemm import WQLinear_GEMM as WQLinear_GEMV -import torch -import torch.nn as nn - -from gptqmodel.quantization.awq.utils.module import try_import - - -awq_ext, msg = try_import("gptqmodel_awq_kernels") - -def make_divisible(c, divisor): - return (c + divisor - 1) // divisor - - -def calculate_zeros_width(in_features, group_size=128, pack_num=8): - if group_size >= 128: - size_multiplier = 1 - elif group_size == 64: - size_multiplier = 2 - elif group_size == 32: - size_multiplier = 4 - else: - raise NotImplementedError - - base_width = make_divisible(in_features // group_size, pack_num) - base_width = make_divisible(base_width, size_multiplier) * size_multiplier - return base_width - - -class WQLinear_GEMV(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.split_k_iters = 8 - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - pack_num = 32 // self.w_bit - - self.register_buffer( - "qweight", - torch.zeros( - (out_features, in_features // pack_num), dtype=torch.int32, device=dev - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (out_features, calculate_zeros_width(in_features, self.group_size)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - ( - out_features, - calculate_zeros_width(in_features, self.group_size) * pack_num, - ), - dtype=torch.float16, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", torch.zeros((out_features), dtype=torch.float16, device=dev) - ) - else: - self.bias = None - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - # need scales and zeros info for real quantization - assert scales is not None and zeros is not None - scale_zeros = zeros * scales - - pack_num = 32 // awq_linear.w_bit - qscales = torch.zeros( - ( - scales.shape[0], - calculate_zeros_width(linear.in_features, group_size) * pack_num, - ), - dtype=torch.float16, - device=scales.device, - ) - qscales[:, : scales.shape[1]] = scales - awq_linear.scales = qscales - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(awq_linear.in_features): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) - / awq_linear.scales[:, idx // group_size] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.to(dtype=torch.int32) - qweight = torch.zeros( - (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), - dtype=torch.int32, - device=intweight.device, - ) - - for col in range(intweight.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 1, 2, 3, 4, 5, 6, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qweight_col = intweight[:, col * pack_num + order_map[i]] - qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) - awq_linear.qweight = qweight - - zeros = zeros.to(dtype=torch.int32) - qzeros = torch.zeros( - (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)), - dtype=torch.int32, - device=zeros.device, - ) - - for col in range((zeros.shape[1] + pack_num - 1) // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 1, 2, 3, 4, 5, 6, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - if col * pack_num + order_map[i] >= zeros.shape[1]: - continue - qzero_col = zeros[:, col * pack_num + order_map[i]] - qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) - awq_linear.qzeros = qzeros - return awq_linear - - @torch.inference_mode() - def forward(self, x): - if awq_ext is None: - raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) - - out_shape = x.shape[:-1] + (self.out_features,) - inputs = x.reshape(-1, x.shape[-1]) - - input_dtype = inputs.dtype - if input_dtype != torch.float16: - inputs = inputs.half() - - if inputs.shape[0] > 8: - out = awq_ext.gemmv2_forward_cuda( - inputs, - self.qweight, - self.scales, - self.qzeros, - self.group_size, - self.split_k_iters, - ) - else: - out = awq_ext.gemv_forward_cuda( - inputs, self.qweight, self.scales, self.qzeros, self.group_size - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) - - def extra_repr(self) -> str: - return ( - "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - ) - ) +__all__ = ["WQLinear_GEMV"] diff --git a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py index 756e725ac..4fce4c185 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py +++ b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py @@ -9,7 +9,7 @@ from gptqmodel.quantization.awq.utils.module import try_import -awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") +awq_ext, msg = try_import("gptqmodel_awq_kernels") def make_divisible(c, divisor): return (c + divisor - 1) // divisor @@ -191,12 +191,26 @@ def from_linear( @torch.inference_mode() def forward(self, x): - if awq_v2_ext is None: - raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg) + if awq_ext is None: + raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) inputs = x batch_size, n_tokens, _ = inputs.shape + if inputs.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"WQLinear_GEMVFast only supports dtypes {{torch.float16, torch.bfloat16}}: got {inputs.dtype}.") + + if self.scales is not None and self.scales.dtype != inputs.dtype: + self.scales = self.scales.to(dtype=inputs.dtype) + + if self.qzeros is not None and self.qzeros.dtype != inputs.dtype: + self.qzeros = self.qzeros.to(dtype=inputs.dtype) + + if self.bias is not None and self.bias.dtype != inputs.dtype: + self.bias = self.bias.to(dtype=inputs.dtype) + + use_fp32_accum = True + if batch_size < 8 and n_tokens == 1: - out = awq_v2_ext.gemv_forward_cuda_decode( + out = awq_ext.gemv_forward_cuda( inputs, self.qweight, self.scales, @@ -207,9 +221,10 @@ def forward(self, x): self.group_size, ) else: - out = awq_v2_ext.gemm_forward_cuda_prefill( - inputs, self.qweight, self.scales, self.qzeros + out = awq_ext.gemm_forward_cuda( + inputs, self.qweight, self.scales, self.qzeros, use_fp32_accum ) - out = out + self.bias if self.bias is not None else out + if self.bias is not None: + out = out + self.bias return out diff --git a/gptqmodel/quantization/awq/utils/mit_repacker.py b/gptqmodel/quantization/awq/utils/mit_repacker.py new file mode 100644 index 000000000..85bf0c234 --- /dev/null +++ b/gptqmodel/quantization/awq/utils/mit_repacker.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: 2023 MIT Han Lab +# SPDX-License-Identifier: Apache-2.0 + +"""Helpers mirrored from MIT Han Lab's AWQ offline weight repacker. + +Source: +https://github.com/mit-han-lab/llm-awq/blob/main/tinychat/offline-weight-repacker.py + +These utilities convert legacy AWQ v1 tensors into the layout expected by the +v2 CUDA kernels. The implementations are intentionally close to the reference +script to minimise drift; only minimal Torch wrappers are added so they can run +on tensors that may already live on device memory. +""" + +from __future__ import annotations + +import torch + + +def qweight_unpack(qweight: torch.Tensor) -> torch.Tensor: + """Unpack int4 weights into individual nibbles (reference implementation).""" + if qweight.dtype != torch.int32: + qweight = qweight.to(torch.int32) + n = qweight.shape[0] + k = qweight.shape[1] * 8 + unpacked = torch.zeros((n, k), dtype=torch.int32, device=qweight.device) + mask = torch.tensor(0x0000000F, dtype=torch.int32, device=qweight.device) + for kk in range(k): + ele_offset = kk // 8 + bit_offset = (kk % 8) * 4 + unpacked[:, kk] = (qweight[:, ele_offset] >> bit_offset) & mask + return unpacked + + +def packing_v2_from_unpacked( + unpacked_qweight: torch.Tensor, interleave: int, kstride: int +) -> torch.Tensor: + """Pack unpacked weights into the v2 kernel layout (reference implementation).""" + n = unpacked_qweight.shape[0] + k = unpacked_qweight.shape[1] + + packed_kernel = ( + unpacked_qweight.detach() + .cpu() + .numpy() + .reshape(n, k // 32, 32) + ) + packed_kernel = packed_kernel.reshape(n, k // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) + packed_kernel = packed_kernel.reshape(n, k // 32, 32) + + packed_kernel = packed_kernel.reshape(n, k // 32, 4, 8) + packed_kernel = packed_kernel.reshape(n, k // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) + packed_kernel = packed_kernel.reshape(n, k) + + packed_kernel = packed_kernel.reshape(n // interleave, interleave, k // kstride, kstride) + packed_kernel = packed_kernel.transpose(0, 2, 1, 3) + packed_kernel = packed_kernel.reshape(n // interleave, k // kstride, kstride, interleave) + packed_kernel = ( + packed_kernel[..., 0] + | (packed_kernel[..., 1] << 4) + | (packed_kernel[..., 2] << 8) + | (packed_kernel[..., 3] << 12) + ) + packed_kernel = packed_kernel.reshape(n // interleave, k) + qweight_v2 = torch.tensor(packed_kernel.astype("int16"), device=unpacked_qweight.device).contiguous() + return qweight_v2 + + +def multiply_scale_qzero_negative( + scales: torch.Tensor, qzeros: torch.Tensor, zp_shift: int = 0 +) -> torch.Tensor: + """Compute scaled zero-points in the format consumed by v2 kernels.""" + pack_size = 8 + k_groups = scales.shape[1] + scaled_zeros = torch.zeros_like(scales) + qzeros = qzeros.to(torch.int32) + for group_idx in range(k_groups): + zero_idx = group_idx // pack_size + zero_offset = group_idx % pack_size + zero = (qzeros[:, zero_idx] >> (4 * zero_offset)) & 0x0000000F + scaled_zeros[:, group_idx] = scales[:, group_idx] * zero.to(scales.dtype) + return -(scaled_zeros + (zp_shift * scales)) + + +__all__ = [ + "qweight_unpack", + "packing_v2_from_unpacked", + "multiply_scale_qzero_negative", +] diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index df35ec839..e8b1e9386 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -67,7 +67,9 @@ class FORMAT(str, Enum): QQQ = "qqq" GEMM = "gemm" + GEMM_V2 = "gemm_v2" GEMV = "gemv" + GEMV_V2 = "gemv_v2" GEMV_FAST = "gemv_fast" @@ -95,7 +97,9 @@ class VRAMStrategy(str, Enum): }, METHOD.AWQ: { FORMAT.GEMM, + FORMAT.GEMM_V2, FORMAT.GEMV, + FORMAT.GEMV_V2, FORMAT.GEMV_FAST, FORMAT.MARLIN, }, @@ -283,10 +287,23 @@ def __post_init__(self): self.damp_auto_increment = 0.01 # TODO FIXME awq compat which didn't have checkpoint_format before merging to gptqmodel - if self.quant_method == METHOD.AWQ and self.format not in [FORMAT.MARLIN, FORMAT.GEMV, FORMAT.GEMV_FAST, FORMAT.GEMM]: + if self.quant_method == METHOD.AWQ and self.format not in [ + FORMAT.MARLIN, + FORMAT.GEMV, + FORMAT.GEMV_V2, + FORMAT.GEMV_FAST, + FORMAT.GEMM, + FORMAT.GEMM_V2, + ]: log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.GEMM}`") self.format = FORMAT.GEMM + if self.quant_method == METHOD.AWQ: + if self.format == FORMAT.GEMM: + self.format = FORMAT.GEMM_V2 + elif self.format == FORMAT.GEMV: + self.format = FORMAT.GEMV_V2 + if self.format not in valid_formats: raise ValueError( f"QuantizeConfig: checkpoint `format` used is {self.format}, and the quantization method is {self.quant_method}. " diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 0f72d4f92..b3965c748 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -85,8 +85,10 @@ FORMAT.QQQ: [BACKEND.QQQ], }, METHOD.AWQ: { - FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], - FORMAT.GEMV: [BACKEND.GEMV], + FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], + FORMAT.GEMM_V2: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ], + FORMAT.GEMV: [BACKEND.GEMV], + FORMAT.GEMV_V2: [BACKEND.GEMV], FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST], FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN], } diff --git a/gptqmodel_ext/awq/pybind_awq.cpp b/gptqmodel_ext/awq/pybind_awq.cpp deleted file mode 100644 index 0b839c419..000000000 --- a/gptqmodel_ext/awq/pybind_awq.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include -#include "quantization/gemm_cuda.h" -#include "quantization/gemv_cuda.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); - m.def("grouped_gemm_forward", &grouped_gemm_forward, "Quantized grouped GEMM kernel."); - m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel."); - m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); - m.def("dequantize_weights_cuda", &dequantize_weights_cuda, "Dequantize weights."); -} \ No newline at end of file diff --git a/gptqmodel_ext/awq/pybind_awq_v2.cpp b/gptqmodel_ext/awq/pybind_awq_v2.cpp index 9499e8b8f..e568da847 100644 --- a/gptqmodel_ext/awq/pybind_awq_v2.cpp +++ b/gptqmodel_ext/awq/pybind_awq_v2.cpp @@ -3,8 +3,28 @@ #include "quantization_new/gemm/gemm_cuda.h" #include "quantization_new/gemv/gemv_cuda.h" +namespace py = pybind11; + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gemm_forward_cuda_prefill", &gemm_forward_cuda_prefill, "New quantized GEMM kernel."); - m.def("gemv_forward_cuda_decode", &gemv_forward_cuda_decode, "New quantized GEMM kernel."); -} \ No newline at end of file + m.def( + "gemm_forward_cuda", + &gemm_forward_cuda_prefill, + "Quantized GEMM kernel (v2).", + py::arg("in_feats"), + py::arg("kernel"), + py::arg("scales"), + py::arg("zeros"), + py::arg("use_fp32") = false); + m.def( + "gemm_forward_cuda_prefill", + &gemm_forward_cuda_prefill, + "Quantized GEMM kernel (v2).", + py::arg("in_feats"), + py::arg("kernel"), + py::arg("scales"), + py::arg("zeros"), + py::arg("use_fp32") = false); + m.def("gemv_forward_cuda", &gemv_forward_cuda_decode, "Quantized GEMV kernel (v2)."); + m.def("gemv_forward_cuda_decode", &gemv_forward_cuda_decode, "Quantized GEMV kernel (v2)."); +} diff --git a/gptqmodel_ext/awq/quantization/dequantize.cuh b/gptqmodel_ext/awq/quantization/dequantize.cuh deleted file mode 100644 index 5d333b35c..000000000 --- a/gptqmodel_ext/awq/quantization/dequantize.cuh +++ /dev/null @@ -1,79 +0,0 @@ -/* -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - -@article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} -} -*/ - -#pragma once - - -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ - uint4 result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - - return result; -} - diff --git a/gptqmodel_ext/awq/quantization/gemm_cuda.h b/gptqmodel_ext/awq/quantization/gemm_cuda.h deleted file mode 100644 index afc816515..000000000 --- a/gptqmodel_ext/awq/quantization/gemm_cuda.h +++ /dev/null @@ -1,23 +0,0 @@ -#include - -torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); - -torch::Tensor grouped_gemm_forward( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - torch::Tensor _topk_weights, - torch::Tensor _sorted_token_ids_ptr, - torch::Tensor _expert_ids_ptr, - torch::Tensor _num_tokens_post_padded, - bool mul_weights, - int split_k_iters); - -torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters); - -// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda.h#L9C1-L10C106 -torch::Tensor dequantize_weights_cuda(torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, int thx, int thy, bool dbg); \ No newline at end of file diff --git a/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu b/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu deleted file mode 100644 index 98f49efac..000000000 --- a/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu +++ /dev/null @@ -1,1288 +0,0 @@ -/* - -@article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} -} - - */ - -#include -#include -#include "gemm_cuda.h" -#include "dequantize.cuh" -#include -#include -#include - - -// Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); - return (v1 << 16) | v0; -} - -__device__ __forceinline__ int make_divisible(int c, int divisor){ - return (c + divisor - 1) / divisor; -} - -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) -{ - static constexpr uint32_t ZERO = 0x0; - float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (128 + 8)]; - - int j_factors1 = ((OC + 128 - 1) / 128); - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - - half A_shared_warp[8]; - half B_shared_warp[32]; - for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[(j_0_4_init * 8) + i] = 0.0; - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / 128; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * 2 - + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (128 / 8) - + (((int)threadIdx.x) % (128 / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) - + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) - + (((int)threadIdx.x) % (128 / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (128 / 8) - + ((int)threadIdx.x) % (128 / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (128) - + (((int)threadIdx.x) % (128 / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 128 - + ((int)threadIdx.y) * 64 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { - *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); - } - - // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); - } - */ - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); - } - */ - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; - } - __syncthreads(); - - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - - - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); - } - - for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } - } - for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif - } - } - } - -// TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); - } - } - } -} - - -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) -{ - static constexpr uint32_t ZERO = 0x0; - float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (64 + 8)]; - - __shared__ half scaling_factors_shared[64]; - __shared__ half zeros_shared[64]; - - int j_factors1 = ((OC + 64 - 1) / 64); - - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - - half A_shared_warp[8]; - half B_shared_warp[16]; - for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[(j_0_4_init * 8) + i] = 0.0; - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / 64; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * 4 - + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (64 / 8) - + (((int)threadIdx.x) % (64 / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) - + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) - + (((int)threadIdx.x) % (64 / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (64 / 8) - + ((int)threadIdx.x) % (64 / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (64) - + (((int)threadIdx.x) % (64 / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 64 - + ((int)threadIdx.y) * 32 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { - *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); - } - - // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); - } - */ - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); - } - */ - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; - } - __syncthreads(); - - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) - { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); - } - - - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) - { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) - { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif - } - } - } - -// TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); - } - } - } -} - -template -__global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* zeros, int M, int IC, int OC, half* __restrict__ C) -{ - static constexpr uint32_t ZERO = 0x0; - float C_warp[64]; - __shared__ half A_shared[128 * (32 + 8)]; - __shared__ half B_shared[64 * (32 + 8)]; - - // __shared__ half scaling_factors_shared[64]; - // __shared__ half zeros_shared[64]; - - int j_factors1 = ((OC + 64 - 1) / 64); - - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1); - int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1); - - half A_shared_warp[32]; - half B_shared_warp[16]; - for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) { - for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0; - } - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride_A = 4 * 32 * 8 / 32; - static constexpr int row_stride = 4 * 32 * 8 / 32; - const int make_divisible_multipler = 128 / G; - const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler; - const int sf_w = zeros_w * 8; - - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; - int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; - - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (IC / 8) * 8 - + (((int)threadIdx.x) / (32 / 8)) * (IC / 8) - + (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8) - + (((int)threadIdx.x) % (32 / 8)) * 1; - -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 4) * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8)) * 8; - - - int* zeros_ptr = zeros - + ((int)threadIdx.y) * zeros_w * 8 - + (((int)threadIdx.x) / (32 / 8)) * zeros_w - + (((int)blockIdx_y) % j_factors1) * 64 * zeros_w - // this term is zero - + (((int)threadIdx.x) % (32 / 8)) / G ; - - half* scaling_factors_ptr = scaling_factors - + ((int)threadIdx.y) * sf_w * 8 - + (((int)threadIdx.x) / (32 / 8)) * sf_w - + (((int)blockIdx_y) % j_factors1) * (64) * sf_w - // this term is zero - + (((int)threadIdx.x) % (32 / 8)) * 8 / G; - - - // Haotian: TBD, check, May 29 11:46 AM PST - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdx_z -> split_k dim - + (((int)blockIdx_y) % j_factors1) * 64 - + (((int)threadIdx.y) / 2) * 32 - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1; - - // TODO (Haotian): load scales and zero points to smem - - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: Here we assume M % cta_M = 0. - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) - { - if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M) - { - *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0); - } - } - - - int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8; - half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G; - - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); - int* B_ptr_local = B_ptr + k_0_0 * (32 / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8)); - int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w); - zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4; - float current_zeros = (float)(zeros_loaded & 0xF); - half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w); - half B_loaded_fp16[8]; - #pragma unroll - for (int ic_1 = 0; ic_1 < 8; ic_1++){ - float current_single_weight_fp = (float)(B_loaded_current & 0xF); - half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros)); - B_loaded_current = B_loaded_current >> 4; - B_loaded_fp16[ic_1] = dequantized_weight; - } - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast(B_loaded_fp16); - } - __syncthreads(); - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { - for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3]) - : "r"(addr) - ); - } - } - - for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) { - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])); - } -#endif - } - } - } - } - -// Haotian: Here (May 29 11:46AM PST) -// TODO: Shang: Hoist loop invariance. - for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) { - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4; - if (row_offset < M) - { - *(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]); - } - } - } - } -} - -// Dequantization to fp16 -// kernel -// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L32C1-L32C1 -__global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // 4096x64 4096 rows 64 cols - half* __restrict__ scaling_factors, // 32x512 32 rows 512 cols - int* __restrict__ zeros, // 32x64 32 rows 64 cols - half* __restrict__ C, // 4096x512 4096 rows 512 cols - int G, - int in_c, - int out_c) -{ - if (blockIdx.z > 0) { - B = B + blockIdx.z * in_c * out_c / 8; - scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; - zeros = zeros + blockIdx.z * in_c * out_c / G / 8; - C = C + blockIdx.z * in_c * out_c; - } - int j_factors1 = 4; - int row_stride2 = 4; - int split_k_iters = 1; - static constexpr uint32_t ZERO = 0x0; - half B_shared[32 * (128 + 8)]; - - half* B_shared_ptr2 = B_shared; - - half B_shared_warp[32]; - int OC = 512; - - int N = blockDim.x * gridDim.x; // 2 - int col = (blockIdx.x * blockDim.x + threadIdx.x); - int row = blockIdx.y * blockDim.y + threadIdx.y; - int index1 = 8 * col + 8 * row * N; // + i (<8) - half* C_ptr2 = C + index1; - - int index2 = col + row * N; - int* B_ptr2 = B + index2; - - int index3 = col + (int)(row / G) * N; - int* zeros_ptr2 = zeros + index3; - int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8) - half* scaling_factors_ptr2 = scaling_factors + index4; - - - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); -int j=0; - - uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - - *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; - - for (int i=0; i<8; ++i) { - *(C_ptr2 + i) = B_shared[i]; - } -} - -template -__global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - const float* __restrict__ topk_weights, - const int* __restrict__ sorted_token_ids_ptr, - const int* __restrict__ expert_ids_ptr, - const int* __restrict__ num_tokens_post_padded, - const int num_valid_tokens, - const int top_k, - const int expert_num, - int pad_M, - int M, - int IC, - int OC, - half* __restrict__ C) -{ - // Only support matrix n = 64 or 128 - assert(N == 64 || N == 128); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 - assert(false); -#else - int num_tokens = *num_tokens_post_padded; - int j_factors1 = ((OC + N - 1) / N); - int blockIdx_x = 0; - int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1); - int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1); - int block = blockIdx_y / j_factors1; - if (block * 16 >= num_tokens) return; - - static constexpr uint32_t ZERO = 0x0; - float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (N + 8)]; - - __shared__ half scaling_factors_shared[N]; - __shared__ half zeros_shared[N]; - - half A_shared_warp[8]; - half B_shared_warp[N / 4]; - for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) { - for (int i = 0; i < 8; ++i) { - C_warp[(j_0_4_init * 8) + i] = 0.0; - } - } - - static constexpr int row_stride_warp = 32 * 8 / 32; - static constexpr int row_stride = 2 * 32 * 8 / N; - bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - - int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); - int token_id = sorted_token_ids_ptr[row]; - bool ld_A_flag = (token_id < num_valid_tokens); - half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8; - - int expert_id = expert_ids_ptr[block]; - B = B + OC * IC / 8 * expert_id; - scaling_factors = scaling_factors + OC * IC / G * expert_id; - zeros = zeros + OC * IC / G / 8 * expert_id; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; - // Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC * expert_num // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; - - // preload s.f. and zeros - int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; - if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; - for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { - int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; - __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { - *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { - *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); - } - - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - - int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); - - for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; - } - __syncthreads(); - - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - - - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); - } - - for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } - } - for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - -#endif - } - } - } - -// TODO: Shang: Hoist loop invariance. - for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) { - for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - int token_id = sorted_token_ids_ptr[row_offset]; - if (token_id < num_valid_tokens) - { - float value = C_warp[(ax1_0_1 * 8) + local_id]; - if (topk_weights) { - value = value * topk_weights[token_id]; - } - *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value); - } - } - } -#endif -} - - -torch::Tensor grouped_gemm_forward( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - torch::Tensor _topk_weights, - torch::Tensor _sorted_token_ids_ptr, - torch::Tensor _expert_ids_ptr, - torch::Tensor _num_tokens_post_padded, - bool mul_weights, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int pad_num_in_feats = _sorted_token_ids_ptr.size(0); - int num_in_channels = _in_feats.size(2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - int num_experts = _topk_weights.size(1); - int top_k = num_experts / _in_feats.size(1); - int group_size = num_in_channels / _scaling_factors.size(1); - - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - auto topk_weights = mul_weights ? reinterpret_cast(_topk_weights.data_ptr()) : nullptr; - auto sorted_token_ids_ptr = reinterpret_cast(_sorted_token_ids_ptr.data_ptr()); - auto expert_ids_ptr = reinterpret_cast(_expert_ids_ptr.data_ptr()); - auto num_tokens_post_padded = reinterpret_cast(_num_tokens_post_padded.data_ptr()); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - group_gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, - topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, - _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, - num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - group_gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, - topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded, - _topk_weights.numel(), top_k, num_experts, pad_num_in_feats, - num_in_feats, num_in_channels, num_out_channels, out_feats); - } - return _out_feats.sum(0); -} - -// Dequantization to fp16 -// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L935C1-L987C2 -torch::Tensor dequantize_weights_cuda( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy, - bool dbg) -{ - int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1); - int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2); - int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); - int out_c = qout_c * 8; - int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - int dbg_ = true; - if (thx==0 && thy==0) { - dbg_ = false; - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } - dbg = dbg && dbg_; - - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel; - if (num_experts == 1) { - _de_kernel = torch::empty({in_c, out_c}, options); - } else { - _de_kernel = torch::empty({num_experts, in_c, out_c}, options); - } - - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - - dim3 num_blocks(x_blocks, y_blocks, num_experts); - dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 - - dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); - - return _de_kernel; -} - -// in_feats: M, IC [float16] -// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] -// scaling_factors: IC // G, OC [float16] -// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] -// assume that batch_size < 16 for now - -torch::Tensor gemmv2_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int group_size, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - // for int4, need _kernel.size(1) * 8 - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(0)}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - - // blockIdx_x: i_factors[0] * j_factors[0] - // blockIdx_y: i_factors[1] * j_factors[1] - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 4); - if (group_size == 128) - { - gemmv2_forward_4bit_cuda_m128n64k32<128><<>>( - split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else if (group_size == 64) - { - gemmv2_forward_4bit_cuda_m128n64k32<64><<>>( - split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else - { - throw std::invalid_argument("Group size temporarily not supported."); - } - return _out_feats.sum(0); -} - -// in_feats: M, IC [float16] -// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] -// scaling_factors: IC // G, OC [float16] -// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] -// assume that batch_size < 16 for now - -torch::Tensor gemm_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - gemm_forward_4bit_cuda_m16n128k32<<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - gemm_forward_4bit_cuda_m16n64k32<<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); - } - return _out_feats.sum(0); -} diff --git a/gptqmodel_ext/awq/quantization/gemv_cuda.cu b/gptqmodel_ext/awq/quantization/gemv_cuda.cu deleted file mode 100644 index d4a26a066..000000000 --- a/gptqmodel_ext/awq/quantization/gemv_cuda.cu +++ /dev/null @@ -1,249 +0,0 @@ -// Inspired by https://github.com/ankan-ban/llama_cu_awq -/* - -@article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} -} - -*/ - -#include -#include -#include -#include -#include "gemv_cuda.h" -#define VECTORIZE_FACTOR 8 -#define Q_VECTORIZE_FACTOR 8 -#define PACK_FACTOR 8 -#define WARP_SIZE 32 - - -// Reduce sum within the warp using the tree reduction algorithm. -__device__ __forceinline__ float warp_reduce_sum(float sum) { - #pragma unroll - for(int i = 4; i >= 0; i--){ - sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2); - uint32_t packed_weights[4]; - // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) - *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); - // load scaling factors - // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups. - float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); - float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF); - int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; - const float4* inputs_ptr = inputs + inputs_ptr_delta; - // multiply 32 weights with 32 inputs - #pragma unroll - for (int ic_0 = 0; ic_0 < 4; ic_0++){ - // iterate over different uint32_t packed_weights in this loop - uint32_t current_packed_weight = packed_weights[ic_0]; - half packed_inputs[PACK_FACTOR]; - // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) - if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { - *((float4*)packed_inputs) = *(inputs_ptr + ic_0); - #pragma unroll - for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ - // iterate over 8 numbers packed within each uint32_t number - float current_single_weight_fp = (float)(current_packed_weight & 0xF); - float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); - //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); - psum += dequantized_weight * __half2float(packed_inputs[ic_1]); - current_packed_weight = current_packed_weight >> 4; - } - } - } - } - psum = warp_reduce_sum(psum); - if (threadIdx.x == 0) { - outputs[oc_idx] = __float2half(psum); - } -} - - -/* -Computes GEMV (group_size = 128). - -Args: - inputs: vector of shape [batch_size, IC]; - weight: matrix of shape [OC, IC / 8]; - output: vector of shape [OC]; - zeros: matrix of shape [OC, IC / group_size / 8]; - scaling_factors: matrix of shape [OC, IC / group_size]; - -Notes: - One cannot infer group_size from the shape of scaling factors. - the second dimension is rounded up to a multiple of PACK_FACTOR. -*/ -__global__ void gemv_kernel_g128( - const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs, - const int IC, const int OC){ - const int group_size = 128; - float psum = 0; - const int batch_idx = blockIdx.z; - const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; - const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; - half* outputs = _outputs + batch_idx * OC; - const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); - const int weight_w = IC / PACK_FACTOR; - // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address - const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); - // consistent with input shape - const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; - //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); - // tile size: 4 OC x 1024 IC per iter - for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ - // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. - uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx); - uint32_t packed_weights[4]; - // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) - *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); - // load scaling factors - // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. - float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); - float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF); - int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; - const float4* inputs_ptr = inputs + inputs_ptr_delta; - // multiply 32 weights with 32 inputs - #pragma unroll - for (int ic_0 = 0; ic_0 < 4; ic_0++){ - // iterate over different uint32_t packed_weights in this loop - uint32_t current_packed_weight = packed_weights[ic_0]; - half packed_inputs[PACK_FACTOR]; - // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) - if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { - *((float4*)packed_inputs) = *(inputs_ptr + ic_0); - #pragma unroll - for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ - // iterate over 8 numbers packed within each uint32_t number - float current_single_weight_fp = (float)(current_packed_weight & 0xF); - float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); - //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); - psum += dequantized_weight * __half2float(packed_inputs[ic_1]); - current_packed_weight = current_packed_weight >> 4; - } - } - } - } - psum = warp_reduce_sum(psum); - if (threadIdx.x == 0) { - outputs[oc_idx] = __float2half(psum); - } -} - - -/* -Computes GEMV (PyTorch interface). - -Args: - _in_feats: tensor of shape [B, IC]; - _kernel: int tensor of shape [OC, IC // 8]; - _zeros: int tensor of shape [OC, IC // G // 8]; - _scaling_factors: tensor of shape [OC, IC // G]; - blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; - blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; - -Returns: - out_feats: tensor of shape [B, OC]; -*/ -torch::Tensor gemv_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int group_size) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - // int kernel_volume = _out_in_map.size(1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - // auto out_in_map = _out_in_map.data_ptr(); - auto options = - torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - // kernel is [OC, IC] - at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - int blockDim_z = num_out_feats; - dim3 num_blocks(1, num_out_channels / 4, num_out_feats); - dim3 num_threads(32, 4); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (group_size == 64) - { - gemv_kernel_g64<<>>( - // pointers - in_feats, kernel, zeros, scaling_factors, out_feats, - // constants - num_in_channels, num_out_channels - ); - } - else if (group_size == 128) - { - gemv_kernel_g128<<>>( - // pointers - in_feats, kernel, zeros, scaling_factors, out_feats, - // constants - num_in_channels, num_out_channels - ); - } - return _out_feats; -;} - diff --git a/gptqmodel_ext/awq/quantization/gemv_cuda.h b/gptqmodel_ext/awq/quantization/gemv_cuda.h deleted file mode 100644 index 748abc5d1..000000000 --- a/gptqmodel_ext/awq/quantization/gemv_cuda.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once -#include - -torch::Tensor gemv_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int group_size); diff --git a/gptqmodel_ext/awq/quantization_new/dequantize.cuh b/gptqmodel_ext/awq/quantization_new/dequantize.cuh index fa02fb771..9917ec949 100644 --- a/gptqmodel_ext/awq/quantization_new/dequantize.cuh +++ b/gptqmodel_ext/awq/quantization_new/dequantize.cuh @@ -9,9 +9,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor } */ #include +#include #pragma once -__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) +template +__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result); + +template <> +__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) { // uint4 result; @@ -19,10 +24,10 @@ __inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *r uint32_t const i4s = reinterpret_cast(source); // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + constexpr uint32_t BOTTOM_MASK = 0x000f000f; + constexpr uint32_t TOP_MASK = 0x00f000f0; + constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. @@ -72,6 +77,47 @@ __inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *r asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); // Convert elt_67 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); +} + +template <> +__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) +{ + // uint4 result; + + uint32_t *h = reinterpret_cast(result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate bf16 number. + constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + constexpr uint32_t BOTTOM_MASK = 0x000f000f; + constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // Shift right by 4, 8, 12 to consider elt_23, elt_45 and elt_67. + const uint32_t i4s1 = i4s >> 4; + const uint32_t i4s2 = i4s >> 8; + const uint32_t i4s3 = i4s >> 12; + // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 - (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s1), "n"(BOTTOM_MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 - (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(i4s2), "n"(BOTTOM_MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 - (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(i4s3), "n"(BOTTOM_MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + + // This is the nv_bfloat162 {128, 128} represented as an integer + static constexpr uint32_t BF16_TOP_MAGIC_NUM = 0x43004300; - // return result; + reinterpret_cast<__nv_bfloat162*>(h)[0] = __hsub2(reinterpret_cast<__nv_bfloat162*>(h)[0], reinterpret_cast(BF16_TOP_MAGIC_NUM)); + reinterpret_cast<__nv_bfloat162*>(h)[1] = __hsub2(reinterpret_cast<__nv_bfloat162*>(h)[1], reinterpret_cast(BF16_TOP_MAGIC_NUM)); + reinterpret_cast<__nv_bfloat162*>(h)[2] = __hsub2(reinterpret_cast<__nv_bfloat162*>(h)[2], reinterpret_cast(BF16_TOP_MAGIC_NUM)); + reinterpret_cast<__nv_bfloat162*>(h)[3] = __hsub2(reinterpret_cast<__nv_bfloat162*>(h)[3], reinterpret_cast(BF16_TOP_MAGIC_NUM)); } \ No newline at end of file diff --git a/gptqmodel_ext/awq/quantization_new/dispatch_utils.cuh b/gptqmodel_ext/awq/quantization_new/dispatch_utils.cuh new file mode 100644 index 000000000..82eda3041 --- /dev/null +++ b/gptqmodel_ext/awq/quantization_new/dispatch_utils.cuh @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + if (pytorch_dtype == at::ScalarType::Half) { \ + using c_type = half; \ + __VA_ARGS__ \ + } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ + using c_type = nv_bfloat16; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + } + diff --git a/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu b/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu index b9c8c1fbc..68ff84d89 100644 --- a/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu +++ b/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu @@ -1,7 +1,9 @@ #include +#include #include "semaphore.h" #include "gemm_cuda.h" #include "../dequantize.cuh" +#include "../dispatch_utils.cuh" #include #include @@ -28,7 +30,7 @@ auto semaphores = reinterpret_cast(_semaphores.data_ptr()); \ constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ - constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(half); \ + constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(ctype); \ if (kSmemByteSize >= 99 * 1024) \ { \ printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \ @@ -37,10 +39,20 @@ int j_factors1 = num_out_channels / CTA_N / 1; \ dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \ dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ - auto kernel_func = gemm_w4a16_T1; \ - cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ - kernel_func<<>>( \ - in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels); + if (use_fp32) \ + { \ + auto kernel_func = gemm_w4a16_T1; \ + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ + kernel_func<<>>( \ + in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels); \ + } \ + else \ + { \ + auto kernel_func = gemm_w4a16_T1; \ + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ + kernel_func<<>>( \ + in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels); \ + } template __inline__ __host__ __device__ int get_log_tile(int n) @@ -87,18 +99,20 @@ __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) return smem_int_ptr; } -__inline__ __device__ void ldmatrix_m8n8_x4_b16(half *shared_warp, int ax0_0, uint32_t addr) +template +__inline__ __device__ void ldmatrix_m8n8_x4_b16(T *shared_warp, int ax0_0, uint32_t addr) { - asm volatile( + __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];" : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) : "r"(addr)); } -__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0_0, uint32_t addr) +template +__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(T *shared_warp, int ax0_0, uint32_t addr) { - asm volatile( + __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];" : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) @@ -118,17 +132,36 @@ __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__r "n"(cp_size)); } -__device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp) +__device__ __inline__ void mma_m16n8k16_f16f16f16(half *C_warp, half *A_shared_warp, half *B_shared_warp) { - asm volatile( + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16" + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};" + : "=r"(((unsigned *)C_warp)[0]), "=r"(((unsigned *)C_warp)[1]) + : "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "r"(((unsigned *)C_warp)[0]), "r"(((unsigned *)C_warp)[1])); +} + +__device__ __inline__ void mma_m16n8k16_f32f16f16f32(float *C_warp, half *A_shared_warp, half *B_shared_warp) +{ + __asm__ __volatile__( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" - : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) - : "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])); + : "=f"(C_warp[0]), "=f"(C_warp[1]), "=f"(C_warp[2]), "=f"(C_warp[3]) + : "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(C_warp[0]), "f"(C_warp[1]), "f"(C_warp[2]), "f"(C_warp[3])); +} + +__device__ __inline__ void mma_m16n8k16_bf16bf16f32(float *C_warp, nv_bfloat16 *A_shared_warp, nv_bfloat16 *B_shared_warp) +{ + + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(C_warp[0]), "=f"(C_warp[1]), "=f"(C_warp[2]), "=f"(C_warp[3]) + : "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(C_warp[0]), "f"(C_warp[1]), "f"(C_warp[2]), "f"(C_warp[3])); } -template -__device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) +template +__device__ __inline__ void global_to_share_one_stage_A(T *src, T *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -161,8 +194,8 @@ __device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int } } -template -__device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) +template +__device__ __inline__ void global_to_share_one_stage_B(T *src, T *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -196,8 +229,8 @@ __device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int } } -template -__device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) +template +__device__ __inline__ void global_to_share_one_stage_scales(T *src, T *dst, T *src_z, T *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) { constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G; constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1; @@ -229,8 +262,8 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst } } -template -__device__ __inline__ void share_to_reg_one_stage_A(half *src, half *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) +template +__device__ __inline__ void share_to_reg_one_stage_A(T *src, T *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) { constexpr int kSmemCol = CTA_K + SMEM_PAD_A; @@ -247,9 +280,10 @@ __device__ __inline__ void share_to_reg_one_stage_A(half *src, half *dst, int wa } } -template -__device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) +template +__device__ __inline__ void share_to_reg_one_stage_B(T *src, T *src_scales, T *src_zeros, T *dst, T *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; constexpr int kSmemCol = CTA_K + SMEM_PAD_B; int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); int c0 = ((threadIdx.x / 8) % 2) * 8; @@ -271,13 +305,21 @@ __device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, #pragma unroll for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { - half scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; - half zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; - half2 scale2 = make_half2(scale, scale); - half2 zero2 = make_half2(zero, zero); - half2 loaded[4]; - - dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); + T scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + T zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + T2 scale2, zero2; + if constexpr (std::is_same::value) + { + scale2 = __half2half2(scale); + zero2 = __half2half2(zero); + } + else + { + scale2 = __bfloat162bfloat162(scale); + zero2 = __bfloat162bfloat162(zero); + } + T2 loaded[4]; + dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); #pragma unroll for (int i = 0; i < 4; i++) { @@ -287,9 +329,10 @@ __device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, } } -template -__global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K) +template +__global__ void gemm_w4a16_T1(T *__restrict__ A, T *__restrict__ B, T *__restrict__ scales, T *__restrict__ zeros, T *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K) { + using DTypeAccum = typename std::conditional::value, float, half>::type; constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; @@ -307,7 +350,7 @@ __global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half * blockIdx_m = block_idx_mapping.x; blockIdx_n = block_idx_mapping.y; - float C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; + DTypeAccum C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; @@ -319,16 +362,16 @@ __global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half * constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load; constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load; extern __shared__ half mem_shared[]; - half *A_shared = mem_shared; - half *B_shared = mem_shared + kSmemSizeA; - half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB; - half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; - float *C_shared = reinterpret_cast(mem_shared); - half A_shared_warp_[2][WARP_M * INTRIN_K / + T *A_shared = (T*)mem_shared; + T *B_shared = (T*)mem_shared + kSmemSizeA; + T *scales_shared = (T*)mem_shared + kSmemSizeA + kSmemSizeB; + T *zeros_shared = (T*)mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; + T *C_shared = (T*)(mem_shared); + T A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE]; - half B_shared_warp_[2][WARP_N * 32 / + T B_shared_warp_[2][WARP_N * 32 / WARP_SIZE]; - half B_shared_warp_tmp_[2][WARP_N * 16 / + T B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE]; int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_n = blockIdx_n * CTA_N; @@ -371,10 +414,10 @@ __global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half * { int ld_stage = k_0_0_ld % STAGES; int compute_stage = k_0_0 % STAGES; - half *A_shared_this_compute_stage; - half *B_shared_this_compute_stage; - half *scales_shared_this_compute_stage; - half *zeros_shared_this_compute_stage; + T *A_shared_this_compute_stage; + T *B_shared_this_compute_stage; + T *scales_shared_this_compute_stage; + T *zeros_shared_this_compute_stage; #pragma unroll for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) @@ -418,15 +461,31 @@ __global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half * warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); } } - half *A_shared_warp = A_shared_warp_[iter_k % 2]; - half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; + T *A_shared_warp = A_shared_warp_[iter_k % 2]; + T *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) { for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) { - mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); - mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + if constexpr (std::is_same::value) + { + if constexpr (UseFP32Accum) + { + mma_m16n8k16_f32f16f16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16_f32f16f16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } + else + { + mma_m16n8k16_f16f16f16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16_f16f16f16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } + } + else + { + mma_m16n8k16_bf16bf16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16_bf16bf16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } } } @@ -464,129 +523,390 @@ __global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half * __pipeline_commit(); __pipeline_wait_prior(0); __syncthreads(); - if constexpr (SLICES > 1) + + if constexpr (std::is_same::value) { -#pragma unroll - for (int z = 0; z < SLICES; ++z) + if constexpr (!UseFP32Accum) { - if (slice_id == z) + if constexpr (SLICES > 1) { #pragma unroll - for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + for (int z = 0; z < SLICES; ++z) { -#pragma unroll - for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + if (slice_id == z) { #pragma unroll - for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { - if (z > 0) +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { - C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + if (z > 0) + { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; + } + C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + }; } - C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; - }; + } + } + __syncthreads(); + } + if (slice_id == 0) + { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; + }; + } } } } - __syncthreads(); + + if (slice_id == 0) + { + Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x); + + if constexpr (SPLITK > 1) + { + semaphore.fetch(); + } + + if (blockIdx_z != 0) + { + semaphore.wait(blockIdx_z); + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + + if (write_row < M) + { + half2 *existing_psum_ptr = reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); + + *existing_psum_ptr = __hadd2(*existing_psum_ptr, + *reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id)); + } + }; + } + } + } + else + { + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + if (write_row < M) + { + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + *reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id); + } + }; + } + } + } + + if constexpr (SPLITK > 1) + { + + int lock = 0; + if (SPLITK == blockIdx_z + 1) + { + + lock = 0; + } + else + { + lock = blockIdx_z + 1; + } + semaphore.release(lock); + } + } } - if (slice_id == 0) + else { -#pragma unroll - for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + if constexpr (SLICES > 1) { #pragma unroll - for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + for (int z = 0; z < SLICES; ++z) + { + if (slice_id == z) + { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + int shared_index = warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2; + if (z > 0) + { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += __half2float(C_shared[shared_index]); + } + C_shared[shared_index] = __float2half(C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]); + }; + } + } + } + __syncthreads(); + } + if (slice_id == 0) { #pragma unroll - for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { - C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; - }; +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + int shared_index = warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2; + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = __half2float(C_shared[shared_index]); + }; + } + } + } + } + + if (slice_id == 0) + { + Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x); + + if constexpr (SPLITK > 1) + { + semaphore.fetch(); + } + + if (blockIdx_z != 0) + { + semaphore.wait(blockIdx_z); + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + + if (write_row < M) + { + half2 *existing_psum_ptr = reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); + float val0 = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + float val1 = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id + 1]; + half2 packed = __floats2half2_rn(val0, val1); + *existing_psum_ptr = __hadd2(*existing_psum_ptr, packed); + } + }; + } + } + } + else + { + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + if (write_row < M) + { + float val0 = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + float val1 = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id + 1]; + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + __floats2half2_rn(val0, val1); + } + }; + } + } + } + + if constexpr (SPLITK > 1) + { + int lock = 0; + if (SPLITK == blockIdx_z + 1) + { + lock = 0; + } + else + { + lock = blockIdx_z + 1; + } + semaphore.release(lock); } } } } - - if (slice_id == 0) + else { - Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x); - - if constexpr (SPLITK > 1) + // first convert fp32 to bf16 + nv_bfloat16 C_warp16[CTA_M * CTA_N / CTA_SIZE_MN]; +#pragma unroll + for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN / 2; ++i) { - semaphore.fetch(); + ((nv_bfloat162*)C_warp16)[i] = __float22bfloat162_rn(((float2*)C_warp)[i]); } - if (blockIdx_z != 0) + // the following is the same as fp16. Maybe there is a neat way to implement this. + if constexpr (SLICES > 1) { - semaphore.wait(blockIdx_z); - for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) +#pragma unroll + for (int z = 0; z < SLICES; ++z) { - for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + if (slice_id == z) { - for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { - int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); - - if (write_row < M) +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { - half2 *existing_psum_ptr = reinterpret_cast( - C + write_row * N + - cta_offset_n + warp_offset_n + ax1_0_1 * 16 + - (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); - - *existing_psum_ptr = __hadd2(*existing_psum_ptr, - __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + - ax1_0_1 * 8 + local_id))); +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + if (z > 0) + { + C_warp16[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; + } + C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp16[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + }; } - }; + } } + __syncthreads(); } - } - else - { - for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + if (slice_id == 0) { - for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { - for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { - int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); - if (write_row < M) +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { - *reinterpret_cast( - C + write_row * N + - cta_offset_n + warp_offset_n + ax1_0_1 * 16 + - (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = - __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + - ax1_0_1 * 8 + local_id)); - } - }; + C_warp16[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; + }; + } } } - } + } - if constexpr (SPLITK > 1) + if (slice_id == 0) { + Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x); + + if constexpr (SPLITK > 1) + { + semaphore.fetch(); + } - int lock = 0; - if (SPLITK == blockIdx_z + 1) + if (blockIdx_z != 0) { + semaphore.wait(blockIdx_z); + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); - lock = 0; + if (write_row < M) + { + nv_bfloat162 *existing_psum_ptr = reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); + + *existing_psum_ptr = __hadd2(*existing_psum_ptr, + *reinterpret_cast(C_warp16 + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id)); + } + }; + } + } } else { - lock = blockIdx_z + 1; + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + if (write_row < M) + { + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + *reinterpret_cast(C_warp16 + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id); + } + }; + } + } + } + + if constexpr (SPLITK > 1) + { + + int lock = 0; + if (SPLITK == blockIdx_z + 1) + { + + lock = 0; + } + else + { + lock = blockIdx_z + 1; + } + semaphore.release(lock); } - semaphore.release(lock); } } } -template -__device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +template +__device__ __inline__ void global_to_share_one_stage_A_T2(T *src, T *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -619,8 +939,8 @@ __device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst, } } -template -__device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +template +__device__ __inline__ void global_to_share_one_stage_B_T2(T *src, T *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -654,8 +974,8 @@ __device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst, } } -template -__device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +template +__device__ __inline__ void global_to_share_one_stage_scales_T2(T *src, T *dst, T *src_z, T *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = CTA_N / PACK_SIZE / 1; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; @@ -686,8 +1006,8 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half * } } -template -__device__ __inline__ void share_to_reg_one_stage_A_T2(half *src, half *dst, int warp_offset_m, int warp_offset_n, int k_0_1) +template +__device__ __inline__ void share_to_reg_one_stage_A_T2(T *src, T *dst, int warp_offset_m, int warp_offset_n, int k_0_1) { constexpr int kSmemCol = CTA_K + SMEM_PAD_A; @@ -704,9 +1024,10 @@ __device__ __inline__ void share_to_reg_one_stage_A_T2(half *src, half *dst, int } } -template -__device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1) +template +__device__ __inline__ void share_to_reg_one_stage_B_T2(T *src, T *src_scales, T *src_zeros, T *dst, T *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1) { + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; constexpr int kSmemCol = CTA_K + SMEM_PAD_B; int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); int c0 = ((threadIdx.x / 8) % 2) * 8; @@ -728,12 +1049,21 @@ __device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scal #pragma unroll for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { - half scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; - half zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; - half2 scale2 = make_half2(scale, scale); - half2 zero2 = make_half2(zero, zero); - half2 loaded[4]; - dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); + T scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + T zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + T2 scale2, zero2; + if constexpr (std::is_same::value) + { + scale2 = __half2half2(scale); + zero2 = __half2half2(zero); + } + else + { + scale2 = __bfloat162bfloat162(scale); + zero2 = __bfloat162bfloat162(zero); + } + T2 loaded[4]; + dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); #pragma unroll for (int i = 0; i < 4; i++) { @@ -743,9 +1073,10 @@ __device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scal } } -template -__global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int M, int N, int K) +template +__global__ void gemm_w4a16_T2(T *__restrict__ A, T *__restrict__ B, T *__restrict__ scales, T *__restrict__ zeros, T *__restrict__ C, int M, int N, int K) { + using DTypeAccum = typename std::conditional::value, float, half>::type; constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; int num_blocks_n = (N + CTA_N - 1) / CTA_N; @@ -760,7 +1091,7 @@ __global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half * blockIdx_m = block_idx_mapping.x; blockIdx_n = block_idx_mapping.y; - float C_warp[CTA_M * CTA_N / CTA_SIZE]; + DTypeAccum C_warp[CTA_M * CTA_N / CTA_SIZE]; constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; @@ -771,15 +1102,15 @@ __global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half * constexpr int kSmemSizeZeros = CTA_N * STAGES / 2; constexpr int scales_load_interval = G / CTA_K; extern __shared__ half mem_shared[]; - half *A_shared = mem_shared; - half *B_shared = mem_shared + kSmemSizeA; - half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB; - half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; - half A_shared_warp_[2][WARP_M * INTRIN_K / + T *A_shared = (T*)mem_shared; + T *B_shared = (T*)mem_shared + kSmemSizeA; + T *scales_shared = (T*)mem_shared + kSmemSizeA + kSmemSizeB; + T *zeros_shared = (T*)mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; + T A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE]; - half B_shared_warp_[2][WARP_N * 32 / + T B_shared_warp_[2][WARP_N * 32 / WARP_SIZE]; - half B_shared_warp_tmp_[2][WARP_N * 16 / + T B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE]; int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_n = blockIdx_n * CTA_N; @@ -817,10 +1148,10 @@ __global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half * { int ld_stage = k_0_0_ld % STAGES; int compute_stage = k_0_0 % STAGES; - half *A_shared_this_compute_stage; - half *B_shared_this_compute_stage; - half *scales_shared_this_compute_stage; - half *zeros_shared_this_compute_stage; + T *A_shared_this_compute_stage; + T *B_shared_this_compute_stage; + T *scales_shared_this_compute_stage; + T *zeros_shared_this_compute_stage; for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) { @@ -864,14 +1195,30 @@ __global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half * } } __syncthreads(); - half *A_shared_warp = A_shared_warp_[iter_k % 2]; - half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; + T *A_shared_warp = A_shared_warp_[iter_k % 2]; + T *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) { for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) { - mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); - mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + if constexpr (std::is_same::value) + { + if constexpr (UseFP32Accum) + { + mma_m16n8k16_f32f16f16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16_f32f16f16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } + else + { + mma_m16n8k16_f16f16f16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16_f16f16f16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } + } + else + { + mma_m16n8k16_bf16bf16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16_bf16bf16f32(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } } } @@ -914,12 +1261,37 @@ __global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half * int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); if (write_row < M) { - *reinterpret_cast( - C + write_row * N + - cta_offset_n + warp_offset_n + ax1_0_1 * 16 + - (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = - __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + - ax1_0_1 * 8 + local_id)); + if constexpr (std::is_same::value) + { + if constexpr (UseFP32Accum) + { + float val0 = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + float val1 = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id + 1]; + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + __floats2half2_rn(val0, val1); + } + else + { + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + (*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id)); + } + } + else + { + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + (__float22bfloat162_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id))); + } } }; } @@ -930,16 +1302,13 @@ torch::Tensor gemm_forward_cuda_prefill( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, - torch::Tensor _zeros) + torch::Tensor _zeros, + bool use_fp32) { std::vector output_shape = _in_feats.sizes().vec(); output_shape.back() = _kernel.size(0) * kInterleave; int num_in_feats = _in_feats.numel() / _in_feats.size(-1); int num_in_channels = _in_feats.size(-1); - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto scales = reinterpret_cast(_scales.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); auto options_int = @@ -947,87 +1316,107 @@ torch::Tensor gemm_forward_cuda_prefill( at::Tensor _out_feats = torch::empty(output_shape, options); int num_out_feats = _out_feats.numel() / _out_feats.size(-1); int num_out_channels = _out_feats.size(-1); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - if (num_out_feats <= 32) - { - constexpr int G = 128; - constexpr int CTA_M = 16; - constexpr int CTA_N = 128; - constexpr int CTA_K = 128; - constexpr int WARP_M = 16; - constexpr int WARP_N = 32; - constexpr int WARP_K = 64; - constexpr int SPLITK = 2; - constexpr int STAGES = 4; - KERNEL_LAUNCH_CODE - } - else if (num_out_feats <= 64) - { + auto data_type = _in_feats.scalar_type(); + TORCH_CHECK(_scales.scalar_type() == data_type); + TORCH_CHECK(_zeros.scalar_type() == data_type); - constexpr int G = 128; - constexpr int CTA_M = 16; - constexpr int CTA_N = 128; - constexpr int CTA_K = 128; - constexpr int WARP_M = 16; - constexpr int WARP_N = 32; - constexpr int WARP_K = 64; - constexpr int SPLITK = 1; - constexpr int STAGES = 3; - KERNEL_LAUNCH_CODE - } - else if (num_out_feats <= 128) - { - constexpr int G = 128; - constexpr int CTA_M = 32; - constexpr int CTA_N = 128; - constexpr int CTA_K = 128; - constexpr int WARP_M = 32; - constexpr int WARP_N = 32; - constexpr int WARP_K = 64; - constexpr int SPLITK = 1; - constexpr int STAGES = 4; - KERNEL_LAUNCH_CODE - } - else if (num_out_feats <= 192) - { - constexpr int G = 128; - constexpr int CTA_M = 64; - constexpr int CTA_N = 128; - constexpr int CTA_K = 64; - constexpr int WARP_M = 64; - constexpr int WARP_N = 32; - constexpr int WARP_K = 64; - constexpr int SPLITK = 1; - constexpr int STAGES = 4; - KERNEL_LAUNCH_CODE - } - else - { - constexpr int G = 128; - constexpr int CTA_M = 64; - constexpr int CTA_N = 128; - constexpr int CTA_K = 64; - constexpr int WARP_M = 64; - constexpr int WARP_N = 32; - constexpr int WARP_K = 64; - constexpr int STAGES = 4; - - constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N); - constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(half); - if (kSmemByteSize >= 99 * 1024) + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(data_type, ctype, { + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + if (num_out_feats <= 32) { - printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); - return _out_feats; + constexpr int G = 128; + constexpr int CTA_M = 16; + constexpr int CTA_N = 128; + constexpr int CTA_K = 128; + constexpr int WARP_M = 16; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 2; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE } - int j_factors1 = num_out_channels / CTA_N / 1; - dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1); - dim3 threads_per_block(WARP_SIZE, NUM_WARPS); - auto kernel_func = gemm_w4a16_T2; - cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); - kernel_func<<>>( - in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); - } + else if (num_out_feats <= 64) + { + constexpr int G = 128; + constexpr int CTA_M = 16; + constexpr int CTA_N = 128; + constexpr int CTA_K = 128; + constexpr int WARP_M = 16; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 1; + constexpr int STAGES = 3; + KERNEL_LAUNCH_CODE + } + else if (num_out_feats <= 128) + { + constexpr int G = 128; + constexpr int CTA_M = 32; + constexpr int CTA_N = 128; + constexpr int CTA_K = 128; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 1; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } + else if (num_out_feats <= 192) + { + constexpr int G = 128; + constexpr int CTA_M = 64; + constexpr int CTA_N = 128; + constexpr int CTA_K = 64; + constexpr int WARP_M = 64; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 1; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } + else + { + constexpr int G = 128; + constexpr int CTA_M = 64; + constexpr int CTA_N = 128; + constexpr int CTA_K = 64; + constexpr int WARP_M = 64; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 4; + + constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N); + constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(ctype); + if (kSmemByteSize >= 99 * 1024) + { + printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); + return _out_feats; + } + int j_factors1 = num_out_channels / CTA_N / 1; + dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1); + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); + if (use_fp32) + { + auto kernel_func = gemm_w4a16_T2; + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); + kernel_func<<>>( + in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); + } + else + { + auto kernel_func = gemm_w4a16_T2; + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); + kernel_func<<>>( + in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); + } + } + }); return _out_feats; -} \ No newline at end of file +} diff --git a/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.h b/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.h index 60c9ece40..ce142be14 100644 --- a/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.h +++ b/gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.h @@ -1,3 +1,3 @@ #include -torch::Tensor gemm_forward_cuda_prefill(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros); +torch::Tensor gemm_forward_cuda_prefill(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros, bool use_fp32); diff --git a/gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu b/gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu index 78d12b49b..d8da3a696 100644 --- a/gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu +++ b/gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu @@ -25,45 +25,26 @@ */ #include +#include #include #include #include "gemv_cuda.h" #include "../dequantize.cuh" +#include "../dispatch_utils.cuh" #define PACK_FACTOR 8 #define WARP_SIZE 32 #define MEM_ACCESS_SIZE 128 - -static inline __device__ float to_float(half src) -{ - return __half2float(src); -} - -static inline __device__ float to_float(float src) -{ - return src; -} - -static inline __device__ half to_half(float src) -{ - return __float2half(src); -} - -static inline __device__ half to_half(half src) -{ - return src; -} - // Reduce sum within the warp using the tree reduction algorithm. -template -__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4]) +template +__device__ __forceinline__ static void warp_reduce(T* psum, float (*out_smem)[Num * 4]) { // kInterleave = 4 float fpsum[Num]; #pragma unroll for (int i = 0; i < Num; ++i) { - fpsum[i] = to_float(psum[i]); + fpsum[i] = static_cast(psum[i]); } #pragma unroll @@ -91,9 +72,9 @@ __device__ __forceinline__ int make_divisible(int c, int divisor){ return (c + divisor - 1) / divisor; } -template +template __global__ void gemv_kernel( - const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs, + const T* inputs, const uint32_t* weight, const T* scales, const T* zeros, T* outputs, const int IC, const int OC) { const int kStride = 64; @@ -101,6 +82,12 @@ __global__ void gemv_kernel( const int kThreadsNumPerTile = kStride / kElemsPerThread; // assert(MEM_ACCESS_SIZE == 128); + using T2 = typename std::conditional< + std::is_same::value, + half2, + nv_bfloat162 + >::type; + static constexpr int kShuffleSize = 32; static constexpr int kShuffleBasicTile = 2; static constexpr int kShuffleContinous = 4; @@ -109,19 +96,20 @@ __global__ void gemv_kernel( constexpr int Num = NPerBlock * Batch; constexpr int kInterleave = 4; - half local_inputs[kElemsPerThread]; + T local_inputs[kElemsPerThread]; uint32_t local_qweights[MEM_ACCESS_SIZE / 32]; - half half_weight_buffer[kElemsPerThread]; - half dequantized_weight[kElemsPerThread * NPerBlock]; - half local_scale[NPerBlock]; - half local_scaled_zeros[NPerBlock]; + T half_weight_buffer[kElemsPerThread]; + T dequantized_weight[kElemsPerThread * NPerBlock]; + T local_scale[NPerBlock]; + T local_scaled_zeros[NPerBlock]; - half psum[Num]; + T psum[Num]; for (int i = 0; i < Num; ++i) - psum[i] = to_half(0.f); + psum[i] = static_cast(0.f); - extern __shared__ uint8_t shmem[]; - float(*out_smem)[Num * kInterleave] = reinterpret_cast(shmem); + // extern __shared__ uint8_t shmem[]; + // float(*out_smem)[Num * kInterleave] = reinterpret_cast(shmem); + __shared__ float out_smem[BlockSize / WARP_SIZE * 2][Num * kInterleave]; const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave; const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave; @@ -130,9 +118,9 @@ __global__ void gemv_kernel( const int group_offset = act_k_offset / GroupSize; // TODO: use make_divisible const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR; - const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC; - const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC; - const half* inputs_ptr = inputs + act_k_offset; + const T* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC; + const T* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC; + const T* inputs_ptr = inputs + act_k_offset; const int act_forward_step = BlockSize * kElemsPerThread / kInterleave; const int scale_forward_step = act_forward_step / GroupSize * OC; @@ -155,7 +143,7 @@ __global__ void gemv_kernel( for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) { // Converts 32 bits (8 x int4) to 8 fp16 - dequantize_s4_to_fp16x2(*reinterpret_cast(local_qweights + i), reinterpret_cast(half_weight_buffer + i * PACK_FACTOR)); + dequantize_s4_to_fp16x2(*reinterpret_cast(local_qweights + i), reinterpret_cast(half_weight_buffer + i * PACK_FACTOR)); } // Dequantize (apply s/z) and shuffle elements to match the weight packing format @@ -165,11 +153,18 @@ __global__ void gemv_kernel( #pragma unroll for (int j = 0; j < kShuffleStrided; ++j) { - half2 w = - *reinterpret_cast( + T2 w = + *reinterpret_cast( half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile ); - w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx])); + if constexpr (std::is_same::value) + { + w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx])); + } + else + { + w = __hfma2(w, __bfloat162bfloat162(local_scale[idx]), __bfloat162bfloat162(local_scaled_zeros[idx])); + } dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x; @@ -182,7 +177,7 @@ __global__ void gemv_kernel( #pragma unroll for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) { - const half* local_inputs_ptr = inputs_ptr + batch_idx * IC; + const T* local_inputs_ptr = inputs_ptr + batch_idx * IC; #pragma unroll for (int idx = 0; idx < kElemsPerThread / 8; ++idx) { @@ -196,10 +191,20 @@ __global__ void gemv_kernel( #pragma unroll for (int y = 0; y < kElemsPerThread; ++y) { - *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2) - = __hfma2(*reinterpret_cast(dequantized_weight + y * NPerBlock + x * 2), - __half2half2(local_inputs[y]), - *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2)); + if constexpr (std::is_same::value) + { + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2) + = __hfma2(*reinterpret_cast(dequantized_weight + y * NPerBlock + x * 2), + __half2half2(local_inputs[y]), + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2)); + } + else + { + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2) + = __hfma2(*reinterpret_cast(dequantized_weight + y * NPerBlock + x * 2), + __bfloat162bfloat162(local_inputs[y]), + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2)); + } } } } @@ -220,7 +225,7 @@ __global__ void gemv_kernel( { acc += out_smem[j][i]; } - outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc); + outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast(acc); } } @@ -252,78 +257,83 @@ torch::Tensor gemv_forward_cuda_decode( std::vector output_shape = _in_feats.sizes().vec(); output_shape.back() = n; - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto data_type = _in_feats.scalar_type(); + TORCH_CHECK(_scaling_factors.scalar_type() == data_type); + TORCH_CHECK(_zeros.scalar_type() == data_type); auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty(output_shape, options); - half * out_feats = reinterpret_cast(_out_feats.data_ptr()); - - static constexpr int N_PER_BLOCK = 2; - static constexpr int K_INTERLEAVE = 4; - static constexpr int BLOCK_SIZE = 256; - dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); - dim3 num_threads(BLOCK_SIZE); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(data_type, ctype, { + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + static constexpr int N_PER_BLOCK = 2; + static constexpr int K_INTERLEAVE = 4; + static constexpr int BLOCK_SIZE = 256; - // if (group_size == 64) - // { - // gemv_kernel_g64<<>>( - // // pointers - // in_feats, kernel, zeros, scaling_factors, out_feats, - // // constants - // num_in_channels, num_out_channels - // ); - // } - if (group_size == 128) - { - switch (m) + dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); + dim3 num_threads(BLOCK_SIZE); + + // if (group_size == 64) + // { + // gemv_kernel_g64<<>>( + // // pointers + // in_feats, kernel, zeros, scaling_factors, out_feats, + // // constants + // num_in_channels, num_out_channels + // ); + // } + if (group_size == 128) + { + switch (m) + { + case 1: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 2: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 3: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 4: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 5: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 6: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 7: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + default: + throw std::runtime_error("Unsupported batch size for gemv kernel.\n"); + } + } + else { - case 1: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - case 2: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - case 3: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - case 4: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - case 5: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - case 6: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - case 7: - gemv_kernel<<>>( - in_feats, kernel, scaling_factors, zeros, out_feats, k, n - ); - break; - default: - throw std::runtime_error("Unsupported batch size for gemv kernel.\n"); + throw std::runtime_error("Unsupported group size for gemv kernel.\n"); } - } - else - { - throw std::runtime_error("Unsupported group size for gemv kernel.\n"); - } + }); return _out_feats; } - diff --git a/setup.py b/setup.py index 10c8ac90b..931cd0e58 100644 --- a/setup.py +++ b/setup.py @@ -750,22 +750,8 @@ def _hipify_compile_flags(flags): print("Skipping AWQ kernels on ROCm: inline PTX is CUDA-only.") else: extensions += [ - # contain un-hipifiable inline PTX cpp_ext.CUDAExtension( "gptqmodel_awq_kernels", - [ - "gptqmodel_ext/awq/pybind_awq.cpp", - "gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu", - "gptqmodel_ext/awq/quantization/gemv_cuda.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ), - # TODO only compatible with ampere? - # arch_flags = get_compute_capabilities({80, 86, 89, 90}) - # extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags) - cpp_ext.CUDAExtension( - "gptqmodel_awq_v2_kernels", [ "gptqmodel_ext/awq/pybind_awq_v2.cpp", "gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu", diff --git a/tests/test_awq_layout_conversion.py b/tests/test_awq_layout_conversion.py new file mode 100644 index 000000000..c23112320 --- /dev/null +++ b/tests/test_awq_layout_conversion.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import torch +import pytest + +from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear +from gptqmodel.quantization.awq.utils.mit_repacker import ( + multiply_scale_qzero_negative as mit_multiply_scale_qzero_negative, + packing_v2_from_unpacked as mit_packing_v2_from_unpacked, + qweight_unpack as mit_qweight_unpack, +) +from gptqmodel.quantization.awq.utils.module import try_import +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm +from gptqmodel.quantization.awq.modules.linear.gemv_fast import calculate_zeros_width +from safetensors.torch import load_file + + +def _pack_weights_v1(intweight: torch.Tensor, bits: int) -> torch.Tensor: + pack_num = 32 // bits + order_map = list(range(pack_num)) + rows, cols = intweight.shape + packed_cols = (cols + pack_num - 1) // pack_num + packed = torch.zeros((rows, packed_cols), dtype=torch.int32) + mask = (1 << bits) - 1 + for col in range(packed_cols): + for idx, order in enumerate(order_map): + src = col * pack_num + order + if src >= cols: + continue + packed[:, col] |= ((intweight[:, src].to(torch.int32) & mask) << (idx * bits)) + return packed + + +def _pack_zeros_v1(zeros: torch.Tensor, bits: int) -> torch.Tensor: + pack_num = 32 // bits + order_map = list(range(pack_num)) + rows, cols = zeros.shape + packed_cols = (cols + pack_num - 1) // pack_num + packed = torch.zeros((rows, packed_cols), dtype=torch.int32) + mask = (1 << bits) - 1 + for col in range(packed_cols): + for idx, order in enumerate(order_map): + src = col * pack_num + order + if src >= cols: + continue + packed[:, col] |= ((zeros[:, src].to(torch.int32) & mask) << (idx * bits)) + return packed + + +def test_awq_gemm_legacy_conversion_matches_v2(): + torch.manual_seed(0) + bits = 4 + group_size = 128 + in_features = 256 + out_features = 128 + + groups = in_features // group_size + assert groups * group_size == in_features + + intweight = torch.randint(0, 2 ** bits, size=(out_features, in_features), dtype=torch.int32) + scales = torch.rand(groups, out_features, dtype=torch.float16) * 2.0 + 0.5 + zeros = torch.randint(0, 2 ** bits, size=(groups, out_features), dtype=torch.int32) + intweight_t = intweight.t().contiguous() + qweight_v1 = _pack_weights_v1(intweight_t, bits) + qzeros_v1 = _pack_zeros_v1(zeros, bits) + scales_v1 = scales.clone() + + module = AwqGEMMQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=False, + ) + module.load_legacy_tensors( + qweight_v1.clone(), + qzeros_v1.clone(), + scales_v1.clone(), + torch.zeros(out_features, dtype=torch.float16), + ) + + unpacked_expected = mit_qweight_unpack(qweight_v1.clone()) + if unpacked_expected.shape[0] == in_features and unpacked_expected.shape[1] == out_features: + unpacked_expected = unpacked_expected.transpose(0, 1).contiguous() + expected_qweight = mit_packing_v2_from_unpacked(unpacked_expected, interleave=4, kstride=64) + pack_num = 32 // bits + zeros_width = calculate_zeros_width(in_features, group_size, pack_num=pack_num) + expected_scales = torch.zeros((zeros_width * pack_num, out_features), dtype=torch.float16) + expected_scales[: scales.shape[0], :] = scales + + def reference_scaled_zeros(scales_ref: torch.Tensor, qzeros_ref: torch.Tensor, shift: int = 0) -> torch.Tensor: + pack_size = 8 + rows, cols = scales_ref.shape + qzeros_ref = qzeros_ref.to(torch.int32) + col_indices = torch.arange(cols, device=scales_ref.device, dtype=torch.int32) + zero_idx = col_indices // pack_size + zero_offset = col_indices % pack_size + zeros = (qzeros_ref[:, zero_idx] >> (zero_offset * 4)) & 0xF + zeros = zeros.to(scales_ref.dtype) + scaled = scales_ref * zeros + if shift: + scaled = scaled + shift * scales_ref + return -scaled + + expected_qzeros = torch.zeros_like(expected_scales) + expected_qzeros[: scales.shape[0], :] = reference_scaled_zeros(scales, qzeros_v1) + expected_bias = torch.zeros(out_features, dtype=torch.float16) + + assert module.qweight.dtype == torch.int16 + torch.testing.assert_close(module.qweight.to(torch.int32), expected_qweight.to(torch.int32)) + torch.testing.assert_close(module.scales.to(torch.float32), expected_scales.to(torch.float32)) + torch.testing.assert_close(module.qzeros.to(torch.float32), expected_qzeros.to(torch.float32)) + torch.testing.assert_close(module.bias.to(torch.float16), expected_bias.to(torch.float16)) + + awq_ext, _ = try_import("gptqmodel_awq_kernels") + if awq_ext is None or not torch.cuda.is_available(): + pytest.skip("AWQ CUDA kernels unavailable for forward validation") + + module = module.to("cuda") + inputs = torch.randn(2, 8, in_features, dtype=torch.float16, device="cuda") + + def dense_weight_from_v1(intweight_ref: torch.Tensor, zeros_ref: torch.Tensor, scales_ref: torch.Tensor) -> torch.Tensor: + intweight_ref = intweight_ref.to(torch.float32) + zeros_ref = zeros_ref.to(torch.float32) + scales_ref = scales_ref.to(torch.float32) + weight = torch.empty(in_features, out_features, dtype=torch.float32) + for group_idx in range(groups): + start = group_idx * group_size + end = start + group_size + block = intweight_ref[:, start:end].T # (group, out_features) + zero = zeros_ref[group_idx].unsqueeze(0) + scale = scales_ref[group_idx].unsqueeze(0) + weight[start:end] = (block - zero) * scale + return weight + + weight_ref = dense_weight_from_v1(intweight, zeros, scales).to(device="cuda", dtype=torch.float16) + with torch.no_grad(): + expected_out = torch.matmul(inputs.reshape(-1, in_features), weight_ref) + expected_out = expected_out.view(inputs.shape[0], inputs.shape[1], out_features) + actual_out = module(inputs) + torch.testing.assert_close(actual_out, expected_out, atol=5e-1, rtol=5e-3) + + +@pytest.mark.cuda +def test_awq_gemm_conversion_real_checkpoint(): + ckpt_dir = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") + safetensor_path = ckpt_dir / "model-00001-of-00002.safetensors" + if not safetensor_path.exists(): + pytest.skip("DeepSeek AWQ checkpoint unavailable at expected path.") + + tensors = load_file(str(safetensor_path)) + tensor_prefix = "model.layers.0.self_attn.q_proj" + qweight = tensors[f"{tensor_prefix}.qweight"] + qzeros = tensors[f"{tensor_prefix}.qzeros"] + scales = tensors[f"{tensor_prefix}.scales"] + + bits = 4 + group_size = 128 + groups = scales.shape[0] + out_features = scales.shape[1] + in_features = groups * group_size + + module = AwqGEMMQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=False, + register_buffers=False, + ) + module.load_legacy_tensors(qweight, qzeros, scales, bias=None) + + pack_num = 32 // bits + zeros_width = calculate_zeros_width(in_features, group_size, pack_num=pack_num) + + unpacked = mit_qweight_unpack(qweight) + if unpacked.shape == (in_features, out_features): + unpacked = unpacked.transpose(0, 1).contiguous() + expected_qweight = mit_packing_v2_from_unpacked(unpacked, interleave=4, kstride=64) + + scales_groups = scales if scales.shape == (groups, out_features) else scales.transpose(0, 1).contiguous() + expected_zero_cols = out_features // pack_num + qzeros_groups = qzeros + if qzeros_groups.shape == (out_features, expected_zero_cols): + qzeros_groups = qzeros_groups.transpose(0, 1).contiguous() + elif qzeros_groups.shape == (expected_zero_cols, out_features): + qzeros_groups = qzeros_groups.transpose(0, 1).contiguous() + + scaled_zeros_groups = mit_multiply_scale_qzero_negative(scales_groups, qzeros_groups, zp_shift=0) + + padded_rows = zeros_width * pack_num + expected_scales = torch.zeros((padded_rows, out_features), dtype=scales_groups.dtype) + expected_zeros = torch.zeros_like(expected_scales) + expected_scales[: groups, :] = scales_groups + expected_zeros[: groups, :] = scaled_zeros_groups + + torch.testing.assert_close(module.qweight.to(torch.int32), expected_qweight.to(torch.int32)) + torch.testing.assert_close(module.scales.to(torch.float32), expected_scales.to(torch.float32)) + torch.testing.assert_close(module.qzeros.to(torch.float32), expected_zeros.to(torch.float32)) diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index 0469f2763..989013b50 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -231,16 +231,18 @@ def _maybe_skip_backend(self, backend: BACKEND): self.skipTest("Machete requires NVIDIA Hopper or newer (SM90+)") float16_cases = [ - (BACKEND.TORCH, torch.float16, 0.0000), - (BACKEND.TRITON, torch.float16, 0.00001), - (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), - (BACKEND.MACHETE, torch.float16, 0.00040), - (BACKEND.MARLIN, torch.float16, 0.00035), - (BACKEND.BITBLAS, torch.float16, 0.0035), + ("torch_fp16", BACKEND.TORCH, torch.float16, 0.0000), + ("triton_fp16", BACKEND.TRITON, torch.float16, 0.00001), + ("exllamav2_fp16", BACKEND.EXLLAMA_V2, torch.float16, 0.0068), + ("machete_fp16", BACKEND.MACHETE, torch.float16, 0.00040), + ("marlin_fp16", BACKEND.MARLIN, torch.float16, 0.00035), + ("bitblas_fp16", BACKEND.BITBLAS, torch.float16, 0.0035), ] @parameterized.expand(float16_cases) - def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + def test_kernel_float16( + self, _case_name: str, backend: BACKEND, dtype: torch.dtype, a_tolerance: float + ): self._maybe_skip_backend(backend) data = self.data[dtype] @@ -257,16 +259,18 @@ def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance ) bfloat16_cases = [ - (BACKEND.TORCH, torch.bfloat16, 0.0000), - (BACKEND.TRITON, torch.bfloat16, 0.00001), - (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), - (BACKEND.MACHETE, torch.bfloat16, 0.0033), - (BACKEND.MARLIN, torch.bfloat16, 0.0031), - (BACKEND.BITBLAS, torch.bfloat16, 0.0031), + ("torch_bf16", BACKEND.TORCH, torch.bfloat16, 0.0000), + ("triton_bf16", BACKEND.TRITON, torch.bfloat16, 0.00001), + ("exllamav2_bf16", BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), + ("machete_bf16", BACKEND.MACHETE, torch.bfloat16, 0.0033), + ("marlin_bf16", BACKEND.MARLIN, torch.bfloat16, 0.0031), + ("bitblas_bf16", BACKEND.BITBLAS, torch.bfloat16, 0.0031), ] @parameterized.expand(bfloat16_cases) - def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + def test_kernel_bfloat16( + self, _case_name: str, backend: BACKEND, dtype: torch.dtype, a_tolerance: float + ): self._maybe_skip_backend(backend) data = self.data[dtype] @@ -283,16 +287,18 @@ def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance ) float16_lora_cases = [ - (BACKEND.TORCH, torch.float16, 0.0000), - (BACKEND.TRITON, torch.float16, 0.00001), - (BACKEND.EXLLAMA_V2, torch.float16, 0.0065), - (BACKEND.MACHETE, torch.float16, 0.00040), - (BACKEND.MARLIN, torch.float16, 0.00035), - (BACKEND.BITBLAS, torch.float16, 0.00035), + ("torch_fp16", BACKEND.TORCH, torch.float16, 0.0000), + ("triton_fp16", BACKEND.TRITON, torch.float16, 0.00001), + ("exllamav2_fp16", BACKEND.EXLLAMA_V2, torch.float16, 0.0065), + ("machete_fp16", BACKEND.MACHETE, torch.float16, 0.00040), + ("marlin_fp16", BACKEND.MARLIN, torch.float16, 0.00035), + ("bitblas_fp16", BACKEND.BITBLAS, torch.float16, 0.00035), ] @parameterized.expand(float16_lora_cases) - def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + def test_kernel_float16_with_lora( + self, _case_name: str, backend: BACKEND, dtype: torch.dtype, a_tolerance: float + ): self._maybe_skip_backend(backend) data = self.data[dtype] @@ -308,16 +314,18 @@ def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_ ) bfloat16_lora_cases = [ - (BACKEND.TORCH, torch.bfloat16, 0.0000), - (BACKEND.TRITON, torch.bfloat16, 0.00001), - (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), - (BACKEND.MACHETE, torch.bfloat16, 0.0033), - (BACKEND.MARLIN, torch.bfloat16, 0.0050), - (BACKEND.BITBLAS, torch.bfloat16, 0.0033), + ("torch_bf16", BACKEND.TORCH, torch.bfloat16, 0.0000), + ("triton_bf16", BACKEND.TRITON, torch.bfloat16, 0.00001), + ("exllamav2_bf16", BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), + ("machete_bf16", BACKEND.MACHETE, torch.bfloat16, 0.0033), + ("marlin_bf16", BACKEND.MARLIN, torch.bfloat16, 0.0050), + ("bitblas_bf16", BACKEND.BITBLAS, torch.bfloat16, 0.0033), ] @parameterized.expand(bfloat16_lora_cases) - def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): + def test_kernel_bfloat16_with_lora( + self, _case_name: str, backend: BACKEND, dtype: torch.dtype, a_tolerance: float + ): self._maybe_skip_backend(backend) data = self.data[dtype] diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 7b3c0a5bf..74e4373e7 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -36,21 +36,33 @@ RESET = "\033[0m" +def _reorder_packed_to_awq_order(packed: torch.Tensor, bits: int) -> torch.Tensor: + if bits != 4: + return packed + order = [0, 2, 4, 6, 1, 3, 5, 7] + mask = (1 << bits) - 1 + result = torch.zeros_like(packed) + for dst, src in enumerate(order): + nib = (packed >> (src * bits)) & mask + result |= nib << (dst * bits) + return result + + class TestAwqKernelOutput(unittest.TestCase): MODEL_PATH = Path("/monster/data/model/deepseek-r1-distill-qwen-7b-awq") TARGET = "model.layers.20.self_attn.v_proj" BITS = 4 GROUP_SIZE = 128 - SUPPORTED_DTYPES = (torch.float16,) + SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) baseline_backend = BACKEND.TORCH_AWQ backend_cases = [ - (baseline_backend, torch.float16, 0.0), - # (baseline_backend, torch.bfloat16, 0.0), - (BACKEND.GEMM, torch.float16, 0.001), - # (BACKEND.GEMM, torch.bfloat16, 0.05), - (BACKEND.MARLIN, torch.float16, 0.01), - # (BACKEND.MARLIN, torch.bfloat16, 0.05), + ("torch_awq_fp16", baseline_backend, torch.float16, 0.0), + ("torch_awq_bf16", baseline_backend, torch.bfloat16, 0.0), + ("gemm_fp16", BACKEND.GEMM, torch.float16, 0.0039), + ("gemm_bf16", BACKEND.GEMM, torch.bfloat16, 0.016), + ("marlin_fp16", BACKEND.MARLIN, torch.float16, 0.016), + ("marlin_bf16", BACKEND.MARLIN, torch.bfloat16, 0.016), ] @classmethod @@ -62,6 +74,7 @@ def setUpClass(cls) -> None: cls.log = log cls._weight_map = cls._load_weight_map() cls.backend_skip_reason: Dict[BACKEND, str] = {} + cls._forward_kwargs: Dict[torch.dtype, Dict[str, torch.dtype]] = {} try: tensors = cls._load_awq_tensors(cls.TARGET) @@ -112,6 +125,7 @@ def setUpClass(cls) -> None: "compute_dtype": torch.float16, "output_dtype": dtype, } + cls._forward_kwargs[dtype] = forward_kwargs cls.reference_outputs[dtype] = cls._forward( torch_module, converted_inputs, @@ -165,16 +179,17 @@ def _build_gemm_module( out_features=cls.out_features, bias=True, adapter=None, - register_buffers=True, + register_buffers=False, ).to(cls.device) - module.qweight.copy_(qweight_cpu.to(cls.device)) - module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(cls.device)) - module.bias.copy_(bias_cpu.to(cls.device)) + module.load_legacy_tensors( + qweight_cpu.to(cls.device), + qzeros_cpu.to(cls.device), + scales_cpu.to(cls.device), + bias_cpu.to(cls.device), + ) module.eval() - module.post_init() return module @classmethod @@ -209,8 +224,11 @@ def _build_marlin_module( register_buffers=True, ).to(cls.device) - module.qweight.data.copy_(qweight_cpu.to(cls.device)) - module.qzeros.data.copy_(qzeros_cpu.to(cls.device)) + qweight_reordered = _reorder_packed_to_awq_order(qweight_cpu, cls.BITS) + qzeros_reordered = _reorder_packed_to_awq_order(qzeros_cpu, cls.BITS) + + module.qweight.data.copy_(qweight_reordered.to(cls.device)) + module.qzeros.data.copy_(qzeros_reordered.to(cls.device)) module.scales.data.copy_(scales_cpu.to(torch.float16).to(cls.device)) module.bias.data.copy_(bias_cpu.to(torch.float16).to(cls.device)) @@ -235,49 +253,71 @@ def _build_torch_awq_module( out_features=cls.out_features, bias=True, adapter=None, - register_buffers=True, + register_buffers=False, ).to(cls.device) - module.qweight.copy_(qweight_cpu.to(cls.device)) - module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(cls.device)) - module.bias.copy_(bias_cpu.to(cls.device)) + module.load_legacy_tensors( + qweight_cpu.to(cls.device), + qzeros_cpu.to(cls.device), + scales_cpu.to(cls.device), + bias_cpu.to(cls.device), + ) module.eval() - module.post_init() return module @classmethod - def _generate_inputs(cls) -> List[torch.Tensor]: - large_shapes = [(4, 32), (2, 64), (1, 96)] - medium_shapes = [(2, 32), (1, 48), (1, 32)] - small_shapes = [(1, 32), (1, 24), (1, 16)] + def _parse_shapes(cls, expr: str) -> List[Tuple[int, int]]: + shapes: List[Tuple[int, int]] = [] + for part in expr.split(","): + part = part.strip() + if not part: + continue + dim_str, samples_str = part.split(":", 1) + shapes.append((int(dim_str), int(samples_str))) + return shapes - try: - total_mem_gb = ( - torch.cuda.get_device_properties(cls.device).total_memory - / (1024 ** 3) - ) - except Exception: # pragma: no cover - total_mem_gb = 0.0 + @classmethod + def _generate_inputs(cls) -> List[torch.Tensor]: + large_shapes = [(1, 256), (16, 128), (32, 64), (64, 32), (128, 16)] + medium_shapes = [(1, 128), (16, 64), (32, 32), (64, 16)] + small_shapes = [(1, 64), (8, 32), (16, 16)] - if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": - shapes = small_shapes - elif total_mem_gb >= 80: - shapes = large_shapes - elif total_mem_gb >= 48: - shapes = medium_shapes + env_shapes = os.getenv("GPTQMODEL_KERNEL_TEST_SHAPES") + if env_shapes: + shapes = cls._parse_shapes(env_shapes) else: - shapes = small_shapes + total_mem_gb = 0.0 + if torch.cuda.is_available(): + device_index = cls.device.index if cls.device.index is not None else 0 + try: # pragma: no cover - hardware dependent + if torch.cuda.device_count() > device_index: + props = torch.cuda.get_device_properties(device_index) + total_mem_gb = props.total_memory / (1024 ** 3) + except Exception: # pragma: no cover - fall back to smallest shapes + total_mem_gb = 0.0 + + if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": + shapes = small_shapes + elif total_mem_gb >= 80: + shapes = large_shapes + elif total_mem_gb >= 48: + shapes = medium_shapes + else: + shapes = small_shapes + + cls._shape_plan = shapes + cls._random_input_sample_size = sum(samples for _, samples in shapes) inputs: List[torch.Tensor] = [] - for batch, tokens in shapes: - tensor = torch.rand( - (batch, tokens, cls.in_features), - device=cls.device, - dtype=torch.float16, - ) - inputs.append(tensor) + for leading_dim, samples in shapes: + for _ in range(samples): + tensor = torch.rand( + (leading_dim, cls.in_features), + device=cls.device, + dtype=torch.float16, + ) + inputs.append(tensor) return inputs @classmethod @@ -327,15 +367,18 @@ def _summarize_results( diff = torch.abs(reference_fp32 - actual_fp32) max_abs_diff = max(max_abs_diff, float(diff.max().item())) mean_abs_diff += float(diff.mean().item()) - is_close_tensor = torch.isclose(reference_fp32, actual_fp32, rtol=0.15, atol=atol) + is_close_tensor = torch.isclose(reference, actual, rtol=0.15, atol=atol) if not bool(torch.all(is_close_tensor)): + sample_max = float(diff.max().item()) + sample_mean = float(diff.mean().item()) failures.append( - "Sample {idx}:\nExpected ({ref_label}) = {expected}\nActual = {actual_val}".format( + "Sample {idx}: max_abs_diff={max_diff:.6f}, mean_abs_diff={mean_diff:.6f}, " + "rtol=0.15, atol={atol:.6f}".format( idx=idx, - ref_label=reference_label, - expected=reference_fp32.detach().cpu().tolist(), - actual_val=actual_fp32.detach().cpu().tolist(), - ) + max_diff=sample_max, + mean_diff=sample_mean, + atol=atol, + ), ) status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" @@ -370,12 +413,18 @@ def _summarize_results( self.log.info("\n" + title + "\n" + table) if failures: + preview = "\n".join(failures[:5]) + if len(failures) > 5: + preview += f"\n... ({len(failures) - 5} additional mismatches)" raise AssertionError( - f"{len(failures)} mismatched outputs for backend `{backend}`" + f"{len(failures)} mismatched outputs for backend `{backend}` " + f"(rtol=0.15, atol={atol:.6f})\n{preview}" ) @parameterized.expand(backend_cases) - def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: float) -> None: + def test_awq_kernel_outputs( + self, _case_name: str, backend: BACKEND, dtype: torch.dtype, atol: float + ) -> None: self._maybe_skip_backend(backend) module = self.modules.get(backend) @@ -387,7 +436,8 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl if backend == self.baseline_backend: actual_outputs = reference_outputs else: - actual_outputs = self._forward(module, inputs) + forward_kwargs = self._forward_kwargs.get(dtype, {}) + actual_outputs = self._forward(module, inputs, **forward_kwargs) self._summarize_results( reference_outputs=reference_outputs, actual_outputs=actual_outputs,