Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
500 changes: 500 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp

Large diffs are not rendered by default.

121 changes: 121 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// AVX512 implementation - compile with -mavx512f -mavx512bf16
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <thread>
#include <type_traits>
#include <omp.h>
#include <immintrin.h>

namespace bitsandbytes_cpu
{
namespace avx512
{
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K

// block size for AMX gemm
constexpr int block_size_m() { return 2 * TILE_M; }

constexpr int block_size_n() { return 2 * TILE_N; }

template <typename T>
inline int get_cache_blocks(int chunk_size)
{
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
}

// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif

template <int n>
struct Unroll
{
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func &f, Args... args) const
{
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};

template <>
struct Unroll<1>
{
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func &f, Args... args) const
{
f(std::integral_constant<int, 0>{}, args...);
}
};

template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y)
{
return (x + y - 1) / y;
}

inline int adjust_num_threads(int m)
{
int actual_nth = omp_get_max_threads();
if (m == 1)
return actual_nth;
return std::max(1, (actual_nth >> 1) * 2);
}

template <typename func_t>
inline void parallel_2d(int m, int n, const func_t &f)
{
int nth = adjust_num_threads(m);
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m)
{
nth_n = nth / nth_m;
if (nth_m * nth_n == nth)
{
break;
}
}
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;

int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);

int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);

f(begin_m, end_m, begin_n, end_n);
}
}

typedef enum DataType_t
{
NF4 = 0,
FP4 = 1,
} DataType_t;

template <typename T, int DATA_TYPE>
void gemm_4bit_inference(
int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w,
const T *__restrict__ absmax, T *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
} // namespace avx512
} // namespace bitsandbytes_cpu
82 changes: 82 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "bitsandbytes_cpu.hpp"

using bitsandbytes_cpu::avx512::DataType_t;

namespace bitsandbytes_cpu
{

static const std::vector<float> NF4_DATA = {
-1.00000000f, -0.69619280f, -0.52507305f, -0.39491749f,
-0.28444138f, -0.18477343f, -0.09105004f, 0.00000000f,
0.07958030f, 0.16093020f, 0.24611230f, 0.33791524f,
0.44070983f, 0.56261700f, 0.72295684f, 1.00000000f
};

static const std::vector<float> FP4_DATA = {
0.0000f, 0.0052f, 0.6667f, 1.0000f,
0.3333f, 0.5000f, 0.1667f, 0.2500f,
0.0000f, -0.0052f, -0.6667f, -1.0000f,
-0.3333f, -0.5000f, -0.1667f, -0.2500f
};

// Main dispatcher that selects the best implementation based on runtime CPU features
void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight,
const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type)
{
int64_t M = input.size(0);
int64_t N = weight.size(0);
int64_t K = input.size(1);
// strides
int64_t x_strideM = input.stride(0);
int64_t out_strideM = out.stride(0);
// Runtime CPU feature detection and dispatch
if (CPUFeatures::hasAVX512BF16())
{
// Use AVX512 optimized implementation
if (quant_type == 1) {
bitsandbytes_cpu::avx512::gemm_4bit_inference<at::BFloat16, DataType_t::FP4>(
M, N, K,
input.data_ptr<at::BFloat16>(),
weight.data_ptr<unsigned char>(),
absmax.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>(),
blocksize, x_strideM, out_strideM);
}
else {
bitsandbytes_cpu::avx512::gemm_4bit_inference<at::BFloat16, DataType_t::NF4>(
M, N, K,
input.data_ptr<at::BFloat16>(),
weight.data_ptr<unsigned char>(),
absmax.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>(),
blocksize, x_strideM, out_strideM);
}
}
else
{
const int64_t packing_block_n = 32;
int64_t N = weight.size(0);
int64_t K_half = weight.size(1);
auto device = weight.device();
auto qw = weight.reshape({-1, packing_block_n});
auto low = torch::bitwise_and(qw, 0x0F);
auto high = torch::bitwise_and(torch::bitwise_right_shift(qw, 4), 0x0F);
auto restored = torch::cat({low, high}, 1);
restored = restored.reshape({N / packing_block_n, K_half, packing_block_n, 2});
restored = restored.transpose(-3, -2);
auto unpacked_weight = restored.reshape({N, K_half * 2});
torch::Tensor table;
if (quant_type == 1) {
table = torch::tensor(FP4_DATA, torch::dtype(torch::kFloat32)).to(device); // FP4
} else {
table = torch::tensor(NF4_DATA, torch::dtype(torch::kFloat32)).to(device); // NF4
}
auto indices = unpacked_weight.to(torch::kLong);
auto dequantized_weight = table.index({indices});
auto scales_expanded = absmax.t().repeat_interleave(blocksize, 1);
auto original_weight = dequantized_weight * scales_expanded;
auto weight_final = original_weight.t().to(input.dtype());
torch::matmul_out(out, input, weight_final);
}
}
} // namespace bitsandbytes_cpu
15 changes: 15 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "cpu_features.hpp"
#include "bitsandbytes_avx512.hpp"
#include <torch/all.h>
#include <stdexcept>
#include <ATen/ATen.h>

namespace bitsandbytes_cpu
{

// Main dispatcher that selects the best implementation based on runtime CPU features
void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight,
const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type);
} // namespace bitsandbytes_cpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <torch/all.h>
#include "bitsandbytes_cpu.hpp"

// Forward implementation for CPU
torch::Tensor gemm_4bit_cpu_forward(
const torch::Tensor &input, const torch::Tensor &weight,
const torch::Tensor &absmax, int64_t blocksize, int64_t quant_type)
{
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous");
TORCH_CHECK(absmax.is_contiguous(), "absmax must be contiguous");

auto output = at::empty({input.size(0), weight.size(0)}, input.options());

bitsandbytes_cpu::gemm_4bit(
input,
weight,
absmax,
output,
blocksize,
quant_type
);

return output;
}
Loading
Loading