Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions quartet2/csrc/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void rht128_requant(
void eden_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long seed, long rows, long cols);
void rtn_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols);
void four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols);
void gridflip_four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, float grid_shift, long rows, long cols);
void dequant_tp_had_quant(
__nv_fp4x2_storage_t* y, __nv_fp8_e4m3* scales_fp8, float* global_scale_ptr,
nv_bfloat16* scratch_scales, unsigned* max_scale, const nv_bfloat16* h,
Expand Down Expand Up @@ -272,6 +273,38 @@ void four_six_fp4_binding(
scale_override, inp.shape(0), inp.shape(1));
}

void gridflip_four_six_fp4_binding(
const CudaArray<>& out,
const CudaArray<>& scales,
const CudaArray<float, nb::shape<>>& global_scale,
const CudaArray<nb::ro>& inp,
const CudaArray<float, nb::ro>& amax_ptr,
float scale_override,
float grid_shift
)
{
nb::dlpack::dtype bf16_dt{static_cast<std::uint8_t>(nb::dlpack::dtype_code::Bfloat), 16, 1};

CHECK_EQ(inp.ndim(), 2ul);
CHECK_EQ(out.ndim(), 2ul);

CHECK_EQ(out.shape(0), inp.shape(0));
CHECK_EQ(out.size(), inp.size() / 2);
CHECK_EQ(out.dtype().bits, static_cast<uint8_t>(8));
CHECK_EQ(scales.size(), inp.size() / 16);
CHECK_EQ(scales.dtype().bits, static_cast<uint8_t>(8));
CHECK_EQ(inp.dtype(), bf16_dt);
CHECK(global_scale.data() != amax_ptr.data());

gridflip_four_six_fp4(
reinterpret_cast<__nv_fp4x4_e2m1*>(out.data()),
reinterpret_cast<__nv_fp8_e4m3*>(scales.data()),
global_scale.data(),
reinterpret_cast<const nv_bfloat16*>(inp.data()),
amax_ptr.data(),
scale_override, grid_shift, inp.shape(0), inp.shape(1));
}

void rtn_fp4_binding(
const CudaArray<>& out,
const CudaArray<>& scales,
Expand Down Expand Up @@ -422,5 +455,6 @@ NB_MODULE(_quartet2, m) {

m.def("eden_fp4", &eden_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override"), nb::arg("seed"));
m.def("four_six_fp4", &four_six_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override"));
m.def("gridflip_four_six_fp4", &gridflip_four_six_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override"), nb::arg("grid_shift"));
m.def("rtn_fp4", &rtn_fp4_binding, nb::arg("out"), nb::arg("scales"), nb::arg("global_scale"), nb::arg("input"), nb::arg("amax"), nb::arg("scale_override"));
}
130 changes: 130 additions & 0 deletions quartet2/csrc/round_four_six.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@ __device__ __forceinline__ QuantResult quantize(float abs_max, float inv_val_max
return QuantResult{result, s_round_fp8, s_as_fp8};
}

__device__ __forceinline__ QuantResult quantize_gridflip(float abs_max, float inv_val_max, float scale, float grid_shift, bf16x8& x) {
float s_group = abs_max * inv_val_max;
float inv_scale = reciprocal_approximate_ftz(scale);
__nv_fp8_e4m3 s_as_fp8 = static_cast<__nv_fp8_e4m3>(s_group * inv_scale);
float s_round_fp8 = static_cast<float>(s_as_fp8);
if (s_round_fp8 == 0) s_round_fp8 = 1.f;

float factor = reciprocal_approximate_ftz(s_round_fp8 * scale);
float2 factor2 = {-factor, -factor};
fp4x8 result;
for (int k = 0; k < bf16x8::size; k += 2) {
float2 src = make_float2(static_cast<float>(x[k+0]), static_cast<float>(x[k+1]));
float2 prod = __fmul2_rn(src, factor2);
float2 scaled = {prod.x - grid_shift, prod.y - grid_shift};
unsigned char bits = __nv_cvt_float2_to_fp4x2(scaled, __nv_fp4_interpretation_t::__NV_E2M1, cudaRoundMode::cudaRoundNearest);
result[k/2] = bits;
}

return QuantResult{result, s_round_fp8, s_as_fp8};
}

__forceinline__ __device__ float quant_error(bf16x8 x, const QuantResult& q, float scale) {
const float descale = static_cast<float>(q.fp8s) * scale;
float2 sum = {0.f, 0.f};
Expand All @@ -56,6 +77,24 @@ __forceinline__ __device__ float quant_error(bf16x8 x, const QuantResult& q, flo
return local_error;
}

__forceinline__ __device__ float gridflip_quant_error(bf16x8 x, const QuantResult& q, float scale, float grid_shift) {
const float descale = static_cast<float>(q.fp8s) * scale;
float2 sum = {0.f, 0.f};
for (int i = 0; i < 4; ++i) {
float2 dq = __nv_cvt_fp4x2_to_float2(q.bits[i]);
float2 xv = {static_cast<float>(x[2*i+0]), static_cast<float>(x[2*i+1])};
float2 recon = {
-(dq.x + grid_shift) * descale,
-(dq.y + grid_shift) * descale,
};
float2 d = {recon.x - xv.x, recon.y - xv.y};
sum = __ffma2_rn(d, d, sum);
}
float local_error = sum.x + sum.y;
local_error += __shfl_xor_sync(0xffffffff, local_error, 1);
return local_error;
}


template<float... Others>
struct get_candidate_helper;
Expand All @@ -76,6 +115,24 @@ struct get_candidate_helper<> {
}
};

template<float... Others>
struct get_candidate_value_helper;

template<float Value, float... Others>
struct get_candidate_value_helper<Value, Others...> {
static constexpr __forceinline__ __device__ float get(int i) {
if (i == 0) return Value;
return get_candidate_value_helper<Others...>::get(i - 1);
}
};

template<>
struct get_candidate_value_helper<> {
static constexpr __forceinline__ __device__ float get(int i) {
__builtin_unreachable();
}
};

template<float... Candidates>
__global__ void four_six_fp4_kernel(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, int nvecs, int cols) {
constexpr int NumCandidates = sizeof...(Candidates);
Expand Down Expand Up @@ -122,6 +179,63 @@ __global__ void four_six_fp4_kernel(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale
}
}

template<float... Candidates>
__global__ void gridflip_four_six_fp4_kernel(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, float grid_shift, int nvecs, int cols) {
constexpr int NumCandidates = sizeof...(Candidates);
float global_abs_max = *amax_ptr;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if(idx >= nvecs) return;

bf16x8 x = bf16x8::load(x_ptr + 8 * idx);

constexpr float inv_scales_max = NumCandidates > 1 ? 1.f / 256.f : 1.f / 448.f;
constexpr float one_over_six = 1.f / 6.f;
float inv_val_max = scale_override * one_over_six;
float scale = global_abs_max == 0 ? 1.f : global_abs_max * inv_scales_max * inv_val_max;
if (idx == 0) {
global_scale_ptr[0] = scale;
}

nv_bfloat16 local_abs_max = vecReduceAbsMax(x);
nv_bfloat16 other_abs_max = __shfl_xor_sync(0xffffffff, local_abs_max, 1);
float full_abs_max = static_cast<float>(__hmax(local_abs_max, other_abs_max));

QuantResult best_standard_res;
float best_standard = INFINITY;
QuantResult best_gridflip_res;
float best_gridflip = INFINITY;
for (int i = 0; i < NumCandidates; ++i) {
float inv_val = get_candidate_helper<Candidates...>::get_inv(i);
QuantResult standard_res = quantize(full_abs_max, inv_val * scale_override, scale, x);
float standard_score = quant_error(x, standard_res, scale);
if (standard_score < best_standard) {
best_standard = standard_score;
best_standard_res = standard_res;
}

float value = get_candidate_value_helper<Candidates...>::get(i);
QuantResult gridflip_res = quantize_gridflip(full_abs_max, scale_override / (value + grid_shift), scale, grid_shift, x);
float gridflip_score = gridflip_quant_error(x, gridflip_res, scale, grid_shift);
if (gridflip_score < best_gridflip) {
best_gridflip = gridflip_score;
best_gridflip_res = gridflip_res;
}
}

bool use_gridflip = best_gridflip < best_standard;
QuantResult res = use_gridflip ? best_gridflip_res : best_standard_res;
res.bits.store(reinterpret_cast<unsigned char*>(y_ptr) + 4 * idx);
if (idx % 2 == 0) {
int col = (idx / 2) % cols;
int row = (idx / 2) / cols;
unsigned char scale_bits = *reinterpret_cast<unsigned char*>(&res.fp8s);
if (use_gridflip) {
scale_bits |= 0x80;
}
reinterpret_cast<unsigned char*>(scale_ptr)[row * cols + col] = scale_bits;
}
}

void four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols) {
if (cols % 128 != 0) throw std::invalid_argument("four_six_fp4: cols must be divisible by 128");
CHECK_POINTER(y_ptr);
Expand All @@ -137,6 +251,22 @@ void four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* globa
CUDA_CHECK(cudaGetLastError());
}

void gridflip_four_six_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, float grid_shift, long rows, long cols) {
if (cols % 128 != 0) throw std::invalid_argument("gridflip_four_six_fp4: cols must be divisible by 128");
if (grid_shift < 0.f) throw std::invalid_argument("gridflip_four_six_fp4: grid_shift must be non-negative");
CHECK_POINTER(y_ptr);
CHECK_POINTER(scale_ptr);
CHECK_POINTER(x_ptr);
CHECK_POINTER_NOT_NULL(global_scale_ptr);
CHECK_POINTER_NOT_NULL(amax_ptr);

int n_vecs = rows * cols / 8;
int block_size = 256;
int n_blocks = (n_vecs + block_size - 1) / block_size;
gridflip_four_six_fp4_kernel<6.f, 4.f><<<n_blocks, block_size>>>(y_ptr, scale_ptr, global_scale_ptr, x_ptr, amax_ptr, scale_override, grid_shift, n_vecs, cols / 16);
CUDA_CHECK(cudaGetLastError());
}

void rtn_fp4(__nv_fp4x4_e2m1* y_ptr, __nv_fp8_e4m3* scale_ptr, float* global_scale_ptr, const nv_bfloat16* x_ptr, const float* amax_ptr, float scale_override, long rows, long cols) {
if (cols % 128 != 0) throw std::invalid_argument("rtn_fp4: cols must be divisible by 128");
CHECK_POINTER(y_ptr);
Expand Down
Loading