diff --git a/include/fbgemm/QuantUtilsNeon.h b/include/fbgemm/QuantUtilsNeon.h index 13169c8a05..69845ec67a 100644 --- a/include/fbgemm/QuantUtilsNeon.h +++ b/include/fbgemm/QuantUtilsNeon.h @@ -36,6 +36,13 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( int input_columns, OutputType* output); +template +void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm #endif // __aarch64__ diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 5301909193..06030df46b 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -636,6 +636,26 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( throw std::runtime_error("Unsupported number of columns"); } +#if HAVE_SVE + switch (bit_rate) { + case 2: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + input, input_rows, input_columns, output); + break; + case 4: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + input, input_rows, input_columns, output); + break; + case 8: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + input, input_rows, input_columns, output); + break; + default: + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, input, input_rows, input_columns, output); + } +#else + if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 switch (bit_rate) { @@ -660,6 +680,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( bit_rate, input, input_rows, input_columns, output); } + +#endif } template diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index 8fef86b94f..7868fbb519 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -95,8 +95,12 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) { #if HAVE_SVE -static inline void -FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) { +template +static inline void FindMinMaxImpl_f16( + const float16_t* m, + OutType* min, + OutType* max, + uint64_t count) { float16_t first = *m; float16_t tmp_min_s = first; @@ -141,8 +145,8 @@ FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) { tmp_max_s = vmaxh_f16(tmp_max_s, tmp); } - *min = static_cast(tmp_min_s); - *max = static_cast(tmp_max_s); + *min = static_cast(tmp_min_s); + *max = static_cast(tmp_max_s); } template @@ -257,6 +261,236 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( } // for each row } +template +void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output) { + if (input_rows == 0 || input_columns <= 0) { + return; + } + + static_assert( + std::is_same() || std::is_same(), + "Only float and float16 types are allowed."); + + static_assert( + (BIT_RATE == 8) || (BIT_RATE == 4) || (BIT_RATE == 2), + "Only bit rates of 8, 4 and 2 are allowed."); + + constexpr uint64_t num_elem_per_byte = 8 / BIT_RATE; + uint64_t column_count = static_cast(input_columns); + const int output_columns = + (column_count + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(float16); + + for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) { + const InputType* input_row = input + row * column_count; + std::uint8_t* output_row = output + row * output_columns; + float16_t* output_row_scale_bias = reinterpret_cast( + output_row + + (column_count + num_elem_per_byte - 1) / num_elem_per_byte); + + float minimum_element; + float maximum_element; + float16_t minimum_element_fp16; + if constexpr (std::is_same()) { + FindMinMaxImpl_f32( + input_row, &minimum_element, &maximum_element, column_count); + minimum_element_fp16 = static_cast(minimum_element); + minimum_element = static_cast(minimum_element_fp16); + } else { + float16_t maximum_element_fp16; + FindMinMaxImpl_f16( + reinterpret_cast(input_row), + &minimum_element_fp16, + &maximum_element_fp16, + column_count); + minimum_element = static_cast(minimum_element_fp16); + maximum_element = static_cast(maximum_element_fp16); + } + + const float range = maximum_element - minimum_element; + + float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1); + float16_t scale_fp16 = static_cast(scale); + scale = static_cast(scale_fp16); + svfloat32_t inverse_scale_sv; + if (scale != 0.0f) { + float inverse_scale = 1.0f / scale; + inverse_scale_sv = svdup_n_f32(inverse_scale); + bool isInf = svptest_any( + svptrue_b8(), + svcmpuo_f32( + svptrue_b8(), + svsub_f32_x(svptrue_b8(), inverse_scale_sv, inverse_scale_sv), + svdup_n_f32(0.0))); + if (isInf) { + scale_fp16 = static_cast(1.0f); + scale = 1.0f; + inverse_scale_sv = svdup_n_f32(1.0f); + } + } else { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because X - minimum_element will be 0 for all X + scale_fp16 = static_cast(1.0f); + scale = 1.0f; + inverse_scale_sv = svdup_n_f32(1.0f); + } + + constexpr uint64_t kItemsPerIter = 8; + uint64_t loopIters = column_count / kItemsPerIter; + uint64_t loopRemainder = column_count % kItemsPerIter; + + output_row_scale_bias[0] = scale_fp16; + output_row_scale_bias[1] = minimum_element_fp16; + + float32x4_t inverse_scale_v = svget_neonq(inverse_scale_sv); + float32x4_t min_v = vdupq_n_f32(minimum_element); + + constexpr unsigned int maxValPerBitRate = (1ul << BIT_RATE) - 1; + uint32x4_t maxval_v = vdupq_n_u32(maxValPerBitRate); + + svbool_t lastPredA = svwhilelt_b32_u64(0, loopRemainder); + svbool_t lastPredB = svwhilelt_b32_u64(4, loopRemainder); + + while (__builtin_expect(loopIters > 0, 1)) { + float32x4_t v0; + float32x4_t v1; + + if constexpr (std::is_same()) { + v0 = vld1q_f32(input_row); + v1 = vld1q_f32(input_row + 4); + } else { + float16x8_t h0 = + vld1q_f16(reinterpret_cast(input_row)); + v0 = vcvt_f32_f16(vget_low_f16(h0)); + v1 = vcvt_high_f32_f16(h0); + } + + input_row += kItemsPerIter; + loopIters -= 1; + + v0 = vsubq_f32(v0, min_v); + v1 = vsubq_f32(v1, min_v); + + v0 = vmulq_f32(v0, inverse_scale_v); + v1 = vmulq_f32(v1, inverse_scale_v); + + int32x4_t i0 = vcvtnq_s32_f32(v0); + int32x4_t i1 = vcvtnq_s32_f32(v1); + + uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v); + uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v); + + if constexpr (num_elem_per_byte == 1) { + svst1b_u32( + svptrue_b8(), output_row, svset_neonq_u32(svundef_u32(), u0)); + svst1b_u32( + svptrue_b8(), output_row + 4, svset_neonq_u32(svundef_u32(), u1)); + } else { + constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30; + + uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar; + uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar; + + u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0)); + u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1)); + + if constexpr (num_elem_per_byte == 2) { + svst1b_u64( + svptrue_b8(), output_row, svset_neonq_u64(svundef_u64(), u2)); + svst1b_u64( + svptrue_b8(), output_row + 2, svset_neonq_u64(svundef_u64(), u3)); + + } else if constexpr (num_elem_per_byte == 4) { + auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8); + auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8); + + u4 = u4 << 4; + u5 = u5 << 4; + + u4 = veor_u8(u4, vget_low_u8(u2)); + u5 = veor_u8(u5, vget_low_u8(u3)); + + vst1_lane_u8(output_row, u4, 0); + vst1_lane_u8(output_row + 1, u5, 0); + } + } + + constexpr uint64_t bytesStored = kItemsPerIter / num_elem_per_byte; + output_row += bytesStored; + } + + if (loopRemainder > 0) { + float32x4_t v0; + float32x4_t v1; + + if constexpr (std::is_same()) { + v0 = svget_neonq(svld1_f32(lastPredA, input_row)); + v1 = svget_neonq(svld1_f32(lastPredB, input_row + 4)); + } else { + auto h0 = svld1uh_u32( + lastPredA, reinterpret_cast(input_row)); + auto h1 = svld1uh_u32( + lastPredB, reinterpret_cast(input_row + 4)); + v0 = svget_neonq( + svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h0))); + v1 = svget_neonq( + svcvt_f32_f16_x(svptrue_b8(), svreinterpret_f16_u32(h1))); + } + + v0 = vsubq_f32(v0, min_v); + v1 = vsubq_f32(v1, min_v); + + v0 = vmulq_f32(v0, inverse_scale_v); + v1 = vmulq_f32(v1, inverse_scale_v); + + int32x4_t i0 = vcvtnq_s32_f32(v0); + int32x4_t i1 = vcvtnq_s32_f32(v1); + + uint32x4_t u0 = vminq_u32(vreinterpretq_u32_s32(i0), maxval_v); + uint32x4_t u1 = vminq_u32(vreinterpretq_u32_s32(i1), maxval_v); + + if constexpr (num_elem_per_byte == 1) { + svst1b_u32(lastPredA, output_row, svset_neonq_u32(svundef_u32(), u0)); + svst1b_u32( + lastPredB, output_row + 4, svset_neonq_u32(svundef_u32(), u1)); + } else { + constexpr uint64_t shiftVar = num_elem_per_byte == 2 ? 28 : 30; + + uint64x2_t u2 = vreinterpretq_u64_u32(u0) >> shiftVar; + uint64x2_t u3 = vreinterpretq_u64_u32(u1) >> shiftVar; + + u2 = veorq_u64(u2, vreinterpretq_u64_u32(u0)); + u3 = veorq_u64(u3, vreinterpretq_u64_u32(u1)); + + if constexpr (num_elem_per_byte == 2) { + svst1b_u64(lastPredA, output_row, svset_neonq_u64(svundef_u64(), u2)); + svst1b_u64( + lastPredB, output_row + 2, svset_neonq_u64(svundef_u64(), u3)); + + } else if constexpr (num_elem_per_byte == 4) { + auto u4 = vdup_laneq_u8(vreinterpretq_u8_u64(u2), 8); + auto u5 = vdup_laneq_u8(vreinterpretq_u8_u64(u3), 8); + + u4 = u4 << 4; + u5 = u5 << 4; + + u4 = veor_u8(u4, vget_low_u8(u2)); + u5 = veor_u8(u5, vget_low_u8(u3)); + + vst1_lane_u8(output_row, u4, 0); + if (loopRemainder > 4) { + vst1_lane_u8(output_row + 1, u5, 0); + } + } + } + } + } // for each row +} + template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, @@ -372,6 +606,24 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16) // clang-format on #undef INSTANTIATE_QuantizationNeonFunctions8Bits +#define INSTANTIATE_QuantizationNeonFunctionsNBits(type, bit_rate) \ + template void \ + FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( \ + const type* input, \ + size_t input_rows, \ + int input_columns, \ + std::uint8_t* output); + + // clang-format off +INSTANTIATE_QuantizationNeonFunctionsNBits(float, 2) +INSTANTIATE_QuantizationNeonFunctionsNBits(float, 4) +INSTANTIATE_QuantizationNeonFunctionsNBits(float, 8) +INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 2) +INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 4) +INSTANTIATE_QuantizationNeonFunctionsNBits(float16, 8) +// clang-format on +#undef INSTANTIATE_QuantizationNeonFunctionsNBits + #endif // HAVE_SVE } // namespace fbgemm