diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp new file mode 100644 index 0000000..5315d6f --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp @@ -0,0 +1,500 @@ +// AVX512 implementation - compile with -mavx512f -mavx512bf16 +#define CPU_CAPABILITY_AVX512 +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bitsandbytes_cpu +{ + namespace avx512 + { + + static inline __m512 set_nf4_lut() + { + return _mm512_set_ps( + 1.0f, 0.7229568362236023, 0.5626170039176941, 0.44070982933044434, 0.33791524171829224, 0.24611230194568634, + 0.16093020141124725, 0.07958029955625534, 0.0f, -0.09105003625154495, -0.18477343022823334, + -0.28444138169288635, -0.39491748809814453, -0.5250730514526367, -0.6961928009986877, -1.0f); + } + + static inline __m512 set_fp4_lut() + { + return _mm512_set_ps( + -0.2500f, -0.16666667f, -0.5000f, -0.33333333f, -1.0000f, -0.66666667f, -5.208333333e-03f, 0.0000f, 0.2500f, + 0.16666667f, 0.5000f, 0.33333333f, 1.0000f, 0.66666667f, 5.208333333e-03f, 0.0000f); + } + +#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + + template + inline void unpack_B( + at::BFloat16 *__restrict__ Btmp, const unsigned char *__restrict__ packed_B, + const at::BFloat16 *__restrict__ Bs, // scales [K/gs, N] in bf16 + int64_t N, int64_t K, int blocksize, int64_t ldb, int64_t ldb_tmp, int64_t strideBs) + { + // Dequant: (w - z) * s -> bf16 + const int64_t K2 = K >> 1; // 2 weights packed per byte + const int64_t gs2 = blocksize >> 1; + const int64_t ldb2 = ldb; // packed leading dimension (bytes) + const int64_t ldb_tmp2 = ldb_tmp; // output leading dimension in elements + float *btmp_ptr = reinterpret_cast(Btmp); // direct bf16 storage + + __m256i mask = _mm256_set1_epi8(0xF); // low nibble + __m256i fifteen = _mm256_set1_epi8(15); // shift [-15,15] -> [0,30] for LUT + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000); + __m512i s_idx1 = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); + __m512i s_idx0 = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + + __m512 scale_lo_fp32, scale_hi_fp32; + __m512 scales[4]; + + for (int64_t n = 0; n < N; n += 32) + { + for (int64_t k = 0; k < K2; ++k) + { + if (k % gs2 == 0) + { + const int64_t kgs = k / gs2; + // Load 32 scales (bf16) -> two fp32 vectors (first16, second16) + __m512i scales_bf16 = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + n)); + scale_lo_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 0)); + scale_hi_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 1)); + scales[0] = _mm512_permutexvar_ps(s_idx0, scale_lo_fp32); + scales[1] = _mm512_permutexvar_ps(s_idx1, scale_lo_fp32); + scales[2] = _mm512_permutexvar_ps(s_idx0, scale_hi_fp32); + scales[3] = _mm512_permutexvar_ps(s_idx1, scale_hi_fp32); + } + + // Load packed 32 bytes => 64 int4 + __m256i w_u4 = _mm256_loadu_si256(reinterpret_cast(packed_B + k * ldb2 + n)); + + // Split nibbles + __m256i w_lo = w_u4 & mask; + __m256i w_hi = _mm256_srli_epi16(w_u4, 4) & mask; + + // Shift to [0..30] before LUT + w_lo = _mm256_add_epi8(w_lo, fifteen); + w_hi = _mm256_add_epi8(w_hi, fifteen); + + // Lookup (w - z) -> bf16 using LUT (process 16-byte halves) + __m512i w_lo_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_lo), lut); + __m512i w_hi_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_hi), lut); + + __m512 w_lo_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 0)) * scales[0]; + __m512 w_hi_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 1)) * scales[1]; + __m512 w_lo_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 0)) * scales[2]; + __m512 w_hi_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 1)) * scales[3]; + + // Pack scaled (first 16 cols) then (second 16 cols) to bf16 + __m512bh packed0 = _mm512_cvtne2ps_pbh(w_hi_fp32_0, w_lo_fp32_0); + __m512bh packed1 = _mm512_cvtne2ps_pbh(w_hi_fp32_1, w_lo_fp32_1); + + // Store: two blocks of 16 bf16 (32 elements) per k iteration + _mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 0), (__m512i)packed0); + _mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 16), (__m512i)packed1); + } + } + } + + template + struct tinygemm_kernel_nn + { + static inline void apply( + const scalar_t *, const unsigned char *, scalar_t *, const scalar_t *, int64_t, int, int64_t, int64_t, int64_t, + int64_t, int64_t) + { + static_assert(sizeof(scalar_t) == 0, "tinygemm_kernel_nn primary template should never be instantiated"); + } + }; + + // The brgemm will not be used without HAS_TORCH + template + struct brgemm + { + static inline void apply( + const scalar_t *__restrict__ A, const unsigned char *__restrict__ B, scalar_t *__restrict__ C, + const scalar_t *__restrict__ Bs, scalar_t *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out) + { + return; + } + }; + + template + struct tinygemm_kernel_nn + { + static inline void apply( + const at::BFloat16 *__restrict__ A, const unsigned char *__restrict__ B, at::BFloat16 *__restrict__ C, + const at::BFloat16 *__restrict__ Bs, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs) + { + static_assert(BLOCK_N % 32 == 0); + constexpr int ROWS = BLOCK_M; // 32 + constexpr int COLS = BLOCK_N / 16; // 2 + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 16 * 4; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vc_master[ROWS * COLS]; + + __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit + __m256i fifteen = _mm256_set1_epi8(15); + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000); + __m512 scales[COLS]; + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const int64_t gs2 = blocksize >> 1; // 64 / 2 = 32 + const float *a_ptr = reinterpret_cast(A); + + auto loadc = [&](auto i) + { + constexpr int col = i % COLS; + vc_master[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + auto pre_compute = [&](auto i, int64_t kgs) + { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_set1_ps(0.f); // reset accumulator + + // load scales + if constexpr (row == 0 && col % 2 == 0) + { + // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 + __m512i tmp = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); + scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0)); + scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1)); + } + }; + auto compute = [&](auto i, int64_t k) + { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) + { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0 && col % 2 == 0) + { + __m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast(B + k * ldb + col * 16)); + + // deinterleave and lookup to BF16 + __m256i vb_i8_lo = vb_u4 & mask; + __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; + vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen); + vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen); + vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); + vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); + + if constexpr (PREFETCH_SIZE_K > 0) + { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + auto post_compute = [&](auto i, int64_t kgs) + { + vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]); + }; + for (int64_t k = 0; k < K2; k += gs2) + { + Unroll{}(pre_compute, k / gs2); + for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) + { + Unroll{}(compute, k + k_offset); + } + Unroll{}(post_compute, k / gs2); + } + + auto storec = [&](auto i) + { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + if constexpr (col % 2 == 0) + { + _mm512_storeu_si512( + reinterpret_cast<__m512i *>(C + row * ldc + col * 16), + (__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i]))); + } + }; + Unroll{}(storec); + } + }; + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, blocksize, lda, ldb, ldc, \ + strideBs); + + inline uint16_t float_to_bf16_round(float x) + { + uint32_t u; + std::memcpy(&u, &x, sizeof(u)); + uint32_t lsb = (u >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + u += rounding_bias; + uint16_t hi = static_cast(u >> 16); + // Quiet NaN handling + if ((u & 0x7f800000) == 0x7f800000 && (u & 0x007fffff)) + { + hi = 0xffff; + } + return hi; + } + + template , int> = 0> + inline at::vec::Vectorized convert_from_float_ext(const at::vec::Vectorized& a, const at::vec::Vectorized& b) { + return at::vec::convert_from_float(a, b); + } + + template <> + inline at::vec::Vectorized + convert_from_float_ext(const at::vec::Vectorized& a, const at::vec::Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); + } + + template + inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } + } + + template + struct brgemm + { + static inline void apply( + const at::BFloat16 *__restrict__ A, const unsigned char *__restrict__ B, at::BFloat16 *__restrict__ C, + const at::BFloat16 *__restrict__ Bs, at::BFloat16 *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out) + { + constexpr int BLOCK_N = block_size_n(); + const int ldb_tmp = BLOCK_N; + if (use_brgemm_dequant_out) + { + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, false, A, Btmp, Ctmp); + } + else + { + for (int64_t k = 0; k < K; k += BLOCK_K) + { + int64_t kb_size = std::min(static_cast(BLOCK_K), K - k); + const int64_t kgs = k / blocksize; + + unpack_B( + Btmp, B + (k >> 1) * ldb, Bs + kgs * strideBs, N, kb_size, blocksize, ldb, ldb_tmp, strideBs); + + const bool add_C = k != 0; + at::native::cpublas::brgemm( + M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp); + } + } + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) + { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + }; + + template + void tinygemm_kernel( + const scalar_t *__restrict__ A, const unsigned char *__restrict__ B, scalar_t *__restrict__ C, + const scalar_t *__restrict__ Bs, scalar_t *__restrict__ Btmp, float *__restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool brg, + bool use_brgemm_dequant_out = false) + { + if (brg) + { + brgemm::apply( + A, B, C, Bs, Btmp, Ctmp, M, N, K, blocksize, lda, ldb, ldc, strideBs, use_brgemm_dequant_out); + return; + } + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) + { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) + { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) + { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32, DATA_TYPE); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64, DATA_TYPE); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32, DATA_TYPE); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64, DATA_TYPE); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32, DATA_TYPE); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64, DATA_TYPE); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32, DATA_TYPE); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64, DATA_TYPE); + break; + default: + { + std::fprintf( + stderr, "[bitsandbytes] Unexpected block size %lldx%lld\n", (long long)mb_size, (long long)nb_size); + std::abort(); // or return; if you prefer silent exit + } + } + } + } + } + + template + void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w, + const T *__restrict__ absmax, T *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride) + { + constexpr int64_t BLOCK_M = block_size_m(); // 32 + constexpr int64_t BLOCK_N = block_size_n(); // 32 + const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 + const int64_t NB = div_up(N, BLOCK_N); + // TODO: Find better threshold. + T *Btmp_start = nullptr; + at::Tensor Btmp_t; + const bool use_brgemm = M > 4; + const bool use_brgemm_dequant_out = M > 100; + if (use_brgemm_dequant_out) + { + // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads + Btmp_t = at::zeros({N, K}, c10::CppTypeToScalarType::value); + Btmp_start = Btmp_t.data_ptr(); +#pragma omp parallel for + for (int64_t nb = 0; nb < NB; ++nb) + { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + T *Btmp = Btmp_start + nb_start * K; + for (int64_t k = 0; k < K; k += BLOCK_K) + { + int64_t kb_size = std::min(BLOCK_K, K - k); + int64_t kgs = k / blocksize; + int64_t strideBs = N; + int64_t ldb = nb_size; + const T *Bs = absmax + nb_start; + const unsigned char *Bw = reinterpret_cast(w + nb_start * K / 2); + unpack_B( + Btmp + k * BLOCK_N, Bw + (k >> 1) * ldb, Bs + kgs * strideBs, nb_size, kb_size, blocksize, ldb, + BLOCK_N, strideBs); + } + } + } + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N * K); + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) + { + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + alignas(64) T Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128 + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { // 0-1 + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + int64_t mb_start = mb * BLOCK_M; // 0 + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + tinygemm_kernel( + /* A */ x + mb_start * x_stride, + /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 + /* C */ out + mb_start * out_stride + nb_start, + /* Bs */ absmax + nb_start, + /* Btmp */ use_brgemm_dequant_out ? Btmp_start + nb_start * K : Btmp_inner, + /* Ctmp */ Ctmp, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* gs */ blocksize, // blocksize + /* lda */ x_stride, + /* ldb */ nb_size, + /* ldc */ out_stride, + /* sBs */ N, + /* brg */ use_brgemm, + /* dequant choice*/ use_brgemm_dequant_out + ); + } + } + } + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } }); + } + + //============================================================== + // TEMPLATE DEFINITIONS + //============================================================== + + template void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const at::BFloat16 *__restrict__ x, const unsigned char *__restrict__ w, + const at::BFloat16 *__restrict__ absmax, at::BFloat16 *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + template void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const at::BFloat16 *__restrict__ x, const unsigned char *__restrict__ w, + const at::BFloat16 *__restrict__ absmax, at::BFloat16 *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + } // namespace avx512 +} // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp new file mode 100644 index 0000000..6d5cbcd --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp @@ -0,0 +1,121 @@ +// AVX512 implementation - compile with -mavx512f -mavx512bf16 +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bitsandbytes_cpu +{ + namespace avx512 + { +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + + // block size for AMX gemm + constexpr int block_size_m() { return 2 * TILE_M; } + + constexpr int block_size_n() { return 2 * TILE_N; } + + template + inline int get_cache_blocks(int chunk_size) + { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); + } + +// forced unroll for perf critical path +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + + template + struct Unroll + { + template + ALWAYS_INLINE void operator()(const Func &f, Args... args) const + { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } + }; + + template <> + struct Unroll<1> + { + template + ALWAYS_INLINE void operator()(const Func &f, Args... args) const + { + f(std::integral_constant{}, args...); + } + }; + + template ::value, int>::type = 0> + inline T div_up(T x, T y) + { + return (x + y - 1) / y; + } + + inline int adjust_num_threads(int m) + { + int actual_nth = omp_get_max_threads(); + if (m == 1) + return actual_nth; + return std::max(1, (actual_nth >> 1) * 2); + } + + template + inline void parallel_2d(int m, int n, const func_t &f) + { + int nth = adjust_num_threads(m); + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) + { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) + { + break; + } + } +#pragma omp parallel num_threads(nth) + { + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); + } + } + + typedef enum DataType_t + { + NF4 = 0, + FP4 = 1, + } DataType_t; + + template + void gemm_4bit_inference( + int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w, + const T *__restrict__ absmax, T *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride); + } // namespace avx512 +} // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp new file mode 100644 index 0000000..66abdf2 --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp @@ -0,0 +1,47 @@ +#include "bitsandbytes_cpu.hpp" + +using bitsandbytes_cpu::avx512::DataType_t; + +namespace bitsandbytes_cpu +{ + + // Main dispatcher that selects the best implementation based on runtime CPU features + void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight, + const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type) + { + int64_t M = input.size(0); + int64_t N = weight.size(0); + int64_t K = input.size(1); + // strides + int64_t x_strideM = input.stride(0); + int64_t out_strideM = out.stride(0); + // Runtime CPU feature detection and dispatch + if (CPUFeatures::hasAVX512BF16()) + { + // Use AVX512 optimized implementation + if (quant_type == 1) { + bitsandbytes_cpu::avx512::gemm_4bit_inference( + M, N, K, + input.data_ptr(), + weight.data_ptr(), + absmax.data_ptr(), + out.data_ptr(), + blocksize, x_strideM, out_strideM); + } + else { + bitsandbytes_cpu::avx512::gemm_4bit_inference( + M, N, K, + input.data_ptr(), + weight.data_ptr(), + absmax.data_ptr(), + out.data_ptr(), + blocksize, x_strideM, out_strideM); + } + } + else + { + // raise error for unsupported CPU + throw std::runtime_error("[bitsandbytes] gemm_4bit: CPU does not support AVX512BF16 instruction set required for 4-bit quantization operations."); + } + } +} // namespace bitsandbytes_cpu diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp new file mode 100644 index 0000000..0070b7d --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "cpu_features.hpp" +#include "bitsandbytes_avx512.hpp" +#include +#include +#include + +namespace bitsandbytes_cpu +{ + + // Main dispatcher that selects the best implementation based on runtime CPU features + void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight, + const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type); +} // namespace bitsandbytes_cpu \ No newline at end of file diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp new file mode 100644 index 0000000..1fdbe32 --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp @@ -0,0 +1,25 @@ +#include +#include "bitsandbytes_cpu.hpp" + +// Forward implementation for CPU +torch::Tensor gemm_4bit_cpu_forward( + const torch::Tensor &input, const torch::Tensor &weight, + const torch::Tensor &absmax, int64_t blocksize, int64_t quant_type) +{ + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(absmax.is_contiguous(), "absmax must be contiguous"); + + auto output = at::empty({input.size(0), weight.size(0)}, input.options()); + + bitsandbytes_cpu::gemm_4bit( + input, + weight, + absmax, + output, + blocksize, + quant_type + ); + + return output; +} diff --git a/quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp b/quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp new file mode 100644 index 0000000..de454dc --- /dev/null +++ b/quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp @@ -0,0 +1,176 @@ +#pragma once + +#ifdef _MSC_VER +#include +#else +#include +#endif +#include +#include +#include +#include +namespace bitsandbytes_cpu +{ + + // CPU feature detection + class CPUFeatures + { + public: + static bool hasAVX2() + { + static bool avx2_supported = checkAVX2(); + return avx2_supported; + } + + static bool hasAVX512BF16() + { + static bool bf16_supported = checkAVX512BF16(); + return bf16_supported; + } + + private: + static bool checkAVX2() + { +#ifdef _MSC_VER + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + + if (n_ids >= 7) + { + __cpuidex(cpu_info, 7, 0); + return (cpu_info[1] & (1 << 5)) != 0; // EBX bit 5 + } + return false; +#else + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + __cpuid_count(7, 0, eax, ebx, ecx, edx); + return (ebx & (1 << 5)) != 0; // EBX bit 5 +#endif + } + + static bool checkAVX512() + { +#ifdef _MSC_VER + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + if (n_ids < 7) + return false; + + __cpuidex(cpu_info, 7, 0); + bool avx512f = (cpu_info[1] & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation + if (!avx512f) + return false; + + __cpuid(cpu_info, 1); + bool osxsave = (cpu_info[2] & (1 << 27)) != 0; // ECX bit 27: OSXSAVE + if (!osxsave) + return false; + + // check XCR0: bits 1,2 (SSE/AVX) and 5,6,7 (AVX-512 state) must be enabled by OS + unsigned long long xcr0 = _xgetbv(0); + return ((xcr0 & 0xE6ULL) == 0xE6ULL); +#else + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + + __cpuid_count(7, 0, eax, ebx, ecx, edx); + bool avx512f = (ebx & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation + if (!avx512f) + return false; + + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) == 0) + { + return false; + } + bool osxsave = (ecx & (1 << 27)) != 0; // ECX bit 27: OSXSAVE + if (!osxsave) + return false; + + unsigned int xcr0_lo = 0, xcr0_hi = 0; + __asm__ volatile("xgetbv" : "=a"(xcr0_lo), "=d"(xcr0_hi) : "c"(0)); + unsigned long long xcr0 = ((unsigned long long)xcr0_hi << 32) | xcr0_lo; + // require XCR0 bits 1,2,5,6,7 set -> mask 0xE6 (0b11100110) + return ((xcr0 & 0xE6ULL) == 0xE6ULL); +#endif + } + + static bool checkAVX512BF16() + { + // require AVX-512 foundation supported and OS enabled + if (!checkAVX512()) + return false; + +#ifndef _MSC_VER + // First: try Linux /proc/cpuinfo flags (most robust on Linux) + std::ifstream f("/proc/cpuinfo"); + if (f) + { + std::string line; + while (std::getline(f, line)) + { + // flags line contains many space-separated tokens including avx512_bf16 on supported CPUs + if (line.find("avx512_bf16") != std::string::npos || + line.find("avx512bf16") != std::string::npos) + { + return true; + } + } + } + + // Fallback: attempt CPUID subleaf check if available. + // Note: exact bit position for AVX512_BF16 may differ across vendors/CPUID versions. + // This fallback tries CPUID(7,1) and checks some common positions; if uncertain returns false. + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + unsigned int eax, ebx, ecx, edx; + // try subleaf 1 + __cpuid_count(7, 1, eax, ebx, ecx, edx); + // There isn't a universally agreed constant here in this file; check common candidate bits: + // - some implementations report AVX512_BF16 in ECX/EBX of subleaf 1. + // Try commonly used positions conservatively. + const unsigned int candidate_masks[] = { + (1u << 5), // candidate (may collide with other features) + (1u << 26), // another candidate position + }; + for (unsigned m : candidate_masks) + { + if ((ebx & m) || (ecx & m) || (edx & m)) + { + return true; + } + } + return false; +#else + // On MSVC / Windows, use CPUID if available (simple check). If unsure, return false. + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + if (n_ids < 7) + return false; + __cpuidex(cpu_info, 7, 1); + // same conservative check as above + const int candidate_masks[] = {(1 << 5), (1 << 26)}; + for (int m : candidate_masks) + { + if ((cpu_info[1] & m) || (cpu_info[2] & m) || (cpu_info[3] & m)) + { + return true; + } + } + return false; +#endif + } + }; + +} // namespace bitsandbytes_cpu diff --git a/quantization-bitsandbytes/build.toml b/quantization-bitsandbytes/build.toml new file mode 100644 index 0000000..fe16dcf --- /dev/null +++ b/quantization-bitsandbytes/build.toml @@ -0,0 +1,27 @@ +[general] +name = "quantization_bitsandbytes" +universal = false + +[torch] +src = ["torch-ext/torch_binding.cpp"] + +[kernel.bitsandbytes_cpu] +backend = "cpu" +depends = ["torch"] +src = [ + "bitsandbytes_cpu/bitsandbytes_cpu_torch.cpp", + "bitsandbytes_cpu/bitsandbytes_cpu.cpp", + "bitsandbytes_cpu/bitsandbytes_cpu.hpp", + "bitsandbytes_cpu/cpu_features.hpp", +] +include = ["bitsandbytes_cpu"] + +[kernel.bitsandbytes_cpu_avx512] +backend = "cpu" +depends = ["torch"] +src = [ + "bitsandbytes_cpu/bitsandbytes_avx512.cpp", + "bitsandbytes_cpu/bitsandbytes_avx512.hpp", +] +include = ["bitsandbytes_cpu"] +cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl", "-mavx512dq", "-mavx512bw", "-mavx512vbmi"] diff --git a/quantization-bitsandbytes/flake.lock b/quantization-bitsandbytes/flake.lock new file mode 100644 index 0000000..fcb402f --- /dev/null +++ b/quantization-bitsandbytes/flake.lock @@ -0,0 +1,95 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1764280248, + "narHash": "sha256-WNUbjnUnnDJtwK5ATn7VkkOZ+qyVF1UpKXaTkzbJsmI=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "cc4997dc60691b597aa76a4173541721bc17c224", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1763291491, + "narHash": "sha256-eEYvm+45PPmy+Qe+nZDpn1uhoMUjJwx3PwVVQoO9ksA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/quantization-bitsandbytes/flake.nix b/quantization-bitsandbytes/flake.nix new file mode 100644 index 0000000..b8d645d --- /dev/null +++ b/quantization-bitsandbytes/flake.nix @@ -0,0 +1,21 @@ +{ + description = "Flake for Torch kernel extension"; + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + outputs = + { self, kernel-builder }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + + # This is a workaround, we should be able to specify flags per arch in + # kernel-builder. + torchVersions = + allVersions: + builtins.map ( + version: + version // { systems = builtins.filter (system: system == "x86_64-linux") version.systems; } + ) allVersions; + }; +} diff --git a/quantization-bitsandbytes/tests/__init__.py b/quantization-bitsandbytes/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/quantization-bitsandbytes/tests/test_bitsandbytes.py b/quantization-bitsandbytes/tests/test_bitsandbytes.py new file mode 100644 index 0000000..fe29fdd --- /dev/null +++ b/quantization-bitsandbytes/tests/test_bitsandbytes.py @@ -0,0 +1,105 @@ +import torch +import pytest + +from quantization_bitsandbytes import gemm_4bit_forward + +def unpack_weight_packed_for_cpu(packed_qweight: torch.Tensor, block_n: int = 32): + """ + Inverse of convert_weight_packed_for_cpu. + packed_qweight: (N, K//2) uint8, each byte = (high<<4)|low, both 4-bit values in 0..15 + returns: qweight_final (N, K) uint8 with original 4-bit values (0..15) + """ + assert packed_qweight.dtype == torch.uint8 + assert packed_qweight.dim() == 2 + N, K_half = packed_qweight.shape + assert N % block_n == 0 + BIT_COUNT = block_n # 32 + # reshape to rows of 32 packed bytes + qw = packed_qweight.reshape(-1, BIT_COUNT) # [(N//block_n)*K_half, 32] + low = (qw & 0x0F) + high = (qw >> 4) & 0x0F + # restore 64 nibbles (low first then high, matching original pack order) + restored = torch.cat([low, high], dim=1) # [..., 64] + # reshape back (inverse of flatten) + restored = restored.reshape(N // block_n, K_half, block_n, 2) # [N/block_n, K//2, block_n, 2] + # inverse transpose + restored = restored.transpose(-3, -2) # [N/block_n, block_n, K//2, 2] + # final shape + qweight_final = restored.reshape(N, K_half * 2).to(torch.uint8) + return qweight_final + + +_NF4_QUANT_TABLE = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=torch.float32, +) + +_FP4_QUANT_TABLE = torch.tensor( + [ + 0.0000, + 0.0052, + 0.6667, + 1.0000, + 0.3333, + 0.5000, + 0.1667, + 0.2500, + 0.0000, + -0.0052, + -0.6667, + -1.0000, + -0.3333, + -0.5000, + -0.1667, + -0.2500, + ], + dtype=torch.float32, +) + +def ref_gemm_4bit(x, packed_weight, scales, group_size, quant_type): + unpacked_weight = unpack_weight_packed_for_cpu(packed_weight) + shape = unpacked_weight.shape + table = _FP4_QUANT_TABLE if quant_type == 1 else _NF4_QUANT_TABLE + original_weight = table[unpacked_weight.reshape(-1).int()].reshape(shape) * scales.T.repeat_interleave(group_size, dim=1) + res = torch.matmul(x, original_weight.T.to(x.dtype)) + return res + +@pytest.mark.parametrize("M", [1, 31, 244, 1024, 2666]) +@pytest.mark.parametrize("K", [2048, 4096, 14336]) +@pytest.mark.parametrize("N", [1024, 4096, 7168]) +@pytest.mark.parametrize("group_size", [64, 256]) +@pytest.mark.parametrize("quant_type", [0, 1]) +def test_bitsandbytes(M, K, N, group_size, quant_type): + device = torch.device("cpu") + dtype = torch.bfloat16 + group_size = 64 + + assert K % group_size == 0 + assert K % 2 == 0 + + x = torch.randn((M, K), device=device, dtype=dtype) * 0.1 + w = torch.randint(0, 15, (N, K // 2), device=device, dtype=torch.uint8) + num_groups = K // group_size + scales = torch.rand((num_groups, N), device=device, dtype=torch.bfloat16).pow(4.0) + + output = gemm_4bit_forward(x, w, scales, group_size, quant_type) + ref_out = ref_gemm_4bit(x, w, scales, group_size, quant_type) + + torch.testing.assert_close(output, ref_out, atol=1e-1, rtol=1e-2) diff --git a/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py new file mode 100644 index 0000000..349225b --- /dev/null +++ b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/__init__.py @@ -0,0 +1,3 @@ +from .custom_ops import gemm_4bit_forward + +__all__ = ["gemm_4bit_forward"] \ No newline at end of file diff --git a/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py new file mode 100644 index 0000000..7e77afc --- /dev/null +++ b/quantization-bitsandbytes/torch-ext/quantization_bitsandbytes/custom_ops.py @@ -0,0 +1,19 @@ +import torch +from ._ops import ops + +def gemm_4bit_forward( + input: torch.Tensor, + weight: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: int, +) -> torch.Tensor: + original_dtype = input.dtype + if original_dtype != torch.bfloat16: + input = input.to(torch.bfloat16) + + output = ops.gemm_4bit_forward(input, weight, absmax, blocksize, quant_type) + if original_dtype != torch.bfloat16: + output = output.to(original_dtype) + + return output diff --git a/quantization-bitsandbytes/torch-ext/torch_binding.cpp b/quantization-bitsandbytes/torch-ext/torch_binding.cpp new file mode 100644 index 0000000..ee642cd --- /dev/null +++ b/quantization-bitsandbytes/torch-ext/torch_binding.cpp @@ -0,0 +1,42 @@ +#include +#include "registration.h" + + +#if defined(CPU_KERNEL) +torch::Tensor gemm_4bit_cpu_forward( + const torch::Tensor &input, + const torch::Tensor &weight, + const torch::Tensor &absmax, + int64_t blocksize, + int64_t quant_type); +#endif +// Unified dispatcher for both CPU and XPU +torch::Tensor gemm_4bit_forward( + const torch::Tensor &input, + const torch::Tensor &weight, + const torch::Tensor &absmax, + int64_t blocksize, + int64_t quant_type) { +#if defined(CPU_KERNEL) + if (input.device().type() == torch::kCPU) { + TORCH_CHECK(input.device().type() == torch::kCPU, "input must be on CPU"); + TORCH_CHECK(weight.device().type() == torch::kCPU, "weight must be on CPU"); + TORCH_CHECK(absmax.device().type() == torch::kCPU, "absmax must be on CPU"); + TORCH_CHECK(blocksize > 0, "blocksize must be > 0"); + return gemm_4bit_cpu_forward(input, weight, absmax, blocksize, quant_type); + } +#endif + else { + TORCH_CHECK(false, "Unsupported device type: ", input.device().type()); + } +} + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("gemm_4bit_forward(Tensor input, Tensor weight, Tensor absmax, int blocksize, int quant_type) -> Tensor"); +#if defined(CPU_KERNEL) + // Register CPU implementation + ops.impl("gemm_4bit_forward", torch::kCPU, &gemm_4bit_forward); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME)