Skip to content

Commit 21ea7ed

Browse files
mcfimeta-codesync[bot]
authored andcommitted
Fix tail handling in arm64 requantize_ (#5153)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2153 Pull Request resolved: #5153 This diff fixes incorrect tail handling that "goto" vector loop. Instead, it reuses most of the vector loop body with each vector's lane 0 filled. Essentially, although the tail uses vector instructions, it's a scalar loop with the same behavior as real vector loops. Reviewed By: Nicoshev Differential Revision: D87494465 fbshipit-source-id: 0e497f1502dc89611cacbce69dbf69eaa0dfbd93
1 parent 7826ec9 commit 21ea7ed

File tree

2 files changed

+290
-212
lines changed

2 files changed

+290
-212
lines changed

src/FbgemmI8DepthwiseNeon-inl.h

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ static ALWAYS_INLINE void requantize_(
314314
vst1q_u8(C_uint8 + j, xyzw_clamped_v);
315315
} // j loop vectorized and unrolled 4x
316316

317-
vec_tail:
318317
for (; j < n / VLEN * VLEN; j += VLEN) {
319318
int32x4_t x_v = vld1q_s32(C_int32 + j);
320319

@@ -401,15 +400,92 @@ static ALWAYS_INLINE void requantize_(
401400
vst1_lane_u32(C_uint8 + j, vreinterpret_u32_u8(x_clamped_v), 0);
402401
} // j loop vectorized
403402

404-
// There are some leftovers that cannot fit in one full vector. Instead of
405-
// doing a scalar loop, we prepare j to be n - VLEN and jump back to the
406-
// above loop for one extra iteration. Compared to a scalar loop, this reuses
407-
// vector loop code so code size bloat is minimal. Another alternative is
408-
// to use a partial vector register, but that also bloats code size more than
409-
// reusing the above loop body.
410-
if (j < n) {
411-
j = n - VLEN;
412-
goto vec_tail;
403+
// leftover handling using minimal code size
404+
#ifdef __clang__
405+
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
406+
#elif defined(__GNUC__)
407+
#pragma GCC novector unroll 0
408+
#endif
409+
while (j < n) {
410+
int32x4_t x_v;
411+
x_v[0] = C_int32[j];
412+
413+
if constexpr (!B_SYMMETRIC) {
414+
int32x4_t row_offset_v;
415+
row_offset_v[0] = row_offsets[j / K_PER_G];
416+
if constexpr (
417+
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
418+
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
419+
B_zero_point_v[0] = B_zero_point[j];
420+
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
421+
static_assert(K_PER_G == 2);
422+
B_zero_point_v[0] = B_zero_point[j / 2];
423+
}
424+
x_v = vmlsq_s32(x_v, row_offset_v, B_zero_point_v);
425+
}
426+
if constexpr (!A_SYMMETRIC) {
427+
int32x4_t col_offsets_v;
428+
col_offsets_v[0] = col_offsets[j];
429+
x_v = vmlsq_s32(x_v, A_zero_point_v, col_offsets_v);
430+
}
431+
432+
// Convert to float
433+
float32x4_t xf_v;
434+
if constexpr (HAS_BIAS) { // static if
435+
if constexpr (std::is_same_v<BIAS_TYPE, float>) {
436+
float32x4_t x_bias_v;
437+
float32x4_t biasfp_v;
438+
float32x4_t act_times_w_scale_v;
439+
if constexpr (
440+
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
441+
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
442+
act_times_w_scale_v[0] = act_times_w_scale[j];
443+
biasfp_v[0] = bias[j];
444+
x_bias_v = vdivq_f32(biasfp_v, act_times_w_scale_v);
445+
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
446+
act_times_w_scale_v[0] = act_times_w_scale[j / 2];
447+
biasfp_v[0] = bias[j];
448+
x_bias_v = vdivq_f32(biasfp_v, act_times_w_scale_v);
449+
} else {
450+
biasfp_v[0] = bias[j];
451+
x_bias_v = vmulq_f32(biasfp_v, act_times_w_rcp_v);
452+
}
453+
xf_v = vaddq_f32(vcvtq_f32_s32(x_v), x_bias_v);
454+
} else {
455+
int32x4_t biasint_v;
456+
biasint_v[0] = bias[j];
457+
x_v = vaddq_s32(x_v, biasint_v);
458+
xf_v = vcvtq_f32_s32(x_v);
459+
}
460+
} else {
461+
xf_v = vcvtq_f32_s32(x_v);
462+
}
463+
464+
if constexpr (
465+
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
466+
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
467+
multiplier_v[0] = C_multiplier[j];
468+
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
469+
multiplier_v[0] = C_multiplier[j / 2];
470+
}
471+
float32x4_t x_scaled_v = vmulq_f32(xf_v, multiplier_v);
472+
// vcvtnq_s32_f32 always rounds to nearest, which is slightly different
473+
// from x86's _mm256_cvtps_epi32 which rounds according to the current
474+
// rounding mode, which may not be round to nearest. To help catch issues
475+
// and debug, we add an assertion here.
476+
assert(fegetround() == FE_TONEAREST);
477+
int32x4_t x_rounded_v = vcvtnq_s32_f32(x_scaled_v);
478+
479+
int16x8_t x_packed_v_s16 = vqaddq_s16(
480+
vcombine_s16(vqmovn_s32(x_rounded_v), vdup_n_s16(0)),
481+
C_zero_point_epi16_v);
482+
uint8x8_t x_packed_v_u8 = vqmovun_s16(x_packed_v_s16);
483+
uint8x8_t x_clamped_v = vmax_u8(
484+
FUSE_RELU ? vget_low_u8(C_zero_point_epi8_v) : vget_low_u8(min_v),
485+
x_packed_v_u8);
486+
487+
vst1_lane_u8(C_uint8 + j, vreinterpret_u32_u8(x_clamped_v), 0);
488+
j++;
413489
}
414490
}
415491

0 commit comments

Comments
 (0)