From b557c974f250daa6797f38f06f18a3e943b2f964 Mon Sep 17 00:00:00 2001 From: Ben Niu Date: Tue, 18 Nov 2025 14:13:39 -0800 Subject: [PATCH] Vectorize requantize_ for Arm64 with NEON intrinsics (#5130) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5130 X-link: https://github.com/facebookresearch/FBGEMM/pull/2132 This change added a vectorized requantize_ for Arm64 with NEON intrinsics: 1. The newly added NEON intrinsics follows what the existing AVX2 code does. 2. The scalar loop was moved to a new function requantize_i8dw_ref_ to make the code more readable and testable. 3. Added new tests to make sure requantize_ and requantize_i8dw_ref_ produce identical results. Reviewed By: Nicoshev Differential Revision: D86216347 --- src/FbgemmI8Depthwise2DAvx2-inl.h | 2 +- src/FbgemmI8Depthwise3DAvx2.cc | 2 +- src/FbgemmI8DepthwiseAvx2-inl.h | 80 ++---- src/FbgemmI8DepthwiseNeon-inl.h | 418 ++++++++++++++++++++++++++++++ src/FbgemmI8DepthwiseUtils.h | 96 +++++++ test/I8DepthwiseTest.cc | 219 ++++++++++++++++ 6 files changed, 763 insertions(+), 54 deletions(-) create mode 100644 src/FbgemmI8DepthwiseNeon-inl.h create mode 100644 src/FbgemmI8DepthwiseUtils.h diff --git a/src/FbgemmI8Depthwise2DAvx2-inl.h b/src/FbgemmI8Depthwise2DAvx2-inl.h index 10a1c83b1e..a17c3bc456 100644 --- a/src/FbgemmI8Depthwise2DAvx2-inl.h +++ b/src/FbgemmI8Depthwise2DAvx2-inl.h @@ -8,7 +8,7 @@ #pragma once -#include "./FbgemmI8DepthwiseAvx2-inl.h" // @manual +#include "./FbgemmI8DepthwiseUtils.h" // @manual #include "./GenerateI8Depthwise.h" // @manual #include "./MaskAvx2.h" // @manual #include "fbgemm/Utils.h" diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc index 4200f3e991..4a3023f4dd 100644 --- a/src/FbgemmI8Depthwise3DAvx2.cc +++ b/src/FbgemmI8Depthwise3DAvx2.cc @@ -12,7 +12,7 @@ #include // for logic_error #include -#include "./FbgemmI8DepthwiseAvx2-inl.h" // @manual +#include "./FbgemmI8DepthwiseUtils.h" // @manual #include "./GenerateI8Depthwise.h" // @manual #include "./MaskAvx2.h" // @manual #include "fbgemm/Utils.h" diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h index 4701f07c30..4eeeedb889 100644 --- a/src/FbgemmI8DepthwiseAvx2-inl.h +++ b/src/FbgemmI8DepthwiseAvx2-inl.h @@ -8,25 +8,21 @@ #pragma once -#include // for min and max +#if defined(__x86_64__) || defined(__i386__) || \ + (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) + #include #include // for lrintf and sqrt #include #include // for is_same -#if defined(__x86_64__) || defined(__i386__) || \ - (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) #include -#include -#endif #include "fbgemm/FbgemmBuild.h" #include "fbgemm/UtilsAvx2.h" namespace fbgemm { -// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different -// row_offsets for each row because of depth-wise convolution template < bool FUSE_RELU, bool HAS_BIAS, @@ -47,6 +43,8 @@ static ALWAYS_INLINE void requantize_( const std::int32_t* col_offsets, const BIAS_TYPE* bias [[maybe_unused]], const float* act_times_w_scale = nullptr) { + int j = 0; +#ifdef __AVX2__ __m256 multiplier_v = _mm256_setzero_ps(); // Broadcasted reciprocal of act_times_w_scale __m256 act_times_w_rcp_v [[maybe_unused]] = _mm256_setzero_ps(); @@ -73,7 +71,6 @@ static ALWAYS_INLINE void requantize_( _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); constexpr int VLEN = 8; - int j = 0; for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { __m256i x_v = _mm256_loadu_si256(reinterpret_cast(C_int32 + j)); @@ -502,51 +499,30 @@ static ALWAYS_INLINE void requantize_( reinterpret_cast<__m128i*>(C_uint8 + j), _mm256_castsi256_si128(x_clamped_v)); } // j loop vectorized +#endif - for (; j < n; ++j) { - std::int32_t raw = C_int32[j]; - int quant_param_idx = 0; - if constexpr ( - Q_GRAN == QuantizationGranularity::OUT_CHANNEL || - (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { - quant_param_idx = j; - } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - quant_param_idx = j / 2; - } - if constexpr (!B_SYMMETRIC) { - raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G]; - } - if constexpr (!A_SYMMETRIC) { - raw -= A_zero_point * col_offsets[j]; - } - float raw_f = NAN; - if constexpr (HAS_BIAS) { // static if - if constexpr (std::is_same_v) { - raw_f = raw; - raw_f += bias[j] / act_times_w_scale[quant_param_idx]; - } else { - raw += bias[j]; - raw_f = raw; - } - } else { - raw_f = raw; - } - - float ab = raw_f * C_multiplier[quant_param_idx]; - long rounded = lrintf(ab) + C_zero_point; - - C_uint8[j] = std::max( - FUSE_RELU ? static_cast(C_zero_point) : 0l, - std::min(255l, rounded)); - } -} - -static inline std::pair closest_factors_(int n) { - int a = static_cast(std::sqrt(n)); - while (n % a != 0) { - a--; - } - return {a, n / a}; // a <= n / a + requantize_i8dw_ref_< + FUSE_RELU, + HAS_BIAS, + Q_GRAN, + A_SYMMETRIC, + B_SYMMETRIC, + K_PER_G, + BIAS_TYPE>( + A_zero_point, + B_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8, + n, + j, + row_offsets, + col_offsets, + bias, + act_times_w_scale); } } // namespace fbgemm + +#endif diff --git a/src/FbgemmI8DepthwiseNeon-inl.h b/src/FbgemmI8DepthwiseNeon-inl.h new file mode 100644 index 0000000000..2366bbf66b --- /dev/null +++ b/src/FbgemmI8DepthwiseNeon-inl.h @@ -0,0 +1,418 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#if defined(__aarch64__) || (defined(_MSC_VER) && defined(_M_ARM64)) + +#include +#include +#include // for lrintf and sqrt +#include +#include // for is_same + +#include + +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/UtilsAvx2.h" + +namespace fbgemm { + +template < + bool FUSE_RELU, + bool HAS_BIAS, + QuantizationGranularity Q_GRAN, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + int K_PER_G, + typename BIAS_TYPE> +static ALWAYS_INLINE void requantize_( + std::int32_t A_zero_point, + const std::int32_t* B_zero_point, + const float* C_multiplier, + std::int32_t C_zero_point, + const std::int32_t* C_int32, + std::uint8_t* C_uint8, + int n, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias [[maybe_unused]], + const float* act_times_w_scale = nullptr) { + float32x4_t multiplier_v = vdupq_n_f32(0.0f); + // Broadcasted reciprocal of act_times_w_scale + float32x4_t act_times_w_rcp_v [[maybe_unused]] = vdupq_n_f32(0.0f); + int32x4_t B_zero_point_v = vdupq_n_s32(0); + if constexpr (Q_GRAN == QuantizationGranularity::TENSOR) { + multiplier_v = vdupq_n_f32(*C_multiplier); + if constexpr (std::is_same_v) { + act_times_w_rcp_v = vdupq_n_f32(1.0 / (*act_times_w_scale)); + } + B_zero_point_v = vdupq_n_s32(B_zero_point[0]); + } + + uint8x16_t min_v = vdupq_n_u8(0); + + if constexpr (A_SYMMETRIC) { + assert(A_zero_point == 0 || col_offsets == nullptr); + } + int32x4_t A_zero_point_v = vdupq_n_s32(A_zero_point); + int16x8_t C_zero_point_epi16_v = vdupq_n_s16(C_zero_point); + int8x16_t C_zero_point_epi8_v = vdupq_n_s8(C_zero_point); + + constexpr int VLEN = 4; + int j = 0; + for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { + int32x4_t x_v = vld1q_s32(C_int32 + j); + int32x4_t y_v = vld1q_s32(C_int32 + j + VLEN); + int32x4_t z_v = vld1q_s32(C_int32 + j + 2 * VLEN); + int32x4_t w_v = vld1q_s32(C_int32 + j + 3 * VLEN); + + int32x4_t row_offset_v; + if constexpr (!B_SYMMETRIC) { + if constexpr (K_PER_G == 1) { + row_offset_v = vld1q_s32(row_offsets + j); + } else { + static_assert(K_PER_G == 2); + // Load row_offsets for 2 groups and broadcast by 2 times. + row_offset_v = + vcombine_s32(vld1_s32(row_offsets + j / 2), vdup_n_s32(0)); + row_offset_v = vzip1q_u32(row_offset_v, row_offset_v); + } + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + B_zero_point_v = vld1q_s32(B_zero_point + j); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + static_assert(K_PER_G == 2); + B_zero_point_v = + vcombine_s32(vld1_s32(B_zero_point + j / 2), vdup_n_s32(0)); + B_zero_point_v = vzip1q_u32(B_zero_point_v, B_zero_point_v); + } + x_v = vmlsq_s32(x_v, row_offset_v, B_zero_point_v); + } + int32x4_t col_off_v; + if constexpr (!A_SYMMETRIC) { + x_v = vmlsq_s32(x_v, A_zero_point_v, vld1q_s32(col_offsets + j)); + } + + if constexpr (!B_SYMMETRIC) { + if constexpr (K_PER_G == 1) { + row_offset_v = vld1q_s32(row_offsets + j + VLEN); + } else { + row_offset_v = + vcombine_s32(vld1_s32(row_offsets + (j + VLEN) / 2), vdup_n_s32(0)); + row_offset_v = vzip1q_u32(row_offset_v, row_offset_v); + } + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + B_zero_point_v = vld1q_s32(B_zero_point + j + VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = vcombine_s32( + vld1_s32(B_zero_point + (j + VLEN) / 2), vdup_n_s32(0)); + B_zero_point_v = vzip1q_u32(B_zero_point_v, B_zero_point_v); + } + y_v = vmlsq_s32(y_v, row_offset_v, B_zero_point_v); + } + if constexpr (!A_SYMMETRIC) { + y_v = vmlsq_s32(y_v, A_zero_point_v, vld1q_s32(col_offsets + j + VLEN)); + } + + if constexpr (!B_SYMMETRIC) { + if constexpr (K_PER_G == 1) { + row_offset_v = vld1q_s32(row_offsets + j + 2 * VLEN); + } else { + row_offset_v = vcombine_s32( + vld1_s32(row_offsets + (j + 2 * VLEN) / 2), vdup_n_s32(0)); + row_offset_v = vzip1q_u32(row_offset_v, row_offset_v); + } + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + B_zero_point_v = vld1q_s32(B_zero_point + j + 2 * VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = vcombine_s32( + vld1_s32(B_zero_point + (j + 2 * VLEN) / 2), vdup_n_s32(0)); + B_zero_point_v = vzip1q_u32(B_zero_point_v, B_zero_point_v); + } + z_v = vmlsq_s32(z_v, row_offset_v, B_zero_point_v); + } + if constexpr (!A_SYMMETRIC) { + z_v = + vmlsq_s32(z_v, A_zero_point_v, vld1q_s32(col_offsets + j + 2 * VLEN)); + } + + if constexpr (!B_SYMMETRIC) { + if constexpr (K_PER_G == 1) { + row_offset_v = vld1q_s32(row_offsets + j + 3 * VLEN); + } else { + row_offset_v = vcombine_s32( + vld1_s32(row_offsets + (j + 3 * VLEN) / 2), vdup_n_s32(0)); + row_offset_v = vzip1q_u32(row_offset_v, row_offset_v); + } + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + B_zero_point_v = vld1q_s32(B_zero_point + j + 3 * VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = vcombine_s32( + vld1_s32(B_zero_point + (j + 3 * VLEN) / 2), vdup_n_s32(0)); + B_zero_point_v = vzip1q_u32(B_zero_point_v, B_zero_point_v); + } + w_v = vmlsq_s32(w_v, row_offset_v, B_zero_point_v); + } + if constexpr (!A_SYMMETRIC) { + w_v = + vmlsq_s32(w_v, A_zero_point_v, vld1q_s32(col_offsets + j + 3 * VLEN)); + } + + // convert to float + float32x4_t xf_v, yf_v, zf_v, wf_v; + if constexpr (HAS_BIAS) { // static if + if constexpr (std::is_same_v) { + float32x4_t x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + x_bias_v = vdivq_f32( + vld1q_f32(bias + j + 0 * VLEN), + vld1q_f32(act_times_w_scale + j + 0 * VLEN)); + y_bias_v = vdivq_f32( + vld1q_f32(bias + j + 1 * VLEN), + vld1q_f32(act_times_w_scale + j + 1 * VLEN)); + z_bias_v = vdivq_f32( + vld1q_f32(bias + j + 2 * VLEN), + vld1q_f32(act_times_w_scale + j + 2 * VLEN)); + w_bias_v = vdivq_f32( + vld1q_f32(bias + j + 3 * VLEN), + vld1q_f32(act_times_w_scale + j + 3 * VLEN)); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + static_assert(K_PER_G == 2); + auto tmp = vcombine_f32( + vld1_f32(act_times_w_scale + (j + 0 * VLEN) / 2), + vdup_n_f32(0.0f)); + + x_bias_v = + vdivq_f32(vld1q_f32(bias + j + 0 * VLEN), vzip1q_f32(tmp, tmp)); + + tmp = vcombine_f32( + vld1_f32(act_times_w_scale + (j + 1 * VLEN) / 2), + vdup_n_f32(0.0f)); + y_bias_v = + vdivq_f32(vld1q_f32(bias + j + 1 * VLEN), vzip1q_f32(tmp, tmp)); + + tmp = vcombine_f32( + vld1_f32(act_times_w_scale + (j + 2 * VLEN) / 2), + vdup_n_f32(0.0f)); + z_bias_v = + vdivq_f32(vld1q_f32(bias + j + 2 * VLEN), vzip1q_f32(tmp, tmp)); + + tmp = vcombine_f32( + vld1_f32(act_times_w_scale + (j + 3 * VLEN) / 2), + vdup_n_f32(0.0f)); + w_bias_v = + vdivq_f32(vld1q_f32(bias + j + 3 * VLEN), vzip1q_f32(tmp, tmp)); + + } else { + x_bias_v = + vmulq_f32(vld1q_f32(bias + j + 0 * VLEN), act_times_w_rcp_v); + y_bias_v = + vmulq_f32(vld1q_f32(bias + j + 1 * VLEN), act_times_w_rcp_v); + z_bias_v = + vmulq_f32(vld1q_f32(bias + j + 2 * VLEN), act_times_w_rcp_v); + w_bias_v = + vmulq_f32(vld1q_f32(bias + j + 3 * VLEN), act_times_w_rcp_v); + } + xf_v = vaddq_f32(vcvtq_f32_s32(x_v), x_bias_v); + yf_v = vaddq_f32(vcvtq_f32_s32(y_v), y_bias_v); + zf_v = vaddq_f32(vcvtq_f32_s32(z_v), z_bias_v); + wf_v = vaddq_f32(vcvtq_f32_s32(w_v), w_bias_v); + } else { + x_v = vaddq_s32(x_v, vld1q_s32(bias + j + 0 * VLEN)); + y_v = vaddq_s32(y_v, vld1q_s32(bias + j + 1 * VLEN)); + z_v = vaddq_s32(z_v, vld1q_s32(bias + j + 2 * VLEN)); + w_v = vaddq_s32(w_v, vld1q_s32(bias + j + 3 * VLEN)); + xf_v = vcvtq_f32_s32(x_v); + yf_v = vcvtq_f32_s32(y_v); + zf_v = vcvtq_f32_s32(z_v); + wf_v = vcvtq_f32_s32(w_v); + } + } else { + xf_v = vcvtq_f32_s32(x_v); + yf_v = vcvtq_f32_s32(y_v); + zf_v = vcvtq_f32_s32(z_v); + wf_v = vcvtq_f32_s32(w_v); + } + + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + multiplier_v = vld1q_f32(C_multiplier + j + 0 * VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + multiplier_v = + vcombine_f32(vld1_f32(C_multiplier + j / 2), vdup_n_f32(0.0f)); + multiplier_v = vzip1q_u32(multiplier_v, multiplier_v); + } + float32x4_t x_scaled_v = vmulq_f32(xf_v, multiplier_v); + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + multiplier_v = vld1q_f32(C_multiplier + j + 1 * VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + multiplier_v = vcombine_f32( + vld1_f32(C_multiplier + (j + VLEN) / 2), vdup_n_f32(0.0f)); + multiplier_v = vzip1q_u32(multiplier_v, multiplier_v); + } + float32x4_t y_scaled_v = vmulq_f32(yf_v, multiplier_v); + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + multiplier_v = vld1q_f32(C_multiplier + j + 2 * VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + multiplier_v = vcombine_f32( + vld1_f32(C_multiplier + (j + 2 * VLEN) / 2), vdup_n_f32(0.0f)); + multiplier_v = vzip1q_u32(multiplier_v, multiplier_v); + } + float32x4_t z_scaled_v = vmulq_f32(zf_v, multiplier_v); + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + multiplier_v = vld1q_f32(C_multiplier + j + 3 * VLEN); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + multiplier_v = vcombine_f32( + vld1_f32(C_multiplier + (j + 3 * VLEN) / 2), vdup_n_f32(0.0f)); + multiplier_v = vzip1q_u32(multiplier_v, multiplier_v); + } + float32x4_t w_scaled_v = vmulq_f32(wf_v, multiplier_v); + + // vcvtnq_s32_f32 always rounds to nearest, which is slightly different + // from x86's _mm256_cvtps_epi32 which rounds according to the current + // rounding mode, which may not be round to nearest. To help catch issues + // and debug, we add an assertion here. + assert(fegetround() == FE_TONEAREST); + int32x4_t x_rounded_v = vcvtnq_s32_f32(x_scaled_v); + int32x4_t y_rounded_v = vcvtnq_s32_f32(y_scaled_v); + int32x4_t z_rounded_v = vcvtnq_s32_f32(z_scaled_v); + int32x4_t w_rounded_v = vcvtnq_s32_f32(w_scaled_v); + + int16x8_t xy_packed_v = vqaddq_s16( + vcombine_s16(vqmovn_s32(x_rounded_v), vqmovn_s32(y_rounded_v)), + C_zero_point_epi16_v); + int16x8_t zw_packed_v = vqaddq_s16( + vcombine_s16(vqmovn_s32(z_rounded_v), vqmovn_s32(w_rounded_v)), + C_zero_point_epi16_v); + uint8x16_t xyzw_packed_v = + vcombine_u8(vqmovun_s16(xy_packed_v), vqmovun_s16(zw_packed_v)); + uint8x16_t xyzw_clamped_v = + vmaxq_u8(FUSE_RELU ? C_zero_point_epi8_v : min_v, xyzw_packed_v); + + vst1q_u8(C_uint8 + j, xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + +vec_tail: + for (; j < n / VLEN * VLEN; j += VLEN) { + int32x4_t x_v = vld1q_s32(C_int32 + j); + + if constexpr (!B_SYMMETRIC) { + int32x4_t row_offset_v; + if constexpr (K_PER_G == 1) { + row_offset_v = vld1q_s32(row_offsets + j); + } else { + static_assert(K_PER_G == 2); + // Load row_offsets for 2 groups and broadcast by 2 times. + row_offset_v = + vcombine_s32(vld1_s32(row_offsets + j / 2), vdup_n_s32(0)); + row_offset_v = vzip1q_u32(row_offset_v, row_offset_v); + } + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + B_zero_point_v = vld1q_s32(B_zero_point + j); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + static_assert(K_PER_G == 2); + B_zero_point_v = + vcombine_s32(vld1_s32(B_zero_point + j / 2), vdup_n_s32(0)); + B_zero_point_v = vzip1q_u32(B_zero_point_v, B_zero_point_v); + } + x_v = vmlsq_s32(x_v, row_offset_v, B_zero_point_v); + } + if constexpr (!A_SYMMETRIC) { + x_v = vmlsq_s32(x_v, A_zero_point_v, vld1q_s32(col_offsets + j)); + } + + // Convert to float + float32x4_t xf_v; + if constexpr (HAS_BIAS) { // static if + if constexpr (std::is_same_v) { + float32x4_t x_bias_v; + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + x_bias_v = + vdivq_f32(vld1q_f32(bias + j), vld1q_f32(act_times_w_scale + j)); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + auto tmp = vcombine_f32( + vld1_f32(act_times_w_scale + j / 2), vdup_n_f32(0.0f)); + + x_bias_v = vdivq_f32(vld1q_f32(bias + j), vzip1q_f32(tmp, tmp)); + } else { + x_bias_v = vmulq_f32(vld1q_f32(bias + j), act_times_w_rcp_v); + } + xf_v = vaddq_f32(vcvtq_f32_s32(x_v), x_bias_v); + } else { + x_v = vaddq_s32( + x_v, vld1q_s32(reinterpret_cast(bias + j))); + xf_v = vcvtq_f32_s32(x_v); + } + } else { + xf_v = vcvtq_f32_s32(x_v); + } + + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + multiplier_v = vld1q_f32(C_multiplier + j); + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + multiplier_v = + vcombine_f32(vld1_f32(C_multiplier + j / 2), vdup_n_f32(0.0f)); + multiplier_v = vzip1q_u32(multiplier_v, multiplier_v); + } + float32x4_t x_scaled_v = vmulq_f32(xf_v, multiplier_v); + // vcvtnq_s32_f32 always rounds to nearest, which is slightly different + // from x86's _mm256_cvtps_epi32 which rounds according to the current + // rounding mode, which may not be round to nearest. To help catch issues + // and debug, we add an assertion here. + assert(fegetround() == FE_TONEAREST); + int32x4_t x_rounded_v = vcvtnq_s32_f32(x_scaled_v); + + int16x8_t x_packed_v_s16 = vqaddq_s16( + vcombine_s16(vqmovn_s32(x_rounded_v), vdup_n_s16(0)), + C_zero_point_epi16_v); + uint8x8_t x_packed_v_u8 = vqmovun_s16(x_packed_v_s16); + uint8x8_t x_clamped_v = vmax_u8( + FUSE_RELU ? vget_low_u8(C_zero_point_epi8_v) : vget_low_u8(min_v), + x_packed_v_u8); + + vst1_lane_u32(C_uint8 + j, vreinterpret_u32_u8(x_clamped_v), 0); + } // j loop vectorized + + // There are some leftovers that cannot fit in one full vector. Instead of + // doing a scalar loop, we prepare j to be n - VLEN and jump back to the + // above loop for one extra iteration. Compared to a scalar loop, this reuses + // vector loop code so code size bloat is minimal. Another alternative is + // to use a partial vector register, but that also bloats code size more than + // reusing the above loop body. + if (j < n) { + j = n - VLEN; + goto vec_tail; + } +} + +} // namespace fbgemm + +#endif diff --git a/src/FbgemmI8DepthwiseUtils.h b/src/FbgemmI8DepthwiseUtils.h new file mode 100644 index 0000000000..83d2d61a94 --- /dev/null +++ b/src/FbgemmI8DepthwiseUtils.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include // for min and max +#include +#include +#include // for lrintf and sqrt +#include +#include // for is_same + +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/UtilsAvx2.h" + +namespace fbgemm { + +// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different +// row_offsets for each row because of depth-wise convolution + +template < + bool FUSE_RELU, + bool HAS_BIAS, + QuantizationGranularity Q_GRAN, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + int K_PER_G, + typename BIAS_TYPE> +static ALWAYS_INLINE void requantize_i8dw_ref_( + std::int32_t A_zero_point, + const std::int32_t* B_zero_point, + const float* C_multiplier, + std::int32_t C_zero_point, + const std::int32_t* C_int32, + std::uint8_t* C_uint8, + int n, + int j, // starting index + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias [[maybe_unused]], + const float* act_times_w_scale = nullptr) { + for (; j < n; ++j) { + std::int32_t raw = C_int32[j]; + int quant_param_idx = 0; + if constexpr ( + Q_GRAN == QuantizationGranularity::OUT_CHANNEL || + (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { + quant_param_idx = j; + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + quant_param_idx = j / 2; + } + if constexpr (!B_SYMMETRIC) { + raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G]; + } + if constexpr (!A_SYMMETRIC) { + raw -= A_zero_point * col_offsets[j]; + } + float raw_f = NAN; + if constexpr (HAS_BIAS) { // static if + if constexpr (std::is_same_v) { + raw_f = raw; + raw_f += bias[j] / act_times_w_scale[quant_param_idx]; + } else { + raw += bias[j]; + raw_f = raw; + } + } else { + raw_f = raw; + } + + float ab = raw_f * C_multiplier[quant_param_idx]; + long rounded = lrintf(ab) + C_zero_point; + + C_uint8[j] = std::max( + FUSE_RELU ? static_cast(C_zero_point) : 0l, + std::min(255l, rounded)); + } +} + +static inline std::pair closest_factors_(int n) { + int a = static_cast(std::sqrt(n)); + while (n % a != 0) { + a--; + } + return {a, n / a}; // a <= n / a +} + +} // namespace fbgemm + +#include "FbgemmI8DepthwiseAvx2-inl.h" +#include "FbgemmI8DepthwiseNeon-inl.h" diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc index 6d4a682b77..96f9646604 100644 --- a/test/I8DepthwiseTest.cc +++ b/test/I8DepthwiseTest.cc @@ -11,9 +11,11 @@ #include +#include "./TestUtils.h" #include "bench/AlignedVec.h" // @manual #include "bench/BenchUtils.h" // @manual #include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "src/FbgemmI8DepthwiseUtils.h" #include "src/RefImplementations.h" // @manual using namespace std; @@ -105,6 +107,12 @@ class FBGemmDepthWisePerChannelQuantizationTest class FBGemmDepthWisePackUnpackTest : public testing::TestWithParam> {}; +// tuple represents QuantizationGranularity, A symmetric, B symmetric, +// test_bias, test_float_bias +class FbgemmI8DepthwiseRequantizationTest + : public testing::TestWithParam< + tuple> {}; + } // namespace INSTANTIATE_TEST_SUITE_P( @@ -702,4 +710,215 @@ TEST_P(FBGemmDepthWisePackUnpackTest, TestPackUnpack) { << "Original and unpacked data elements are not the same"; } // TestPackUnpack +static void runRequantizeI8DepthWiseTest() { + const int n = 70; // OC == 70 + aligned_vector B_zero_point(n, 0); + aligned_vector C_multiplier(n, 0.0f); + aligned_vector C_int32(n, 0); + aligned_vector row_offsets(n, 0); + aligned_vector col_offsets(n, 0); + aligned_vector act_times_w_scale(n, 0.0f); + aligned_vector bias(n, 0); + aligned_vector fbias(n, 0.0f); + aligned_vector C_int8_scalar(n, 0); + aligned_vector C_int8_vector(n, 0); + int32_t A_zero_point = 0; + int32_t C_zero_point = 0; + aligned_vector Zero_point(2, 0); + + randFill(Zero_point, 0, 10); + A_zero_point = Zero_point[0]; + C_zero_point = Zero_point[1]; + randFill(B_zero_point, -3, 3); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + randFill(C_int32, -5, 5); + randFill(row_offsets, -100, 100); + randFill(col_offsets, -100, 100); + randFill(bias, -8, 8); + + randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2); + + requantize_i8dw_ref_< + true, + true, + QuantizationGranularity::GROUP, + false, + false, + 1>( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_scalar.data(), + n, + 0, + row_offsets.data(), + col_offsets.data(), + bias.data(), + act_times_w_scale.data()); + + requantize_( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_vector.data(), + n, + row_offsets.data(), + col_offsets.data(), + bias.data(), + act_times_w_scale.data()); + + compare_validate_buffers( + C_int8_scalar.data(), + C_int8_vector.data(), + n, + 1, + 1, + static_cast(0)); + + requantize_i8dw_ref_< + true, + true, + QuantizationGranularity::GROUP, + false, + false, + 2>( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_scalar.data(), + n, + 0, + row_offsets.data(), + col_offsets.data(), + bias.data(), + act_times_w_scale.data()); + + requantize_( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_vector.data(), + n, + row_offsets.data(), + col_offsets.data(), + bias.data(), + act_times_w_scale.data()); + + compare_validate_buffers( + C_int8_scalar.data(), + C_int8_vector.data(), + n, + 1, + 1, + static_cast(0)); + + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + bias.begin(), + fbias.begin(), + multiplies()); + + requantize_i8dw_ref_< + true, + true, + QuantizationGranularity::GROUP, + false, + false, + 1>( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_scalar.data(), + n, + 0, + row_offsets.data(), + col_offsets.data(), + fbias.data(), + act_times_w_scale.data()); + + requantize_( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_vector.data(), + n, + row_offsets.data(), + col_offsets.data(), + fbias.data(), + act_times_w_scale.data()); + + compare_validate_buffers( + C_int8_scalar.data(), + C_int8_vector.data(), + n, + 1, + 1, + static_cast(0)); + + for (int g = 0; g < 2; ++g) { + for (int c = 0; c < n / 2; ++c) { + fbias[g * n / 2 + c] = + act_times_w_scale[g] * static_cast(bias[g * n / 2 + c]); + } + } + + requantize_i8dw_ref_< + true, + true, + QuantizationGranularity::GROUP, + false, + false, + 2>( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_scalar.data(), + n, + 0, + row_offsets.data(), + col_offsets.data(), + fbias.data(), + act_times_w_scale.data()); + + requantize_( + A_zero_point, + B_zero_point.data(), + C_multiplier.data(), + C_zero_point, + C_int32.data(), + C_int8_vector.data(), + n, + row_offsets.data(), + col_offsets.data(), + fbias.data(), + act_times_w_scale.data()); + + compare_validate_buffers( + C_int8_scalar.data(), + C_int8_vector.data(), + n, + 1, + 1, + static_cast(0)); +} + +TEST(FbgemmI8DepthwiseRequantizationTest, requantizeI8DepthWiseTest) { + runRequantizeI8DepthWiseTest(); +} + } // namespace fbgemm