Skip to content

Commit 416cd2e

Browse files
committed
fix(kernel): 稍微调整 MatMulInteger 逻辑
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent bed3627 commit 416cd2e

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

scripts/compare/compare.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def parse_args():
2323
args.actual,
2424
)
2525

26+
2627
def getDiff(base, test):
2728
absolute_diff = np.subtract(base, test)
2829
max_absolute_diff = np.max(np.abs(absolute_diff))
@@ -35,16 +36,19 @@ def getDiff(base, test):
3536

3637
return max_absolute_diff, max_relative_diff
3738

38-
def compare_npy(actual_path, expect_path, edge, node):
39+
40+
def compare_npy(node, actual_path, expect_path):
3941
actual = np.load(actual_path)
4042
expect = np.load(expect_path)
4143
if np.isnan(actual).any():
42-
print(f"NAN value in node:{node} edge:{edge}")
44+
print(f"NAN value in node:{node}\t{actual_path}\t{expect_path}")
4345
return
44-
46+
4547
max_absolute_diff, max_relative_diff = getDiff(expect, actual)
46-
if max_absolute_diff != 0.0: ## No need to print tensor with no diff
47-
print(f'{max_absolute_diff}\t{max_relative_diff}\t{node}\t{edge}')
48+
if max_absolute_diff != 0.0: ## No need to print tensor with no diff
49+
print(
50+
f"{max_absolute_diff}\t{max_relative_diff}\t{node}\t{actual_path}\t{expect_path}"
51+
)
4852

4953

5054
def main():
@@ -70,9 +74,7 @@ def main():
7074
expect_file = expect_file + ".npy"
7175
expect_file_path = os.path.join(expect_dir, expect_file)
7276
if os.path.exists(expect_file_path):
73-
compare_npy(
74-
actual_file_path, expect_file_path, edge_name, node_name
75-
)
77+
compare_npy(meta_file, actual_file_path, expect_file_path)
7678

7779

7880
if __name__ == "__main__":

src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ namespace refactor::kernel {
1010

1111
template<class T> __device__ __forceinline__ static int8_t sub(T, T);
1212
template<> __device__ __forceinline__ int8_t sub<int8_t>(int8_t a, int8_t b) { return a - b; }
13-
template<> __device__ __forceinline__ int8_t sub<uint8_t>(uint8_t a, uint8_t b) { return static_cast<int8_t>(static_cast<int16_t>(a) - static_cast<int16_t>(b)); }
13+
template<> __device__ __forceinline__ int8_t sub<uint8_t>(uint8_t a, uint8_t b) {
14+
constexpr static int16_t MAX = 127;
15+
return static_cast<int8_t>(CUB_MIN(MAX, static_cast<int16_t>(a) - static_cast<int16_t>(b)));
16+
}
1417

1518
template<class T>
1619
struct MatMulIntegerZPFunctorScalar {
@@ -33,16 +36,16 @@ namespace refactor::kernel {
3336
}
3437

3538
template<class T>
36-
struct MatMulIntegerZPFunctorA {
37-
dim_t m, n;
39+
struct MatMulIntegerZPFunctor {
40+
dim_t m, n, a, b, c;
3841
T const *src, *zp;
3942

4043
__device__ int8_t operator()(size_t idx) const noexcept {
4144
auto
42-
// k = idx % n,
45+
k = idx % n,
4346
j = idx / n % m,
4447
i = idx / n / m;
45-
return sub(src[idx], zp[i * m + j]);
48+
return sub(src[idx], zp[i * a + j * b + k * c]);
4649
}
4750
};
4851

@@ -52,38 +55,30 @@ namespace refactor::kernel {
5255
int8_t *dst, void const *src_, void const *zp_) {
5356
thrust::tabulate(thrust::device,
5457
dst, dst + b * m * n,
55-
MatMulIntegerZPFunctorA<T>{
58+
MatMulIntegerZPFunctor<T>{
5659
m,
5760
n,
61+
m,
62+
1,
63+
0,
5864
reinterpret_cast<T const *>(src_),
5965
reinterpret_cast<T const *>(zp_),
6066
});
6167
}
6268

63-
template<class T>
64-
struct MatMulIntegerZPFunctorB {
65-
dim_t m, n;
66-
T const *src, *zp;
67-
68-
__device__ int8_t operator()(size_t idx) const noexcept {
69-
auto
70-
k = idx % n,
71-
// j = idx / n % m,
72-
i = idx / n / m;
73-
return sub(src[idx], zp[i * n + k]);
74-
}
75-
};
76-
7769
template<class T>
7870
static void applyZeroPointB(
7971
dim_t b, dim_t m, dim_t n,
8072
int8_t *dst, void const *src_, void const *zp_) {
8173

8274
thrust::tabulate(thrust::device,
8375
dst, dst + b * m * n,
84-
MatMulIntegerZPFunctorB<T>{
76+
MatMulIntegerZPFunctor<T>{
8577
m,
8678
n,
79+
n,
80+
0,
81+
1,
8782
reinterpret_cast<T const *>(src_),
8883
reinterpret_cast<T const *>(zp_),
8984
});

0 commit comments

Comments
 (0)