@@ -10,7 +10,10 @@ namespace refactor::kernel {
1010
1111 template <class T > __device__ __forceinline__ static int8_t sub (T, T);
1212 template <> __device__ __forceinline__ int8_t sub<int8_t >(int8_t a, int8_t b) { return a - b; }
13- template <> __device__ __forceinline__ int8_t sub<uint8_t >(uint8_t a, uint8_t b) { return static_cast <int8_t >(static_cast <int16_t >(a) - static_cast <int16_t >(b)); }
13+ template <> __device__ __forceinline__ int8_t sub<uint8_t >(uint8_t a, uint8_t b) {
14+ constexpr static int16_t MAX = 127 ;
15+ return static_cast <int8_t >(CUB_MIN (MAX, static_cast <int16_t >(a) - static_cast <int16_t >(b)));
16+ }
1417
1518 template <class T >
1619 struct MatMulIntegerZPFunctorScalar {
@@ -33,16 +36,16 @@ namespace refactor::kernel {
3336 }
3437
3538 template <class T >
36- struct MatMulIntegerZPFunctorA {
37- dim_t m, n;
39+ struct MatMulIntegerZPFunctor {
40+ dim_t m, n, a, b, c ;
3841 T const *src, *zp;
3942
4043 __device__ int8_t operator ()(size_t idx) const noexcept {
4144 auto
42- // k = idx % n,
45+ k = idx % n,
4346 j = idx / n % m,
4447 i = idx / n / m;
45- return sub (src[idx], zp[i * m + j]);
48+ return sub (src[idx], zp[i * a + j * b + k * c ]);
4649 }
4750 };
4851
@@ -52,38 +55,30 @@ namespace refactor::kernel {
5255 int8_t *dst, void const *src_, void const *zp_) {
5356 thrust::tabulate (thrust::device,
5457 dst, dst + b * m * n,
55- MatMulIntegerZPFunctorA <T>{
58+ MatMulIntegerZPFunctor <T>{
5659 m,
5760 n,
61+ m,
62+ 1 ,
63+ 0 ,
5864 reinterpret_cast <T const *>(src_),
5965 reinterpret_cast <T const *>(zp_),
6066 });
6167 }
6268
63- template <class T >
64- struct MatMulIntegerZPFunctorB {
65- dim_t m, n;
66- T const *src, *zp;
67-
68- __device__ int8_t operator ()(size_t idx) const noexcept {
69- auto
70- k = idx % n,
71- // j = idx / n % m,
72- i = idx / n / m;
73- return sub (src[idx], zp[i * n + k]);
74- }
75- };
76-
7769 template <class T >
7870 static void applyZeroPointB (
7971 dim_t b, dim_t m, dim_t n,
8072 int8_t *dst, void const *src_, void const *zp_) {
8173
8274 thrust::tabulate (thrust::device,
8375 dst, dst + b * m * n,
84- MatMulIntegerZPFunctorB <T>{
76+ MatMulIntegerZPFunctor <T>{
8577 m,
8678 n,
79+ n,
80+ 0 ,
81+ 1 ,
8782 reinterpret_cast <T const *>(src_),
8883 reinterpret_cast <T const *>(zp_),
8984 });
0 commit comments