Skip to content

Commit 643894e

Browse files
mcfimeta-codesync[bot]
authored andcommitted
Vectorize requantize_ for Arm64 with NEON intrinsics (#5130)
Summary: Pull Request resolved: #5130 X-link: https://github.com/facebookresearch/FBGEMM/pull/2132 This change added a vectorized requantize_ for Arm64 with NEON intrinsics: 1. The newly added NEON intrinsics follows what the existing AVX2 code does. 2. The scalar loop was moved to a new function requantize_i8dw_ref_ to make the code more readable and testable. 3. Added new tests to make sure requantize_ and requantize_i8dw_ref_ produce identical results. Reviewed By: Nicoshev Differential Revision: D86216347 fbshipit-source-id: 1e61ec1bd4ea6ff571de96c59b0191e082258878
1 parent 47e79da commit 643894e

File tree

6 files changed

+763
-54
lines changed

6 files changed

+763
-54
lines changed

src/FbgemmI8Depthwise2DAvx2-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#pragma once
1010

11-
#include "./FbgemmI8DepthwiseAvx2-inl.h" // @manual
11+
#include "./FbgemmI8DepthwiseUtils.h" // @manual
1212
#include "./GenerateI8Depthwise.h" // @manual
1313
#include "./MaskAvx2.h" // @manual
1414
#include "fbgemm/Utils.h"

src/FbgemmI8Depthwise3DAvx2.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <stdexcept> // for logic_error
1313
#include <string>
1414

15-
#include "./FbgemmI8DepthwiseAvx2-inl.h" // @manual
15+
#include "./FbgemmI8DepthwiseUtils.h" // @manual
1616
#include "./GenerateI8Depthwise.h" // @manual
1717
#include "./MaskAvx2.h" // @manual
1818
#include "fbgemm/Utils.h"

src/FbgemmI8DepthwiseAvx2-inl.h

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,21 @@
88

99
#pragma once
1010

11-
#include <algorithm> // for min and max
11+
#if defined(__x86_64__) || defined(__i386__) || \
12+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
13+
1214
#include <cassert>
1315
#include <cmath> // for lrintf and sqrt
1416
#include <cstdint>
1517
#include <type_traits> // for is_same
1618

17-
#if defined(__x86_64__) || defined(__i386__) || \
18-
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1919
#include <immintrin.h>
20-
#include <math.h>
21-
#endif
2220

2321
#include "fbgemm/FbgemmBuild.h"
2422
#include "fbgemm/UtilsAvx2.h"
2523

2624
namespace fbgemm {
2725

28-
// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
29-
// row_offsets for each row because of depth-wise convolution
3026
template <
3127
bool FUSE_RELU,
3228
bool HAS_BIAS,
@@ -47,6 +43,8 @@ static ALWAYS_INLINE void requantize_(
4743
const std::int32_t* col_offsets,
4844
const BIAS_TYPE* bias [[maybe_unused]],
4945
const float* act_times_w_scale = nullptr) {
46+
int j = 0;
47+
#ifdef __AVX2__
5048
__m256 multiplier_v = _mm256_setzero_ps();
5149
// Broadcasted reciprocal of act_times_w_scale
5250
__m256 act_times_w_rcp_v [[maybe_unused]] = _mm256_setzero_ps();
@@ -73,7 +71,6 @@ static ALWAYS_INLINE void requantize_(
7371
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
7472

7573
constexpr int VLEN = 8;
76-
int j = 0;
7774
for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
7875
__m256i x_v =
7976
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
@@ -502,51 +499,30 @@ static ALWAYS_INLINE void requantize_(
502499
reinterpret_cast<__m128i*>(C_uint8 + j),
503500
_mm256_castsi256_si128(x_clamped_v));
504501
} // j loop vectorized
502+
#endif
505503

506-
for (; j < n; ++j) {
507-
std::int32_t raw = C_int32[j];
508-
int quant_param_idx = 0;
509-
if constexpr (
510-
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
511-
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
512-
quant_param_idx = j;
513-
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
514-
quant_param_idx = j / 2;
515-
}
516-
if constexpr (!B_SYMMETRIC) {
517-
raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G];
518-
}
519-
if constexpr (!A_SYMMETRIC) {
520-
raw -= A_zero_point * col_offsets[j];
521-
}
522-
float raw_f = NAN;
523-
if constexpr (HAS_BIAS) { // static if
524-
if constexpr (std::is_same_v<BIAS_TYPE, float>) {
525-
raw_f = raw;
526-
raw_f += bias[j] / act_times_w_scale[quant_param_idx];
527-
} else {
528-
raw += bias[j];
529-
raw_f = raw;
530-
}
531-
} else {
532-
raw_f = raw;
533-
}
534-
535-
float ab = raw_f * C_multiplier[quant_param_idx];
536-
long rounded = lrintf(ab) + C_zero_point;
537-
538-
C_uint8[j] = std::max(
539-
FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
540-
std::min(255l, rounded));
541-
}
542-
}
543-
544-
static inline std::pair<int, int> closest_factors_(int n) {
545-
int a = static_cast<int>(std::sqrt(n));
546-
while (n % a != 0) {
547-
a--;
548-
}
549-
return {a, n / a}; // a <= n / a
504+
requantize_i8dw_ref_<
505+
FUSE_RELU,
506+
HAS_BIAS,
507+
Q_GRAN,
508+
A_SYMMETRIC,
509+
B_SYMMETRIC,
510+
K_PER_G,
511+
BIAS_TYPE>(
512+
A_zero_point,
513+
B_zero_point,
514+
C_multiplier,
515+
C_zero_point,
516+
C_int32,
517+
C_uint8,
518+
n,
519+
j,
520+
row_offsets,
521+
col_offsets,
522+
bias,
523+
act_times_w_scale);
550524
}
551525

552526
} // namespace fbgemm
527+
528+
#endif

0 commit comments

Comments
 (0)