From c91e0b3d03f73f987c4e0df6fcc8f0253ba9aaca Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 09:16:55 +0000 Subject: [PATCH 01/19] add bnb cpu Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 550 ++++++++++++++++++ .../bitsandbytes_cpu/bitsandbytes_avx512.hpp | 239 ++++++++ .../bitsandbytes_cpu/bitsandbytes_cpu.cpp | 48 ++ .../bitsandbytes_cpu/bitsandbytes_cpu.hpp | 15 + .../bitsandbytes_cpu_torch.cpp | 25 + .../bitsandbytes_cpu/cpu_features.hpp | 176 ++++++ quantization-bitsandbytes/build.toml | 27 + quantization-bitsandbytes/flake.nix | 21 + .../quantization_bitsandbytes/__init__.py | 3 + .../quantization_bitsandbytes/custom_ops.py | 11 + .../torch-ext/torch_binding.cpp | 42 ++ 11 files changed, 1157 insertions(+) create mode 100644 quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp create mode 100644 quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp create mode 100644 quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp create mode 100644 quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp create mode 100644 quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp create mode 100644 quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp create mode 100644 quantization-bitsandbytes/build.toml create mode 100644 quantization-bitsandbytes/flake.nix create mode 100644 quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py create mode 100644 quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py create mode 100644 quantization-bitsandbytes/torch-ext/torch_binding.cpp diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp new file mode 100644 index 0000000..be6219b --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -0,0 +1,550 @@ +// AVX512 implementation - compile with -mavx512f -mavx512bf16 +#include +#include +#include +#include +#include +#include +#include + +namespace bitsandbytes_cpu +{ + namespace avx512 + { + inline __m256i cvt_fp32_to_fp16(const __m512 src) + { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + + inline __m256i cvt_fp32_to_bf16(const __m512 src) + { + if (has_avx512bf16()) + { + return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); + } + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); + } + + static inline __m512 set_nf4_lut() + { + return _mm512_set_ps( + 1.0f, 0.7229568362236023, 0.5626170039176941, 0.44070982933044434, 0.33791524171829224, 0.24611230194568634, + 0.16093020141124725, 0.07958029955625534, 0.0f, -0.09105003625154495, -0.18477343022823334, + -0.28444138169288635, -0.39491748809814453, -0.5250730514526367, -0.6961928009986877, -1.0f); + } + + static inline __m512 set_fp4_lut() + { + return _mm512_set_ps( + -0.2500f, -0.16666667f, -0.5000f, -0.33333333f, -1.0000f, -0.66666667f, -5.208333333e-03f, 0.0000f, 0.2500f, + 0.16666667f, 0.5000f, 0.33333333f, 1.0000f, 0.66666667f, 5.208333333e-03f, 0.0000f); + } + +#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + + static inline const at::BFloat16 *cast_at_bf16(const bf16_t *p) + { + static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch"); + return reinterpret_cast(p); + } + + static inline at::BFloat16 *cast_at_bf16(bf16_t *p) + { + static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch"); + return reinterpret_cast(p); + } + + template + inline void unpack_B( + bf16_t *__restrict__ Btmp, const unsigned char *__restrict__ packed_B, + const bf16_t *__restrict__ Bs, // scales [K/gs, N] in bf16 + int64_t N, int64_t K, int blocksize, int64_t ldb, int64_t ldb_tmp, int64_t strideBs) + { + // Dequant: (w - z) * s -> bf16 + const int64_t K2 = K >> 1; // 2 weights packed per byte + const int64_t gs2 = blocksize >> 1; + const int64_t ldb2 = ldb; // packed leading dimension (bytes) + const int64_t ldb_tmp2 = ldb_tmp; // output leading dimension in elements + float *btmp_ptr = reinterpret_cast(Btmp); // direct bf16 storage + + __m256i mask = _mm256_set1_epi8(0xF); // low nibble + __m256i fifteen = _mm256_set1_epi8(15); // shift [-15,15] -> [0,30] for LUT + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000); + __m512i s_idx1 = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); + __m512i s_idx0 = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + + __m512 scale_lo_fp32, scale_hi_fp32; + __m512 scales[4]; + + for (int64_t n = 0; n < N; n += 32) + { + for (int64_t k = 0; k < K2; ++k) + { + if (k % gs2 == 0) + { + const int64_t kgs = k / gs2; + // Load 32 scales (bf16) -> two fp32 vectors (first16, second16) + __m512i scales_bf16 = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + n)); + scale_lo_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 0)); + scale_hi_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 1)); + scales[0] = _mm512_permutexvar_ps(s_idx0, scale_lo_fp32); + scales[1] = _mm512_permutexvar_ps(s_idx1, scale_lo_fp32); + scales[2] = _mm512_permutexvar_ps(s_idx0, scale_hi_fp32); + scales[3] = _mm512_permutexvar_ps(s_idx1, scale_hi_fp32); + } + + // Load packed 32 bytes => 64 int4 + __m256i w_u4 = _mm256_loadu_si256(reinterpret_cast(packed_B + k * ldb2 + n)); + + // Split nibbles + __m256i w_lo = w_u4 & mask; + __m256i w_hi = _mm256_srli_epi16(w_u4, 4) & mask; + + // Shift to [0..30] before LUT + w_lo = _mm256_add_epi8(w_lo, fifteen); + w_hi = _mm256_add_epi8(w_hi, fifteen); + + // Lookup (w - z) -> bf16 using LUT (process 16-byte halves) + __m512i w_lo_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_lo), lut); + __m512i w_hi_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_hi), lut); + + __m512 w_lo_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 0)) * scales[0]; + __m512 w_hi_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 1)) * scales[1]; + __m512 w_lo_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 0)) * scales[2]; + __m512 w_hi_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 1)) * scales[3]; + + // Pack scaled (first 16 cols) then (second 16 cols) to bf16 + __m512bh packed0 = _mm512_cvtne2ps_pbh(w_hi_fp32_0, w_lo_fp32_0); + __m512bh packed1 = _mm512_cvtne2ps_pbh(w_hi_fp32_1, w_lo_fp32_1); + + // Store: two blocks of 16 bf16 (32 elements) per k iteration + _mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 0), (__m512i)packed0); + _mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 16), (__m512i)packed1); + } + } + } + + template + struct tinygemm_kernel_nn + { + static inline void apply( + const scalar_t *, const unsigned char *, scalar_t *, const scalar_t *, int64_t, int, int64_t, int64_t, int64_t, + int64_t, int64_t) + { + static_assert(sizeof(scalar_t) == 0, "tinygemm_kernel_nn primary template should never be instantiated"); + } + }; + + // The brgemm will not be used without HAS_TORCH + template + struct brgemm + { + static inline void apply( + const scalar_t *__restrict__ A, const unsigned char *__restrict__ B, scalar_t *__restrict__ C, + const scalar_t *__restrict__ Bs, scalar_t *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out) + { + return; + } + }; + + template + struct tinygemm_kernel_nn + { + static inline void apply( + const bf16_t *__restrict__ A, const unsigned char *__restrict__ B, bf16_t *__restrict__ C, + const bf16_t *__restrict__ Bs, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs) + { + static_assert(BLOCK_N % 32 == 0); + constexpr int ROWS = BLOCK_M; // 32 + constexpr int COLS = BLOCK_N / 16; // 2 + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 16 * 4; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vc_master[ROWS * COLS]; + + __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit + __m256i fifteen = _mm256_set1_epi8(15); + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000); + __m512 scales[COLS]; + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const int64_t gs2 = blocksize >> 1; // 64 / 2 = 32 + const float *a_ptr = reinterpret_cast(A); + + auto loadc = [&](auto i) + { + constexpr int col = i % COLS; + vc_master[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + auto pre_compute = [&](auto i, int64_t kgs) + { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_set1_ps(0.f); // reset accumulator + + // load scales + if constexpr (row == 0 && col % 2 == 0) + { + // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 + __m512i tmp = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); + scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0)); + scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1)); + } + }; + auto compute = [&](auto i, int64_t k) + { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) + { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0 && col % 2 == 0) + { + __m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast(B + k * ldb + col * 16)); + + // deinterleave and lookup to BF16 + __m256i vb_i8_lo = vb_u4 & mask; + __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; + vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen); + vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen); + vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); + vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); + + if constexpr (PREFETCH_SIZE_K > 0) + { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + auto post_compute = [&](auto i, int64_t kgs) + { + vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]); + }; + for (int64_t k = 0; k < K2; k += gs2) + { + Unroll{}(pre_compute, k / gs2); + for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) + { + Unroll{}(compute, k + k_offset); + } + Unroll{}(post_compute, k / gs2); + } + + auto storec = [&](auto i) + { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + if constexpr (col % 2 == 0) + { + _mm512_storeu_si512( + reinterpret_cast<__m512i *>(C + row * ldc + col * 16), + (__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i]))); + } + }; + Unroll{}(storec); + } + }; + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, blocksize, lda, ldb, ldc, \ + strideBs); + + inline uint16_t float_to_bf16_round(float x) + { + uint32_t u; + std::memcpy(&u, &x, sizeof(u)); + uint32_t lsb = (u >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + u += rounding_bias; + uint16_t hi = static_cast(u >> 16); + // Quiet NaN handling + if ((u & 0x7f800000) == 0x7f800000 && (u & 0x007fffff)) + { + hi = 0xffff; + } + return hi; + } + + template + inline void copy_stub(scalar_t *__restrict__ out, const float *__restrict__ input, int64_t size) + { + if (has_avx512bf16()) + { + int64_t d = 0; + const int V = 32; + for (; d + V <= size; d += V) + { + __m512 lo = _mm512_loadu_ps(input + d); + __m512 hi = _mm512_loadu_ps(input + d + 16); + __m512bh packed = _mm512_cvtne2ps_pbh(hi, lo); + _mm512_storeu_si512(reinterpret_cast(out + d), (__m512i)packed); + } + for (; d < size; ++d) + { + if constexpr (std::is_same_v) + { + // store raw bf16 bits + reinterpret_cast(out)[d] = float_to_bf16_round(input[d]); + } + else + { + out[d] = static_cast(input[d]); + } + } + } + else + { + for (int64_t d = 0; d < size; ++d) + { + if constexpr (std::is_same_v) + { + reinterpret_cast(out)[d] = float_to_bf16_round(input[d]); + } + else + { + out[d] = static_cast(input[d]); + } + } + } + } + + template + struct brgemm + { + static inline void apply( + const bf16_t *__restrict__ A, const unsigned char *__restrict__ B, bf16_t *__restrict__ C, + const bf16_t *__restrict__ Bs, bf16_t *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out) + { + constexpr int BLOCK_N = block_size_n(); + const int ldb_tmp = BLOCK_N; + if (use_brgemm_dequant_out) + { + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, false, cast_at_bf16(A), cast_at_bf16(Btmp), Ctmp); + } + else + { + for (int64_t k = 0; k < K; k += BLOCK_K) + { + int64_t kb_size = std::min(static_cast(BLOCK_K), K - k); + const int64_t kgs = k / blocksize; + + unpack_B( + Btmp, B + (k >> 1) * ldb, Bs + kgs * strideBs, N, kb_size, blocksize, ldb, ldb_tmp, strideBs); + + const bool add_C = k != 0; + at::native::cpublas::brgemm( + M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, cast_at_bf16(A + k), cast_at_bf16(Btmp), Ctmp); + } + } + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) + { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + }; + + template + void tinygemm_kernel( + const scalar_t *__restrict__ A, const unsigned char *__restrict__ B, scalar_t *__restrict__ C, + const scalar_t *__restrict__ Bs, scalar_t *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool brg, + bool use_brgemm_dequant_out = false) + { + if (brg) + { + brgemm::apply( + A, B, C, Bs, Btmp, Ctmp, M, N, K, blocksize, lda, ldb, ldc, strideBs, use_brgemm_dequant_out); + return; + } + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) + { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) + { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) + { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32, DATA_TYPE); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64, DATA_TYPE); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32, DATA_TYPE); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64, DATA_TYPE); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32, DATA_TYPE); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64, DATA_TYPE); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32, DATA_TYPE); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64, DATA_TYPE); + break; + default: + { + std::fprintf( + stderr, "[bitsandbytes] Unexpected block size %lldx%lld\n", (long long)mb_size, (long long)nb_size); + std::abort(); // or return; if you prefer silent exit + } + } + } + } + } + + template + void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w, + const T *__restrict__ absmax, T *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride) + { + constexpr int64_t BLOCK_M = block_size_m(); // 32 + constexpr int64_t BLOCK_N = block_size_n(); // 32 + const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 + const int64_t NB = div_up(N, BLOCK_N); + // TODO: Find better threshold. + T *Btmp_start = nullptr; + const bool use_brgemm = M > 4; + const bool use_brgemm_dequant_out = M > 100; + if (use_brgemm_dequant_out) + { + // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads + at::Tensor Btmp_t = at::zeros({N, K}, at::dtype(at::kBFloat16)); + at::BFloat16 *Btmp_start_pt = Btmp_t.data_ptr(); + Btmp_start = reinterpret_cast(Btmp_start_pt); +#pragma omp parallel for + for (int64_t nb = 0; nb < NB; ++nb) + { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + T *Btmp = Btmp_start + nb_start * K; + for (int64_t k = 0; k < K; k += BLOCK_K) + { + int64_t kb_size = std::min(BLOCK_K, K - k); + int64_t kgs = k / blocksize; + int64_t strideBs = N; + int64_t ldb = nb_size; + const T *Bs = absmax + nb_start; + const unsigned char *Bw = reinterpret_cast(w + nb_start * K / 2); + unpack_B( + Btmp + k * BLOCK_N, Bw + (k >> 1) * ldb, Bs + kgs * strideBs, nb_size, kb_size, blocksize, ldb, + BLOCK_N, strideBs); + } + } + } + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N * K); + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) + { + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + alignas(64) T Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128 + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { // 0-1 + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + int64_t mb_start = mb * BLOCK_M; // 0 + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + tinygemm_kernel( + /* A */ x + mb_start * x_stride, + /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 + /* C */ out + mb_start * out_stride + nb_start, + /* Bs */ absmax + nb_start, + /* Btmp */ use_brgemm_dequant_out ? Btmp_start + nb_start * K : Btmp_inner, + /* Ctmp */ Ctmp, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* gs */ blocksize, // blocksize + /* lda */ x_stride, + /* ldb */ nb_size, + /* ldc */ out_stride, + /* sBs */ N, + /* brg */ use_brgemm, + /* dequant choice*/ use_brgemm_dequant_out + ); + } + } + } + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } }); + } + + //============================================================== + // TEMPLATE DEFINITIONS + //============================================================== + + template void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const bf16_t *__restrict__ x, const unsigned char *__restrict__ w, + const bf16_t *__restrict__ absmax, bf16_t *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + template void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const bf16_t *__restrict__ x, const unsigned char *__restrict__ w, + const bf16_t *__restrict__ absmax, bf16_t *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + } // namespace avx512 +} // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp new file mode 100644 index 0000000..28d05fa --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp @@ -0,0 +1,239 @@ +// AVX512 implementation - compile with -mavx512f -mavx512bf16 +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bitsandbytes_cpu +{ + namespace avx512 + { +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + + // block size for AMX gemm + constexpr int block_size_m() { return 2 * TILE_M; } + + constexpr int block_size_n() { return 2 * TILE_N; } + + template + inline int get_cache_blocks(int chunk_size) + { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); + } + +// forced unroll for perf critical path +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + + template + struct Unroll + { + template + ALWAYS_INLINE void operator()(const Func &f, Args... args) const + { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } + }; + + template <> + struct Unroll<1> + { + template + ALWAYS_INLINE void operator()(const Func &f, Args... args) const + { + f(std::integral_constant{}, args...); + } + }; + + template ::value, int>::type = 0> + inline T div_up(T x, T y) + { + return (x + y - 1) / y; + } + + inline int adjust_num_threads(int m) + { + int actual_nth = omp_get_max_threads(); + if (m == 1) + return actual_nth; + return std::max(1, (actual_nth >> 1) * 2); + } + + template + inline void parallel_2d(int m, int n, const func_t &f) + { + int nth = adjust_num_threads(m); + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) + { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) + { + break; + } + } +#pragma omp parallel num_threads(nth) + { + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); + } + } + + typedef enum DataType_t + { + NF4 = 0, + FP4 = 1, + } DataType_t; + + struct fp16_t + { + uint16_t v; + }; + + struct bf16_t + { + uint16_t v; + }; + + static inline bf16_t float_to_bf16(float x) + { + uint32_t bits; + std::memcpy(&bits, &x, 4); + uint32_t r = bits + 0x7FFF + ((bits >> 16) & 1); + return bf16_t{static_cast(r >> 16)}; + } + + static float bf16_to_float(uint16_t bf16) + { + uint32_t bits = (uint32_t)bf16 << 16; + float f; + std::memcpy(&f, &bits, sizeof(f)); + return f; + } + + static inline fp16_t float_to_fp16(float x) + { + uint32_t bits; + std::memcpy(&bits, &x, 4); + uint32_t sign = (bits >> 31) & 0x1; + uint32_t exp = (bits >> 23) & 0xFF; + uint32_t mant = bits & 0x7FFFFF; + + uint16_t h; + if (exp == 0xFF) + { // Inf / NaN + uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa + h = (sign << 15) | (0x1F << 10) | mant16; + } + else if (exp > 0x70 + 0x1E) + { // overflow: exp_f -127 +15 > 30 (exp_f > 142) + h = (sign << 15) | (0x1F << 10); // Inf + } + else if (exp < 0x71) + { // subnormal or zero (exp_f < 113) + if (exp < 0x67) + { // too small -> zero (exp_f < 103) + h = (sign << 15); + } + else + { + // subnormal: implicit leading 1 + uint32_t shift = 0x71 - exp; + uint32_t mant_with_hidden = mant | 0x800000; + // add rounding bias before shifting (23-10 =13 bits to drop + shift) + uint32_t rounded = (mant_with_hidden + (1u << (shift + 12))) >> (shift + 13); + h = (sign << 15) | (uint16_t)rounded; + } + } + else + { + // normalized + uint32_t exp_h = exp - 127 + 15; + // round mantissa: add 2^(23-10-1) = 0x1000 + uint32_t mant_rounded = mant + 0x00001000; + if (mant_rounded & 0x00800000) + { // mantissa overflow after rounding + mant_rounded = 0; + ++exp_h; + if (exp_h >= 0x1F) + { // overflow to Inf + h = (sign << 15) | (0x1F << 10); + return fp16_t{h}; + } + } + h = (sign << 15) | ((uint16_t)exp_h << 10) | ((uint16_t)(mant_rounded >> 13)); + } + return fp16_t{h}; + } + +#ifdef _MSC_VER +#include + + static inline bool has_avx512f() + { + static bool v = [] + { + int info[4]; + __cpuidex(info, 7, 0); + return (info[1] & (1 << 16)) != 0; // EBX bit16 AVX512F + }(); + return v; + } + static inline bool has_avx512bf16() + { + static bool v = [] + { + int info[4]; + __cpuidex(info, 7, 1); + return (info[0] & (1 << 5)) != 0; // EAX bit5 AVX512_BF16 + }(); + return v; + } +#else + static inline bool has_avx512f() + { + static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); + return supported_avx512f; + } + static inline bool has_avx512bf16() + { + static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); + return supported_avx512bf16; + } +#endif + + template + void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w, + const T *__restrict__ absmax, T *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + } // namespace avx512 +} // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp new file mode 100644 index 0000000..e485159 --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp @@ -0,0 +1,48 @@ +#include "bitsandbytes_cpu.hpp" + +using bf16_t = bitsandbytes_cpu::avx512::bf16_t; +using bitsandbytes_cpu::avx512::DataType_t; + +namespace bitsandbytes_cpu +{ + + // Main dispatcher that selects the best implementation based on runtime CPU features + void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight, + const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type) + { + int64_t M = input.size(0); + int64_t N = weight.size(0); + int64_t K = input.size(1); + // strides + int64_t x_strideM = input.stride(0); + int64_t out_strideM = out.stride(0); + // Runtime CPU feature detection and dispatch + if (CPUFeatures::hasAVX512BF16()) + { + // Use AVX512 optimized implementation + if (quant_type == 1) { + bitsandbytes_cpu::avx512::gemm_4bit_inference( + M, N, K, + reinterpret_cast(input.data_ptr()), + weight.data_ptr(), + reinterpret_cast(absmax.data_ptr()), + reinterpret_cast(out.data_ptr()), + blocksize, x_strideM, out_strideM); + } + else { + bitsandbytes_cpu::avx512::gemm_4bit_inference( + M, N, K, + reinterpret_cast(input.data_ptr()), + weight.data_ptr(), + reinterpret_cast(absmax.data_ptr()), + reinterpret_cast(out.data_ptr()), + blocksize, x_strideM, out_strideM); + } + } + else + { + // raise error for unsupported CPU + throw std::runtime_error("[bitsandbytes] gemm_4bit: CPU does not support AVX512BF16 instruction set required for 4-bit quantization operations."); + } + } +} // namespace bitsandbytes_cpu diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp new file mode 100644 index 0000000..0070b7d --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "cpu_features.hpp" +#include "bitsandbytes_avx512.hpp" +#include +#include +#include + +namespace bitsandbytes_cpu +{ + + // Main dispatcher that selects the best implementation based on runtime CPU features + void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight, + const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type); +} // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp new file mode 100644 index 0000000..1fdbe32 --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp @@ -0,0 +1,25 @@ +#include +#include "bitsandbytes_cpu.hpp" + +// Forward implementation for CPU +torch::Tensor gemm_4bit_cpu_forward( + const torch::Tensor &input, const torch::Tensor &weight, + const torch::Tensor &absmax, int64_t blocksize, int64_t quant_type) +{ + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(absmax.is_contiguous(), "absmax must be contiguous"); + + auto output = at::empty({input.size(0), weight.size(0)}, input.options()); + + bitsandbytes_cpu::gemm_4bit( + input, + weight, + absmax, + output, + blocksize, + quant_type + ); + + return output; +} diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp new file mode 100644 index 0000000..de454dc --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp @@ -0,0 +1,176 @@ +#pragma once + +#ifdef _MSC_VER +#include +#else +#include +#endif +#include +#include +#include +#include +namespace bitsandbytes_cpu +{ + + // CPU feature detection + class CPUFeatures + { + public: + static bool hasAVX2() + { + static bool avx2_supported = checkAVX2(); + return avx2_supported; + } + + static bool hasAVX512BF16() + { + static bool bf16_supported = checkAVX512BF16(); + return bf16_supported; + } + + private: + static bool checkAVX2() + { +#ifdef _MSC_VER + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + + if (n_ids >= 7) + { + __cpuidex(cpu_info, 7, 0); + return (cpu_info[1] & (1 << 5)) != 0; // EBX bit 5 + } + return false; +#else + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + __cpuid_count(7, 0, eax, ebx, ecx, edx); + return (ebx & (1 << 5)) != 0; // EBX bit 5 +#endif + } + + static bool checkAVX512() + { +#ifdef _MSC_VER + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + if (n_ids < 7) + return false; + + __cpuidex(cpu_info, 7, 0); + bool avx512f = (cpu_info[1] & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation + if (!avx512f) + return false; + + __cpuid(cpu_info, 1); + bool osxsave = (cpu_info[2] & (1 << 27)) != 0; // ECX bit 27: OSXSAVE + if (!osxsave) + return false; + + // check XCR0: bits 1,2 (SSE/AVX) and 5,6,7 (AVX-512 state) must be enabled by OS + unsigned long long xcr0 = _xgetbv(0); + return ((xcr0 & 0xE6ULL) == 0xE6ULL); +#else + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + + __cpuid_count(7, 0, eax, ebx, ecx, edx); + bool avx512f = (ebx & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation + if (!avx512f) + return false; + + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) == 0) + { + return false; + } + bool osxsave = (ecx & (1 << 27)) != 0; // ECX bit 27: OSXSAVE + if (!osxsave) + return false; + + unsigned int xcr0_lo = 0, xcr0_hi = 0; + __asm__ volatile("xgetbv" : "=a"(xcr0_lo), "=d"(xcr0_hi) : "c"(0)); + unsigned long long xcr0 = ((unsigned long long)xcr0_hi << 32) | xcr0_lo; + // require XCR0 bits 1,2,5,6,7 set -> mask 0xE6 (0b11100110) + return ((xcr0 & 0xE6ULL) == 0xE6ULL); +#endif + } + + static bool checkAVX512BF16() + { + // require AVX-512 foundation supported and OS enabled + if (!checkAVX512()) + return false; + +#ifndef _MSC_VER + // First: try Linux /proc/cpuinfo flags (most robust on Linux) + std::ifstream f("/proc/cpuinfo"); + if (f) + { + std::string line; + while (std::getline(f, line)) + { + // flags line contains many space-separated tokens including avx512_bf16 on supported CPUs + if (line.find("avx512_bf16") != std::string::npos || + line.find("avx512bf16") != std::string::npos) + { + return true; + } + } + } + + // Fallback: attempt CPUID subleaf check if available. + // Note: exact bit position for AVX512_BF16 may differ across vendors/CPUID versions. + // This fallback tries CPUID(7,1) and checks some common positions; if uncertain returns false. + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + unsigned int eax, ebx, ecx, edx; + // try subleaf 1 + __cpuid_count(7, 1, eax, ebx, ecx, edx); + // There isn't a universally agreed constant here in this file; check common candidate bits: + // - some implementations report AVX512_BF16 in ECX/EBX of subleaf 1. + // Try commonly used positions conservatively. + const unsigned int candidate_masks[] = { + (1u << 5), // candidate (may collide with other features) + (1u << 26), // another candidate position + }; + for (unsigned m : candidate_masks) + { + if ((ebx & m) || (ecx & m) || (edx & m)) + { + return true; + } + } + return false; +#else + // On MSVC / Windows, use CPUID if available (simple check). If unsure, return false. + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + if (n_ids < 7) + return false; + __cpuidex(cpu_info, 7, 1); + // same conservative check as above + const int candidate_masks[] = {(1 << 5), (1 << 26)}; + for (int m : candidate_masks) + { + if ((cpu_info[1] & m) || (cpu_info[2] & m) || (cpu_info[3] & m)) + { + return true; + } + } + return false; +#endif + } + }; + +} // namespace bitsandbytes_cpu diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml new file mode 100644 index 0000000..b19d07f --- /dev/null +++ b/quantization-bitsandbytes/build.toml @@ -0,0 +1,27 @@ +[general] +name = "bitsandbytes" +universal = false + +[torch] +src = ["torch-ext/torch_binding.cpp"] + +[kernel.bitsandbytes_cpu] +backend = "cpu" +depends = ["torch"] +src = [ + "bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp", + "bitsandbytes_cpu/bitsandbytes_cpu.cpp", + "bitsandbytes_cpu/bitsandbytes_cpu.hpp", + "bitsandbytes_cpu/cpu_features.hpp", +] +include = ["bitsandbytes_cpu"] + +[kernel.bitsandbytes_cpu_avx512] +backend = "cpu" +depends = ["torch"] +src = [ + "bitsandbytes_cpu/bitsandbytes_avx512.cpp", + "bitsandbytes_cpu/bitsandbytes_avx512.hpp", +] +include = ["bitsandbytes_cpu"] +cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl"] diff --git a/quantization-bitsandbytes/flake.nix b/quantization-bitsandbytes/flake.nix new file mode 100644 index 0000000..b8d645d --- /dev/null +++ b/quantization-bitsandbytes/flake.nix @@ -0,0 +1,21 @@ +{ + description = "Flake for Torch kernel extension"; + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + outputs = + { self, kernel-builder }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + + # This is a workaround, we should be able to specify flags per arch in + # kernel-builder. + torchVersions = + allVersions: + builtins.map ( + version: + version // { systems = builtins.filter (system: system == "x86_64-linux") version.systems; } + ) allVersions; + }; +} diff --git a/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py new file mode 100644 index 0000000..349225b --- /dev/null +++ b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py @@ -0,0 +1,3 @@ +from .custom_ops import gemm_4bit_forward + +__all__ = ["gemm_4bit_forward"] \ No newline at end of file diff --git a/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py new file mode 100644 index 0000000..65db6cd --- /dev/null +++ b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py @@ -0,0 +1,11 @@ +import torch +from ._ops import ops + +def gemm_4bit_forward( + input: torch.Tensor, + weight: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: int, +) -> torch.Tensor: + return ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type) diff --git a/quantization-bitsandbytes/torch-ext/torch_binding.cpp b/quantization-bitsandbytes/torch-ext/torch_binding.cpp new file mode 100644 index 0000000..c3fb06c --- /dev/null +++ b/quantization-bitsandbytes/torch-ext/torch_binding.cpp @@ -0,0 +1,42 @@ +#include +#include "registration.h" + + +#if defined(CPU_KERNEL) +torch::Tensor gemm_4bit_cpu_forward( + const torch::Tensor &input, + const torch::Tensor &weight, + const torch::Tensor &absmax, + int64_t blocksize, + int64_t quant_type); +#endif +// Unified dispatcher for both CPU and XPU +torch::Tensor gemm_4bit_forward( + const torch::Tensor &input, + const torch::Tensor &weight, + const torch::Tensor &absmax, + int64_t blocksize, + int64_t quant_type) { +#if defined(CPU_KERNEL) + if (input.device().type() == torch::kCPU) { + TORCH_CHECK(input.device().type() == torch::kCPU, "input must be on CPU"); + TORCH_CHECK(weight.device().type() == torch::kCPU, "weight must be on CPU"); + TORCH_CHECK(absmax.device().type() == torch::kCPU, "absmax must be on CPU"); + TORCH_CHECK(blocksize > 0, "blocksize must be > 0"); + return gemm_4bit_cpu_forward(input, weight, absmax, blocksize, quant_type); + } +#endif + else { + TORCH_CHECK(false, "Unsupported device type: ", hidden_states.device().type()); + } +} + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("gemm_4bit_forward(Tensor input, Tensor weight, Tensor absmax, int blocksize, int quant_type) -> Tensor"); +#if defined(CPU_KERNEL) + // Register CPU implementation + ops.impl("gemm_4bit_forward", torch::kCPU, &gemm_4bit_forward); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) From d37d0b7cff580e3384166f0cc5dad45c475bb98f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 09:50:04 +0000 Subject: [PATCH 02/19] fix check Signed-off-by: jiqing-feng --- quantization-bitsandbytes/torch-ext/torch_binding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/torch-ext/torch_binding.cpp b/quantization-bitsandbytes/torch-ext/torch_binding.cpp index c3fb06c..ee642cd 100644 --- a/quantization-bitsandbytes/torch-ext/torch_binding.cpp +++ b/quantization-bitsandbytes/torch-ext/torch_binding.cpp @@ -27,7 +27,7 @@ torch::Tensor gemm_4bit_forward( } #endif else { - TORCH_CHECK(false, "Unsupported device type: ", hidden_states.device().type()); + TORCH_CHECK(false, "Unsupported device type: ", input.device().type()); } } From 26c60791d226c56932d2614ed0b850f8c1001235 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 09:51:41 +0000 Subject: [PATCH 03/19] rm useless head Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp index be6219b..6758651 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -1,7 +1,6 @@ // AVX512 implementation - compile with -mavx512f -mavx512bf16 #include #include -#include #include #include #include From 1fecbbbedbd8a3716801826548a57dfae79c1ace Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 09:52:58 +0000 Subject: [PATCH 04/19] rm useless header Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp index 28d05fa..dcda785 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp @@ -1,7 +1,6 @@ // AVX512 implementation - compile with -mavx512f -mavx512bf16 #include #include -#include #include #include #include From e9eb7bca803af635673c7b1a02342a4287489811 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 09:56:13 +0000 Subject: [PATCH 05/19] add mavx512dq mavx512bw Signed-off-by: jiqing-feng --- quantization-bitsandbytes/build.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml index b19d07f..ccd1367 100644 --- a/quantization-bitsandbytes/build.toml +++ b/quantization-bitsandbytes/build.toml @@ -24,4 +24,4 @@ src = [ "bitsandbytes_cpu/bitsandbytes_avx512.hpp", ] include = ["bitsandbytes_cpu"] -cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl"] +cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl", "mavx512dq", "mavx512bw"] From 3b4d7baa9f21353262b16cfe5121a2f8964c9c46 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 10:00:49 +0000 Subject: [PATCH 06/19] fix avx512dq Signed-off-by: jiqing-feng --- quantization-bitsandbytes/build.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml index ccd1367..966c1e7 100644 --- a/quantization-bitsandbytes/build.toml +++ b/quantization-bitsandbytes/build.toml @@ -24,4 +24,4 @@ src = [ "bitsandbytes_cpu/bitsandbytes_avx512.hpp", ] include = ["bitsandbytes_cpu"] -cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl", "mavx512dq", "mavx512bw"] +cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl", "-mavx512dq", "-mavx512bw"] From bd320ad26bd5ac59adac5b0eae97563b36ae44d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 10:02:13 +0000 Subject: [PATCH 07/19] fix name Signed-off-by: jiqing-feng --- quantization-bitsandbytes/build.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml index 966c1e7..7ee938a 100644 --- a/quantization-bitsandbytes/build.toml +++ b/quantization-bitsandbytes/build.toml @@ -1,5 +1,5 @@ [general] -name = "bitsandbytes" +name = "bitsandbytes_hf_kernel" universal = false [torch] From 5017d225d49c1fcc394135a1af387d704e3550a6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 10:12:15 +0000 Subject: [PATCH 08/19] name Signed-off-by: jiqing-feng --- quantization-bitsandbytes/build.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml index 7ee938a..066f641 100644 --- a/quantization-bitsandbytes/build.toml +++ b/quantization-bitsandbytes/build.toml @@ -1,5 +1,5 @@ [general] -name = "bitsandbytes_hf_kernel" +name = "quantization_bitsandbytes" universal = false [torch] From 7823aa8dbbc99aca57f4ce03ebba7ed71b5b6c7a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 27 Nov 2025 14:13:47 +0000 Subject: [PATCH 09/19] add flake.lock Signed-off-by: jiqing-feng --- quantization-bitsandbytes/flake.lock | 95 ++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 quantization-bitsandbytes/flake.lock diff --git a/quantization-bitsandbytes/flake.lock b/quantization-bitsandbytes/flake.lock new file mode 100644 index 0000000..185d4d3 --- /dev/null +++ b/quantization-bitsandbytes/flake.lock @@ -0,0 +1,95 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1764164279, + "narHash": "sha256-HVhMEzLOYfektbKqbfSWdo1o1bXLReknDu/DX466TqQ=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "df995731f57c7042ce1ae6471c8d57477fce25fb", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1763291491, + "narHash": "sha256-eEYvm+45PPmy+Qe+nZDpn1uhoMUjJwx3PwVVQoO9ksA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} From fc0be9a4edebdc72d283a24dabce0d12a198ef60 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 12:41:26 +0000 Subject: [PATCH 10/19] fix lint Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 140 ++++++------------ .../bitsandbytes_cpu/bitsandbytes_avx512.hpp | 117 --------------- .../bitsandbytes_cpu/bitsandbytes_cpu.cpp | 17 +-- 3 files changed, 53 insertions(+), 221 deletions(-) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp index 6758651..35a29b7 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -1,6 +1,9 @@ // AVX512 implementation - compile with -mavx512f -mavx512bf16 +#define CPU_CAPABILITY_AVX512 #include #include +#include +#include #include #include #include @@ -10,34 +13,6 @@ namespace bitsandbytes_cpu { namespace avx512 { - inline __m256i cvt_fp32_to_fp16(const __m512 src) - { - return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - } - - inline __m256i cvt_fp32_to_bf16(const __m512 src) - { - if (has_avx512bf16()) - { - return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src)); - } - __m512i value = _mm512_castps_si512(src); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - // uint32_t lsb = (input >> 16) & 1; - auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); - // uint32_t rounding_bias = 0x7fff + lsb; - t_value = _mm512_add_epi32(t_value, vec_bias); - // input += rounding_bias; - t_value = _mm512_add_epi32(t_value, value); - // input = input >> 16; - t_value = _mm512_srli_epi32(t_value, 16); - // Check NaN before converting back to bf16 - t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); - return _mm512_cvtusepi32_epi16(t_value); - } static inline __m512 set_nf4_lut() { @@ -56,22 +31,10 @@ namespace bitsandbytes_cpu #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) - static inline const at::BFloat16 *cast_at_bf16(const bf16_t *p) - { - static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch"); - return reinterpret_cast(p); - } - - static inline at::BFloat16 *cast_at_bf16(bf16_t *p) - { - static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch"); - return reinterpret_cast(p); - } - template inline void unpack_B( - bf16_t *__restrict__ Btmp, const unsigned char *__restrict__ packed_B, - const bf16_t *__restrict__ Bs, // scales [K/gs, N] in bf16 + at::BFloat16 *__restrict__ Btmp, const unsigned char *__restrict__ packed_B, + const at::BFloat16 *__restrict__ Bs, // scales [K/gs, N] in bf16 int64_t N, int64_t K, int blocksize, int64_t ldb, int64_t ldb_tmp, int64_t strideBs) { // Dequant: (w - z) * s -> bf16 @@ -171,11 +134,11 @@ namespace bitsandbytes_cpu }; template - struct tinygemm_kernel_nn + struct tinygemm_kernel_nn { static inline void apply( - const bf16_t *__restrict__ A, const unsigned char *__restrict__ B, bf16_t *__restrict__ C, - const bf16_t *__restrict__ Bs, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs) + const at::BFloat16 *__restrict__ A, const unsigned char *__restrict__ B, at::BFloat16 *__restrict__ C, + const at::BFloat16 *__restrict__ Bs, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs) { static_assert(BLOCK_N % 32 == 0); constexpr int ROWS = BLOCK_M; // 32 @@ -307,55 +270,42 @@ namespace bitsandbytes_cpu return hi; } + template , int> = 0> + inline at::vec::Vectorized convert_from_float_ext(const at::vec::Vectorized& a, const at::vec::Vectorized& b) { + return at::vec::convert_from_float(a, b); + } + + template <> + inline at::vec::Vectorized + convert_from_float_ext(const at::vec::Vectorized& a, const at::vec::Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); + } + template - inline void copy_stub(scalar_t *__restrict__ out, const float *__restrict__ input, int64_t size) - { - if (has_avx512bf16()) - { - int64_t d = 0; - const int V = 32; - for (; d + V <= size; d += V) - { - __m512 lo = _mm512_loadu_ps(input + d); - __m512 hi = _mm512_loadu_ps(input + d + 16); - __m512bh packed = _mm512_cvtne2ps_pbh(hi, lo); - _mm512_storeu_si512(reinterpret_cast(out + d), (__m512i)packed); - } - for (; d < size; ++d) - { - if constexpr (std::is_same_v) - { - // store raw bf16 bits - reinterpret_cast(out)[d] = float_to_bf16_round(input[d]); - } - else - { - out[d] = static_cast(input[d]); - } - } + inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); } - else - { - for (int64_t d = 0; d < size; ++d) - { - if constexpr (std::is_same_v) - { - reinterpret_cast(out)[d] = float_to_bf16_round(input[d]); - } - else - { - out[d] = static_cast(input[d]); - } - } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); } } template - struct brgemm + struct brgemm { static inline void apply( - const bf16_t *__restrict__ A, const unsigned char *__restrict__ B, bf16_t *__restrict__ C, - const bf16_t *__restrict__ Bs, bf16_t *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + const at::BFloat16 *__restrict__ A, const unsigned char *__restrict__ B, at::BFloat16 *__restrict__ C, + const at::BFloat16 *__restrict__ Bs, at::BFloat16 *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out) { constexpr int BLOCK_N = block_size_n(); @@ -363,7 +313,7 @@ namespace bitsandbytes_cpu if (use_brgemm_dequant_out) { at::native::cpublas::brgemm( - M, N, K, lda, ldb_tmp, BLOCK_N, false, cast_at_bf16(A), cast_at_bf16(Btmp), Ctmp); + M, N, K, lda, ldb_tmp, BLOCK_N, false, A, Btmp, Ctmp); } else { @@ -377,14 +327,14 @@ namespace bitsandbytes_cpu const bool add_C = k != 0; at::native::cpublas::brgemm( - M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, cast_at_bf16(A + k), cast_at_bf16(Btmp), Ctmp); + M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp); } } // copy from Ctmp to C for (int64_t m = 0; m < M; ++m) { - copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); } } }; @@ -539,11 +489,11 @@ namespace bitsandbytes_cpu // TEMPLATE DEFINITIONS //============================================================== - template void gemm_4bit_inference( - int64_t M, int64_t N, int64_t K, const bf16_t *__restrict__ x, const unsigned char *__restrict__ w, - const bf16_t *__restrict__ absmax, bf16_t *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); - template void gemm_4bit_inference( - int64_t M, int64_t N, int64_t K, const bf16_t *__restrict__ x, const unsigned char *__restrict__ w, - const bf16_t *__restrict__ absmax, bf16_t *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + template void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const at::BFloat16 *__restrict__ x, const unsigned char *__restrict__ w, + const at::BFloat16 *__restrict__ absmax, at::BFloat16 *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + template void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const at::BFloat16 *__restrict__ x, const unsigned char *__restrict__ w, + const at::BFloat16 *__restrict__ absmax, at::BFloat16 *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); } // namespace avx512 } // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp index dcda785..6d5cbcd 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp @@ -113,123 +113,6 @@ namespace bitsandbytes_cpu FP4 = 1, } DataType_t; - struct fp16_t - { - uint16_t v; - }; - - struct bf16_t - { - uint16_t v; - }; - - static inline bf16_t float_to_bf16(float x) - { - uint32_t bits; - std::memcpy(&bits, &x, 4); - uint32_t r = bits + 0x7FFF + ((bits >> 16) & 1); - return bf16_t{static_cast(r >> 16)}; - } - - static float bf16_to_float(uint16_t bf16) - { - uint32_t bits = (uint32_t)bf16 << 16; - float f; - std::memcpy(&f, &bits, sizeof(f)); - return f; - } - - static inline fp16_t float_to_fp16(float x) - { - uint32_t bits; - std::memcpy(&bits, &x, 4); - uint32_t sign = (bits >> 31) & 0x1; - uint32_t exp = (bits >> 23) & 0xFF; - uint32_t mant = bits & 0x7FFFFF; - - uint16_t h; - if (exp == 0xFF) - { // Inf / NaN - uint16_t mant16 = mant ? 0x200 : 0; // quiet NaN: set MSB of mantissa - h = (sign << 15) | (0x1F << 10) | mant16; - } - else if (exp > 0x70 + 0x1E) - { // overflow: exp_f -127 +15 > 30 (exp_f > 142) - h = (sign << 15) | (0x1F << 10); // Inf - } - else if (exp < 0x71) - { // subnormal or zero (exp_f < 113) - if (exp < 0x67) - { // too small -> zero (exp_f < 103) - h = (sign << 15); - } - else - { - // subnormal: implicit leading 1 - uint32_t shift = 0x71 - exp; - uint32_t mant_with_hidden = mant | 0x800000; - // add rounding bias before shifting (23-10 =13 bits to drop + shift) - uint32_t rounded = (mant_with_hidden + (1u << (shift + 12))) >> (shift + 13); - h = (sign << 15) | (uint16_t)rounded; - } - } - else - { - // normalized - uint32_t exp_h = exp - 127 + 15; - // round mantissa: add 2^(23-10-1) = 0x1000 - uint32_t mant_rounded = mant + 0x00001000; - if (mant_rounded & 0x00800000) - { // mantissa overflow after rounding - mant_rounded = 0; - ++exp_h; - if (exp_h >= 0x1F) - { // overflow to Inf - h = (sign << 15) | (0x1F << 10); - return fp16_t{h}; - } - } - h = (sign << 15) | ((uint16_t)exp_h << 10) | ((uint16_t)(mant_rounded >> 13)); - } - return fp16_t{h}; - } - -#ifdef _MSC_VER -#include - - static inline bool has_avx512f() - { - static bool v = [] - { - int info[4]; - __cpuidex(info, 7, 0); - return (info[1] & (1 << 16)) != 0; // EBX bit16 AVX512F - }(); - return v; - } - static inline bool has_avx512bf16() - { - static bool v = [] - { - int info[4]; - __cpuidex(info, 7, 1); - return (info[0] & (1 << 5)) != 0; // EAX bit5 AVX512_BF16 - }(); - return v; - } -#else - static inline bool has_avx512f() - { - static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); - return supported_avx512f; - } - static inline bool has_avx512bf16() - { - static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); - return supported_avx512bf16; - } -#endif - template void gemm_4bit_inference( int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w, diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp index e485159..66abdf2 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp @@ -1,6 +1,5 @@ #include "bitsandbytes_cpu.hpp" -using bf16_t = bitsandbytes_cpu::avx512::bf16_t; using bitsandbytes_cpu::avx512::DataType_t; namespace bitsandbytes_cpu @@ -21,21 +20,21 @@ namespace bitsandbytes_cpu { // Use AVX512 optimized implementation if (quant_type == 1) { - bitsandbytes_cpu::avx512::gemm_4bit_inference( + bitsandbytes_cpu::avx512::gemm_4bit_inference( M, N, K, - reinterpret_cast(input.data_ptr()), + input.data_ptr(), weight.data_ptr(), - reinterpret_cast(absmax.data_ptr()), - reinterpret_cast(out.data_ptr()), + absmax.data_ptr(), + out.data_ptr(), blocksize, x_strideM, out_strideM); } else { - bitsandbytes_cpu::avx512::gemm_4bit_inference( + bitsandbytes_cpu::avx512::gemm_4bit_inference( M, N, K, - reinterpret_cast(input.data_ptr()), + input.data_ptr(), weight.data_ptr(), - reinterpret_cast(absmax.data_ptr()), - reinterpret_cast(out.data_ptr()), + absmax.data_ptr(), + out.data_ptr(), blocksize, x_strideM, out_strideM); } } From 4712546affa53d2588fa7bbb203dcdd934b5b96b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 12:41:44 +0000 Subject: [PATCH 11/19] add tests Signed-off-by: jiqing-feng --- quantization-bitsandbytes/tests/__init__.py | 0 .../tests/test_bitsandbytes.py | 105 ++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 quantization-bitsandbytes/tests/__init__.py create mode 100644 quantization-bitsandbytes/tests/test_bitsandbytes.py diff --git a/quantization-bitsandbytes/tests/__init__.py b/quantization-bitsandbytes/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/quantization-bitsandbytes/tests/test_bitsandbytes.py b/quantization-bitsandbytes/tests/test_bitsandbytes.py new file mode 100644 index 0000000..bc4488b --- /dev/null +++ b/quantization-bitsandbytes/tests/test_bitsandbytes.py @@ -0,0 +1,105 @@ +import torch +import pytest + +from quantization_bitsandbytes import gemm_4bit_forward + +def unpack_weight_packed_for_cpu(packed_qweight: torch.Tensor, block_n: int = 32): + """ + Inverse of convert_weight_packed_for_cpu. + packed_qweight: (N, K//2) uint8, each byte = (high<<4)|low, both 4-bit values in 0..15 + returns: qweight_final (N, K) uint8 with original 4-bit values (0..15) + """ + assert packed_qweight.dtype == torch.uint8 + assert packed_qweight.dim() == 2 + N, K_half = packed_qweight.shape + assert N % block_n == 0 + BIT_COUNT = block_n # 32 + # reshape to rows of 32 packed bytes + qw = packed_qweight.reshape(-1, BIT_COUNT) # [(N//block_n)*K_half, 32] + low = (qw & 0x0F) + high = (qw >> 4) & 0x0F + # restore 64 nibbles (low first then high, matching original pack order) + restored = torch.cat([low, high], dim=1) # [..., 64] + # reshape back (inverse of flatten) + restored = restored.reshape(N // block_n, K_half, block_n, 2) # [N/block_n, K//2, block_n, 2] + # inverse transpose + restored = restored.transpose(-3, -2) # [N/block_n, block_n, K//2, 2] + # final shape + qweight_final = restored.reshape(N, K_half * 2).to(torch.uint8) + return qweight_final + + +_NF4_QUANT_TABLE = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=torch.float32, +) + +_FP4_QUANT_TABLE = torch.tensor( + [ + 0.0000, + 0.0052, + 0.6667, + 1.0000, + 0.3333, + 0.5000, + 0.1667, + 0.2500, + 0.0000, + -0.0052, + -0.6667, + -1.0000, + -0.3333, + -0.5000, + -0.1667, + -0.2500, + ], + dtype=torch.float32, +) + +def ref_gemm_4bit(x, packed_weight, scales, group_size, quant_type): + unpacked_weight = unpack_weight_packed_for_cpu(packed_weight) + shape = unpacked_weight.shape + table = _FP4_QUANT_TABLE if quant_type == 1 else _NF4_QUANT_TABLE + original_weight = table[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1) + res = torch.matmul(x, original_weight.T.to(x.dtype)) + return res + +@pytest.mark.parametrize("M", [1, 4, 32, 128, 512, 1024]) +@pytest.mark.parametrize("K", [2048, 4096]) +@pytest.mark.parametrize("N", [1024, 2048, 4096]) +@pytest.mark.parametrize("quant_type", [0, 1]) +def test_bitsandbytes(M, K, N, quant_type): + torch.manual_seed(100) + device = torch.device("cpu") + dtype = torch.bfloat16 + group_size = 64 + + assert K % group_size == 0 + assert K % 2 == 0 + + x = torch.randn((M, K), device=device, dtype=dtype) * 0.1 + w = torch.randint(0, 15, (N, K // 2), device=device, dtype=torch.uint8) + num_groups = K // group_size + scales = torch.rand((num_groups, N), device=device, dtype=torch.bfloat16).pow(4.0) + + output = gemm_4bit_forward(x, w, scales, group_size, quant_type) + ref_out = ref_gemm_4bit(x, w, scales, group_size, quant_type) + + torch.testing.assert_close(output, ref_out, atol=1e-1, rtol=1e-2) From 6563014b10993d843e5b050ab63e1a61a25aba23 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 12:46:13 +0000 Subject: [PATCH 12/19] fix lint Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp index 35a29b7..14d39f0 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -422,10 +422,8 @@ namespace bitsandbytes_cpu if (use_brgemm_dequant_out) { // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads - at::Tensor Btmp_t = at::zeros({N, K}, at::dtype(at::kBFloat16)); - at::BFloat16 *Btmp_start_pt = Btmp_t.data_ptr(); - Btmp_start = reinterpret_cast(Btmp_start_pt); -#pragma omp parallel for + at::Tensor Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); + Btmp_start = Btmp_t.data_ptr(); for (int64_t nb = 0; nb < NB; ++nb) { int64_t nb_start = nb * BLOCK_N; From d86aac806c1841e27a1182ecce35cf02102f84b3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 12:51:29 +0000 Subject: [PATCH 13/19] fix typo Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp index 14d39f0..7f1ed5e 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -422,7 +423,7 @@ namespace bitsandbytes_cpu if (use_brgemm_dequant_out) { // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads - at::Tensor Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); + at::Tensor Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); Btmp_start = Btmp_t.data_ptr(); for (int64_t nb = 0; nb < NB; ++nb) { From a79d760fbd0b4774ba63354496e99a12a440acea Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 15:01:59 +0000 Subject: [PATCH 14/19] update build and tests Signed-off-by: jiqing-feng --- quantization-bitsandbytes/build.toml | 2 +- quantization-bitsandbytes/tests/test_bitsandbytes.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml index 066f641..fe16dcf 100644 --- a/quantization-bitsandbytes/build.toml +++ b/quantization-bitsandbytes/build.toml @@ -24,4 +24,4 @@ src = [ "bitsandbytes_cpu/bitsandbytes_avx512.hpp", ] include = ["bitsandbytes_cpu"] -cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl", "-mavx512dq", "-mavx512bw"] +cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl", "-mavx512dq", "-mavx512bw", "-mavx512vbmi"] diff --git a/quantization-bitsandbytes/tests/test_bitsandbytes.py b/quantization-bitsandbytes/tests/test_bitsandbytes.py index bc4488b..d816d5d 100644 --- a/quantization-bitsandbytes/tests/test_bitsandbytes.py +++ b/quantization-bitsandbytes/tests/test_bitsandbytes.py @@ -83,7 +83,7 @@ def ref_gemm_4bit(x, packed_weight, scales, group_size, quant_type): @pytest.mark.parametrize("M", [1, 4, 32, 128, 512, 1024]) @pytest.mark.parametrize("K", [2048, 4096]) -@pytest.mark.parametrize("N", [1024, 2048, 4096]) +@pytest.mark.parametrize("N", [2048, 4096]) @pytest.mark.parametrize("quant_type", [0, 1]) def test_bitsandbytes(M, K, N, quant_type): torch.manual_seed(100) From 6dc06791d6ef4ec609aeae57f2cd8c7f9fab00e8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 15:21:37 +0000 Subject: [PATCH 15/19] fix typo Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp index 7f1ed5e..7b08570 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -425,6 +425,7 @@ namespace bitsandbytes_cpu // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads at::Tensor Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); Btmp_start = Btmp_t.data_ptr(); +#pragma omp parallel for for (int64_t nb = 0; nb < NB; ++nb) { int64_t nb_start = nb * BLOCK_N; From 003fa20559b09406eb8b93f28e6959cccd1cc29b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 1 Dec 2025 15:04:08 +0000 Subject: [PATCH 16/19] fix def aten and expand tests Signed-off-by: jiqing-feng --- .../bitsandbytes_cpu/bitsandbytes_avx512.cpp | 3 ++- quantization-bitsandbytes/tests/test_bitsandbytes.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp index 7b08570..5315d6f 100644 --- a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -418,12 +418,13 @@ namespace bitsandbytes_cpu const int64_t NB = div_up(N, BLOCK_N); // TODO: Find better threshold. T *Btmp_start = nullptr; + at::Tensor Btmp_t; const bool use_brgemm = M > 4; const bool use_brgemm_dequant_out = M > 100; if (use_brgemm_dequant_out) { // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads - at::Tensor Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); + Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); Btmp_start = Btmp_t.data_ptr(); #pragma omp parallel for for (int64_t nb = 0; nb < NB; ++nb) diff --git a/quantization-bitsandbytes/tests/test_bitsandbytes.py b/quantization-bitsandbytes/tests/test_bitsandbytes.py index d816d5d..9e3a0a4 100644 --- a/quantization-bitsandbytes/tests/test_bitsandbytes.py +++ b/quantization-bitsandbytes/tests/test_bitsandbytes.py @@ -81,12 +81,11 @@ def ref_gemm_4bit(x, packed_weight, scales, group_size, quant_type): res = torch.matmul(x, original_weight.T.to(x.dtype)) return res -@pytest.mark.parametrize("M", [1, 4, 32, 128, 512, 1024]) -@pytest.mark.parametrize("K", [2048, 4096]) -@pytest.mark.parametrize("N", [2048, 4096]) +@pytest.mark.parametrize("M", [1, 4, 32, 97, 128, 244, 512, 1024, 1666]) +@pytest.mark.parametrize("K", [2048, 4096, 14336]) +@pytest.mark.parametrize("N", [1024, 2048, 4096, 7168]) @pytest.mark.parametrize("quant_type", [0, 1]) def test_bitsandbytes(M, K, N, quant_type): - torch.manual_seed(100) device = torch.device("cpu") dtype = torch.bfloat16 group_size = 64 From aa6d9fdb41a1680cd694838c822c350a6c498bc5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 1 Dec 2025 15:22:44 +0000 Subject: [PATCH 17/19] update tests Signed-off-by: jiqing-feng --- quantization-bitsandbytes/tests/test_bitsandbytes.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/quantization-bitsandbytes/tests/test_bitsandbytes.py b/quantization-bitsandbytes/tests/test_bitsandbytes.py index 9e3a0a4..fe29fdd 100644 --- a/quantization-bitsandbytes/tests/test_bitsandbytes.py +++ b/quantization-bitsandbytes/tests/test_bitsandbytes.py @@ -81,11 +81,12 @@ def ref_gemm_4bit(x, packed_weight, scales, group_size, quant_type): res = torch.matmul(x, original_weight.T.to(x.dtype)) return res -@pytest.mark.parametrize("M", [1, 4, 32, 97, 128, 244, 512, 1024, 1666]) +@pytest.mark.parametrize("M", [1, 31, 244, 1024, 2666]) @pytest.mark.parametrize("K", [2048, 4096, 14336]) -@pytest.mark.parametrize("N", [1024, 2048, 4096, 7168]) +@pytest.mark.parametrize("N", [1024, 4096, 7168]) +@pytest.mark.parametrize("group_size", [64, 256]) @pytest.mark.parametrize("quant_type", [0, 1]) -def test_bitsandbytes(M, K, N, quant_type): +def test_bitsandbytes(M, K, N, group_size, quant_type): device = torch.device("cpu") dtype = torch.bfloat16 group_size = 64 From 807bfc0ea06f63e1e4a70fd17ce46762c2c4e01c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 1 Dec 2025 15:49:53 +0000 Subject: [PATCH 18/19] update flake Signed-off-by: jiqing-feng --- quantization-bitsandbytes/flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/quantization-bitsandbytes/flake.lock b/quantization-bitsandbytes/flake.lock index 185d4d3..fcb402f 100644 --- a/quantization-bitsandbytes/flake.lock +++ b/quantization-bitsandbytes/flake.lock @@ -40,11 +40,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1764164279, - "narHash": "sha256-HVhMEzLOYfektbKqbfSWdo1o1bXLReknDu/DX466TqQ=", + "lastModified": 1764280248, + "narHash": "sha256-WNUbjnUnnDJtwK5ATn7VkkOZ+qyVF1UpKXaTkzbJsmI=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "df995731f57c7042ce1ae6471c8d57477fce25fb", + "rev": "cc4997dc60691b597aa76a4173541721bc17c224", "type": "github" }, "original": { From 72d7059e202cbe924a22cf3a6df2de059e7c46c7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 1 Dec 2025 16:31:55 +0000 Subject: [PATCH 19/19] convert to bfloat16 in custom op Signed-off-by: jiqing-feng --- .../torch-ext/quantization_bitsandbytes/custom_ops.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py index 65db6cd..7e77afc 100644 --- a/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py +++ b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py @@ -8,4 +8,12 @@ def gemm_4bit_forward( blocksize: int, quant_type: int, ) -> torch.Tensor: - return ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type) + original_dtype = input.dtype + if original_dtype != torch.bfloat16: + input = input.to(torch.bfloat16) + + output = ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type) + if original_dtype != torch.bfloat16: + output = output.to(original_dtype) + + return output