Skip to content
Open

Bnb #77

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
500 changes: 500 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.cpp

Large diffs are not rendered by default.

121 changes: 121 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_avx512.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// AVX512 implementation - compile with -mavx512f -mavx512bf16
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <thread>
#include <type_traits>
#include <omp.h>
#include <immintrin.h>

namespace bitsandbytes_cpu
{
namespace avx512
{
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K

// block size for AMX gemm
constexpr int block_size_m() { return 2 * TILE_M; }

constexpr int block_size_n() { return 2 * TILE_N; }

template <typename T>
inline int get_cache_blocks(int chunk_size)
{
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
}

// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif

template <int n>
struct Unroll
{
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func &f, Args... args) const
{
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};

template <>
struct Unroll<1>
{
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func &f, Args... args) const
{
f(std::integral_constant<int, 0>{}, args...);
}
};

template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y)
{
return (x + y - 1) / y;
}

inline int adjust_num_threads(int m)
{
int actual_nth = omp_get_max_threads();
if (m == 1)
return actual_nth;
return std::max(1, (actual_nth >> 1) * 2);
}

template <typename func_t>
inline void parallel_2d(int m, int n, const func_t &f)
{
int nth = adjust_num_threads(m);
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m)
{
nth_n = nth / nth_m;
if (nth_m * nth_n == nth)
{
break;
}
}
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;

int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);

int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);

f(begin_m, end_m, begin_n, end_n);
}
}

typedef enum DataType_t
{
NF4 = 0,
FP4 = 1,
} DataType_t;

template <typename T, int DATA_TYPE>
void gemm_4bit_inference(
int64_t M, int64_t N, int64_t K, const T *__restrict__ x, const unsigned char *__restrict__ w,
const T *__restrict__ absmax, T *__restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
} // namespace avx512
} // namespace bitsandbytes_cpu
47 changes: 47 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "bitsandbytes_cpu.hpp"

using bitsandbytes_cpu::avx512::DataType_t;

namespace bitsandbytes_cpu
{

// Main dispatcher that selects the best implementation based on runtime CPU features
void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight,
const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type)
{
int64_t M = input.size(0);
int64_t N = weight.size(0);
int64_t K = input.size(1);
// strides
int64_t x_strideM = input.stride(0);
int64_t out_strideM = out.stride(0);
// Runtime CPU feature detection and dispatch
if (CPUFeatures::hasAVX512BF16())
{
// Use AVX512 optimized implementation
if (quant_type == 1) {
bitsandbytes_cpu::avx512::gemm_4bit_inference<at::BFloat16, DataType_t::FP4>(
M, N, K,
input.data_ptr<at::BFloat16>(),
weight.data_ptr<unsigned char>(),
absmax.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>(),
blocksize, x_strideM, out_strideM);
}
else {
bitsandbytes_cpu::avx512::gemm_4bit_inference<at::BFloat16, DataType_t::NF4>(
M, N, K,
input.data_ptr<at::BFloat16>(),
weight.data_ptr<unsigned char>(),
absmax.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>(),
blocksize, x_strideM, out_strideM);
}
}
else
{
// raise error for unsupported CPU
throw std::runtime_error("[bitsandbytes] gemm_4bit: CPU does not support AVX512BF16 instruction set required for 4-bit quantization operations.");
}
}
} // namespace bitsandbytes_cpu
15 changes: 15 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/bitsandbytes_cpu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "cpu_features.hpp"
#include "bitsandbytes_avx512.hpp"
#include <torch/all.h>
#include <stdexcept>
#include <ATen/ATen.h>

namespace bitsandbytes_cpu
{

// Main dispatcher that selects the best implementation based on runtime CPU features
void gemm_4bit(const torch::Tensor &input, const torch::Tensor &weight,
const torch::Tensor &absmax, torch::Tensor &out, int64_t blocksize, int64_t quant_type);
} // namespace bitsandbytes_cpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <torch/all.h>
#include "bitsandbytes_cpu.hpp"

// Forward implementation for CPU
torch::Tensor gemm_4bit_cpu_forward(
const torch::Tensor &input, const torch::Tensor &weight,
const torch::Tensor &absmax, int64_t blocksize, int64_t quant_type)
{
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous");
TORCH_CHECK(absmax.is_contiguous(), "absmax must be contiguous");

auto output = at::empty({input.size(0), weight.size(0)}, input.options());

bitsandbytes_cpu::gemm_4bit(
input,
weight,
absmax,
output,
blocksize,
quant_type
);

return output;
}
176 changes: 176 additions & 0 deletions quantization-bitsandbytes/bitsandbytes_cpu/cpu_features.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#pragma once

#ifdef _MSC_VER
#include <intrin.h>
#else
#include <cpuid.h>
#endif
#include <fstream>
#include <iostream>
#include <string>
#include <cstdlib>
namespace bitsandbytes_cpu
{

// CPU feature detection
class CPUFeatures
{
public:
static bool hasAVX2()
{
static bool avx2_supported = checkAVX2();
return avx2_supported;
}

static bool hasAVX512BF16()
{
static bool bf16_supported = checkAVX512BF16();
return bf16_supported;
}

private:
static bool checkAVX2()
{
#ifdef _MSC_VER
int cpu_info[4];
__cpuid(cpu_info, 0);
int n_ids = cpu_info[0];

if (n_ids >= 7)
{
__cpuidex(cpu_info, 7, 0);
return (cpu_info[1] & (1 << 5)) != 0; // EBX bit 5
}
return false;
#else
unsigned int eax, ebx, ecx, edx;
if (__get_cpuid_max(0, nullptr) < 7)
{
return false;
}
__cpuid_count(7, 0, eax, ebx, ecx, edx);
return (ebx & (1 << 5)) != 0; // EBX bit 5
#endif
}

static bool checkAVX512()
{
#ifdef _MSC_VER
int cpu_info[4];
__cpuid(cpu_info, 0);
int n_ids = cpu_info[0];
if (n_ids < 7)
return false;

__cpuidex(cpu_info, 7, 0);
bool avx512f = (cpu_info[1] & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation
if (!avx512f)
return false;

__cpuid(cpu_info, 1);
bool osxsave = (cpu_info[2] & (1 << 27)) != 0; // ECX bit 27: OSXSAVE
if (!osxsave)
return false;

// check XCR0: bits 1,2 (SSE/AVX) and 5,6,7 (AVX-512 state) must be enabled by OS
unsigned long long xcr0 = _xgetbv(0);
return ((xcr0 & 0xE6ULL) == 0xE6ULL);
#else
unsigned int eax, ebx, ecx, edx;
if (__get_cpuid_max(0, nullptr) < 7)
{
return false;
}

__cpuid_count(7, 0, eax, ebx, ecx, edx);
bool avx512f = (ebx & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation
if (!avx512f)
return false;

if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) == 0)
{
return false;
}
bool osxsave = (ecx & (1 << 27)) != 0; // ECX bit 27: OSXSAVE
if (!osxsave)
return false;

unsigned int xcr0_lo = 0, xcr0_hi = 0;
__asm__ volatile("xgetbv" : "=a"(xcr0_lo), "=d"(xcr0_hi) : "c"(0));
unsigned long long xcr0 = ((unsigned long long)xcr0_hi << 32) | xcr0_lo;
// require XCR0 bits 1,2,5,6,7 set -> mask 0xE6 (0b11100110)
return ((xcr0 & 0xE6ULL) == 0xE6ULL);
#endif
}

static bool checkAVX512BF16()
{
// require AVX-512 foundation supported and OS enabled
if (!checkAVX512())
return false;

#ifndef _MSC_VER
// First: try Linux /proc/cpuinfo flags (most robust on Linux)
std::ifstream f("/proc/cpuinfo");
if (f)
{
std::string line;
while (std::getline(f, line))
{
// flags line contains many space-separated tokens including avx512_bf16 on supported CPUs
if (line.find("avx512_bf16") != std::string::npos ||
line.find("avx512bf16") != std::string::npos)
{
return true;
}
}
}

// Fallback: attempt CPUID subleaf check if available.
// Note: exact bit position for AVX512_BF16 may differ across vendors/CPUID versions.
// This fallback tries CPUID(7,1) and checks some common positions; if uncertain returns false.
if (__get_cpuid_max(0, nullptr) < 7)
{
return false;
}
unsigned int eax, ebx, ecx, edx;
// try subleaf 1
__cpuid_count(7, 1, eax, ebx, ecx, edx);
// There isn't a universally agreed constant here in this file; check common candidate bits:
// - some implementations report AVX512_BF16 in ECX/EBX of subleaf 1.
// Try commonly used positions conservatively.
const unsigned int candidate_masks[] = {
(1u << 5), // candidate (may collide with other features)
(1u << 26), // another candidate position
};
for (unsigned m : candidate_masks)
{
if ((ebx & m) || (ecx & m) || (edx & m))
{
return true;
}
}
return false;
#else
// On MSVC / Windows, use CPUID if available (simple check). If unsure, return false.
int cpu_info[4];
__cpuid(cpu_info, 0);
int n_ids = cpu_info[0];
if (n_ids < 7)
return false;
__cpuidex(cpu_info, 7, 1);
// same conservative check as above
const int candidate_masks[] = {(1 << 5), (1 << 26)};
for (int m : candidate_masks)
{
if ((cpu_info[1] & m) || (cpu_info[2] & m) || (cpu_info[3] & m))
{
return true;
}
}
return false;
#endif
}
};

} // namespace bitsandbytes_cpu
Loading