@@ -314,7 +314,6 @@ static ALWAYS_INLINE void requantize_(
314314 vst1q_u8 (C_uint8 + j, xyzw_clamped_v);
315315 } // j loop vectorized and unrolled 4x
316316
317- vec_tail:
318317 for (; j < n / VLEN * VLEN; j += VLEN) {
319318 int32x4_t x_v = vld1q_s32 (C_int32 + j);
320319
@@ -401,15 +400,92 @@ static ALWAYS_INLINE void requantize_(
401400 vst1_lane_u32 (C_uint8 + j, vreinterpret_u32_u8 (x_clamped_v), 0 );
402401 } // j loop vectorized
403402
404- // There are some leftovers that cannot fit in one full vector. Instead of
405- // doing a scalar loop, we prepare j to be n - VLEN and jump back to the
406- // above loop for one extra iteration. Compared to a scalar loop, this reuses
407- // vector loop code so code size bloat is minimal. Another alternative is
408- // to use a partial vector register, but that also bloats code size more than
409- // reusing the above loop body.
410- if (j < n) {
411- j = n - VLEN;
412- goto vec_tail;
403+ // leftover handling using minimal code size
404+ #ifdef __clang__
405+ #pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
406+ #elif defined(__GNUC__)
407+ #pragma GCC novector unroll 0
408+ #endif
409+ while (j < n) {
410+ int32x4_t x_v;
411+ x_v[0 ] = C_int32[j];
412+
413+ if constexpr (!B_SYMMETRIC) {
414+ int32x4_t row_offset_v;
415+ row_offset_v[0 ] = row_offsets[j / K_PER_G];
416+ if constexpr (
417+ Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
418+ (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1 )) {
419+ B_zero_point_v[0 ] = B_zero_point[j];
420+ } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
421+ static_assert (K_PER_G == 2 );
422+ B_zero_point_v[0 ] = B_zero_point[j / 2 ];
423+ }
424+ x_v = vmlsq_s32 (x_v, row_offset_v, B_zero_point_v);
425+ }
426+ if constexpr (!A_SYMMETRIC) {
427+ int32x4_t col_offsets_v;
428+ col_offsets_v[0 ] = col_offsets[j];
429+ x_v = vmlsq_s32 (x_v, A_zero_point_v, col_offsets_v);
430+ }
431+
432+ // Convert to float
433+ float32x4_t xf_v;
434+ if constexpr (HAS_BIAS) { // static if
435+ if constexpr (std::is_same_v<BIAS_TYPE, float >) {
436+ float32x4_t x_bias_v;
437+ float32x4_t biasfp_v;
438+ float32x4_t act_times_w_scale_v;
439+ if constexpr (
440+ Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
441+ (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1 )) {
442+ act_times_w_scale_v[0 ] = act_times_w_scale[j];
443+ biasfp_v[0 ] = bias[j];
444+ x_bias_v = vdivq_f32 (biasfp_v, act_times_w_scale_v);
445+ } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
446+ act_times_w_scale_v[0 ] = act_times_w_scale[j / 2 ];
447+ biasfp_v[0 ] = bias[j];
448+ x_bias_v = vdivq_f32 (biasfp_v, act_times_w_scale_v);
449+ } else {
450+ biasfp_v[0 ] = bias[j];
451+ x_bias_v = vmulq_f32 (biasfp_v, act_times_w_rcp_v);
452+ }
453+ xf_v = vaddq_f32 (vcvtq_f32_s32 (x_v), x_bias_v);
454+ } else {
455+ int32x4_t biasint_v;
456+ biasint_v[0 ] = bias[j];
457+ x_v = vaddq_s32 (x_v, biasint_v);
458+ xf_v = vcvtq_f32_s32 (x_v);
459+ }
460+ } else {
461+ xf_v = vcvtq_f32_s32 (x_v);
462+ }
463+
464+ if constexpr (
465+ Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
466+ (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1 )) {
467+ multiplier_v[0 ] = C_multiplier[j];
468+ } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
469+ multiplier_v[0 ] = C_multiplier[j / 2 ];
470+ }
471+ float32x4_t x_scaled_v = vmulq_f32 (xf_v, multiplier_v);
472+ // vcvtnq_s32_f32 always rounds to nearest, which is slightly different
473+ // from x86's _mm256_cvtps_epi32 which rounds according to the current
474+ // rounding mode, which may not be round to nearest. To help catch issues
475+ // and debug, we add an assertion here.
476+ assert (fegetround () == FE_TONEAREST);
477+ int32x4_t x_rounded_v = vcvtnq_s32_f32 (x_scaled_v);
478+
479+ int16x8_t x_packed_v_s16 = vqaddq_s16 (
480+ vcombine_s16 (vqmovn_s32 (x_rounded_v), vdup_n_s16 (0 )),
481+ C_zero_point_epi16_v);
482+ uint8x8_t x_packed_v_u8 = vqmovun_s16 (x_packed_v_s16);
483+ uint8x8_t x_clamped_v = vmax_u8 (
484+ FUSE_RELU ? vget_low_u8 (C_zero_point_epi8_v) : vget_low_u8 (min_v),
485+ x_packed_v_u8);
486+
487+ vst1_lane_u8 (C_uint8 + j, vreinterpret_u32_u8 (x_clamped_v), 0 );
488+ j++;
413489 }
414490}
415491
0 commit comments