Skip to content

Commit fe324d6

Browse files
mcfimeta-codesync[bot]
authored andcommitted
Vectorize requantize_ for Arm64 with NEON intrinsics (#5130)
Summary: Pull Request resolved: #5130 X-link: https://github.com/facebookresearch/FBGEMM/pull/2132 This change added a vectorized requantize_ for Arm64 with NEON intrinsics: 1. The newly added NEON intrinsics follows what the existing AVX2 code does. 2. The scalar loop was moved to a new function requantize_i8dw_ref_ to make the code more readable and testable. 3. Added new tests to make requantize_ and requantize_i8dw_ref_ produce identical results. Differential Revision: D86216347
1 parent 1fd545d commit fe324d6

File tree

6 files changed

+769
-53
lines changed

6 files changed

+769
-53
lines changed

src/FbgemmI8Depthwise2DAvx2-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#pragma once
1010

11-
#include "./FbgemmI8DepthwiseAvx2-inl.h" // @manual
11+
#include "./FbgemmI8DepthwiseUtils.h" // @manual
1212
#include "./GenerateI8Depthwise.h" // @manual
1313
#include "./MaskAvx2.h" // @manual
1414
#include "fbgemm/Utils.h"

src/FbgemmI8Depthwise3DAvx2.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <stdexcept> // for logic_error
1313
#include <string>
1414

15-
#include "./FbgemmI8DepthwiseAvx2-inl.h" // @manual
15+
#include "./FbgemmI8DepthwiseUtils.h" // @manual
1616
#include "./GenerateI8Depthwise.h" // @manual
1717
#include "./MaskAvx2.h" // @manual
1818
#include "fbgemm/Utils.h"

src/FbgemmI8DepthwiseUtils.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <math.h>
12+
#include <algorithm> // for min and max
13+
#include <cassert>
14+
#include <cmath> // for lrintf and sqrt
15+
#include <cstdint>
16+
#include <type_traits> // for is_same
17+
18+
#include "fbgemm/FbgemmBuild.h"
19+
#include "fbgemm/UtilsAvx2.h"
20+
21+
namespace fbgemm {
22+
23+
// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
24+
// row_offsets for each row because of depth-wise convolution
25+
26+
template <
27+
bool FUSE_RELU,
28+
bool HAS_BIAS,
29+
QuantizationGranularity Q_GRAN,
30+
bool A_SYMMETRIC,
31+
bool B_SYMMETRIC,
32+
int K_PER_G,
33+
typename BIAS_TYPE>
34+
static ALWAYS_INLINE void requantize_i8dw_ref_(
35+
std::int32_t A_zero_point,
36+
const std::int32_t* B_zero_point,
37+
const float* C_multiplier,
38+
std::int32_t C_zero_point,
39+
const std::int32_t* C_int32,
40+
std::uint8_t* C_uint8,
41+
int n,
42+
int j, // starting index
43+
const std::int32_t* row_offsets,
44+
const std::int32_t* col_offsets,
45+
const BIAS_TYPE* bias [[maybe_unused]],
46+
const float* act_times_w_scale = nullptr) {
47+
for (; j < n; ++j) {
48+
std::int32_t raw = C_int32[j];
49+
int quant_param_idx = 0;
50+
if constexpr (
51+
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
52+
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
53+
quant_param_idx = j;
54+
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
55+
quant_param_idx = j / 2;
56+
}
57+
if constexpr (!B_SYMMETRIC) {
58+
raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G];
59+
}
60+
if constexpr (!A_SYMMETRIC) {
61+
raw -= A_zero_point * col_offsets[j];
62+
}
63+
float raw_f = NAN;
64+
if constexpr (HAS_BIAS) { // static if
65+
if constexpr (std::is_same_v<BIAS_TYPE, float>) {
66+
raw_f = raw;
67+
raw_f += bias[j] / act_times_w_scale[quant_param_idx];
68+
} else {
69+
raw += bias[j];
70+
raw_f = raw;
71+
}
72+
} else {
73+
raw_f = raw;
74+
}
75+
76+
float ab = raw_f * C_multiplier[quant_param_idx];
77+
long rounded = lrintf(ab) + C_zero_point;
78+
79+
C_uint8[j] = std::max(
80+
FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
81+
std::min(255l, rounded));
82+
}
83+
}
84+
85+
static inline std::pair<int, int> closest_factors_(int n) {
86+
int a = static_cast<int>(std::sqrt(n));
87+
while (n % a != 0) {
88+
a--;
89+
}
90+
return {a, n / a}; // a <= n / a
91+
}
92+
93+
} // namespace fbgemm
94+
95+
#include "FbgemmI8DepthwiseUtilsAvx2.h"
96+
#include "FbgemmI8DepthwiseUtilsNeon.h"

src/FbgemmI8DepthwiseAvx2-inl.h renamed to src/FbgemmI8DepthwiseUtilsAvx2.h

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,21 @@
88

99
#pragma once
1010

11-
#include <algorithm> // for min and max
11+
#if defined(__x86_64__) || defined(__i386__) || \
12+
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
13+
1214
#include <cassert>
1315
#include <cmath> // for lrintf and sqrt
1416
#include <cstdint>
1517
#include <type_traits> // for is_same
1618

17-
#if defined(__x86_64__) || defined(__i386__) || \
18-
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
1919
#include <immintrin.h>
20-
#include <math.h>
21-
#endif
2220

2321
#include "fbgemm/FbgemmBuild.h"
2422
#include "fbgemm/UtilsAvx2.h"
2523

2624
namespace fbgemm {
2725

28-
// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
29-
// row_offsets for each row because of depth-wise convolution
3026
template <
3127
bool FUSE_RELU,
3228
bool HAS_BIAS,
@@ -503,50 +499,28 @@ static ALWAYS_INLINE void requantize_(
503499
_mm256_castsi256_si128(x_clamped_v));
504500
} // j loop vectorized
505501

506-
for (; j < n; ++j) {
507-
std::int32_t raw = C_int32[j];
508-
int quant_param_idx = 0;
509-
if constexpr (
510-
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
511-
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
512-
quant_param_idx = j;
513-
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
514-
quant_param_idx = j / 2;
515-
}
516-
if constexpr (!B_SYMMETRIC) {
517-
raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G];
518-
}
519-
if constexpr (!A_SYMMETRIC) {
520-
raw -= A_zero_point * col_offsets[j];
521-
}
522-
float raw_f = NAN;
523-
if constexpr (HAS_BIAS) { // static if
524-
if constexpr (std::is_same_v<BIAS_TYPE, float>) {
525-
raw_f = raw;
526-
raw_f += bias[j] / act_times_w_scale[quant_param_idx];
527-
} else {
528-
raw += bias[j];
529-
raw_f = raw;
530-
}
531-
} else {
532-
raw_f = raw;
533-
}
534-
535-
float ab = raw_f * C_multiplier[quant_param_idx];
536-
long rounded = lrintf(ab) + C_zero_point;
537-
538-
C_uint8[j] = std::max(
539-
FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
540-
std::min(255l, rounded));
541-
}
542-
}
543-
544-
static inline std::pair<int, int> closest_factors_(int n) {
545-
int a = static_cast<int>(std::sqrt(n));
546-
while (n % a != 0) {
547-
a--;
548-
}
549-
return {a, n / a}; // a <= n / a
502+
requantize_i8dw_ref_<
503+
FUSE_RELU,
504+
HAS_BIAS,
505+
Q_GRAN,
506+
A_SYMMETRIC,
507+
B_SYMMETRIC,
508+
K_PER_G,
509+
BIAS_TYPE>(
510+
A_zero_point,
511+
B_zero_point,
512+
C_multiplier,
513+
C_zero_point,
514+
C_int32,
515+
C_uint8,
516+
n,
517+
j,
518+
row_offsets,
519+
col_offsets,
520+
bias,
521+
act_times_w_scale);
550522
}
551523

552524
} // namespace fbgemm
525+
526+
#endif

0 commit comments

Comments
 (0)