Skip to content

Commit 535134b

Browse files
authored
Merge pull request #59 from InfiniTensor/dev
Dev
2 parents d5e1b02 + d0c4692 commit 535134b

File tree

81 files changed

+2226
-442
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+2226
-442
lines changed

scripts/compare/compare.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def parse_args():
2323
args.actual,
2424
)
2525

26+
2627
def getDiff(base, test):
2728
absolute_diff = np.subtract(base, test)
2829
max_absolute_diff = np.max(np.abs(absolute_diff))
@@ -35,16 +36,19 @@ def getDiff(base, test):
3536

3637
return max_absolute_diff, max_relative_diff
3738

38-
def compare_npy(actual_path, expect_path, edge, node):
39+
40+
def compare_npy(node, actual_path, expect_path):
3941
actual = np.load(actual_path)
4042
expect = np.load(expect_path)
4143
if np.isnan(actual).any():
42-
print(f"NAN value in node:{node} edge:{edge}")
44+
print(f"NAN value in node:{node}\t{actual_path}\t{expect_path}")
4345
return
44-
46+
4547
max_absolute_diff, max_relative_diff = getDiff(expect, actual)
46-
if max_absolute_diff != 0.0: ## No need to print tensor with no diff
47-
print(f'{max_absolute_diff}\t{max_relative_diff}\t{node}\t{edge}')
48+
if max_absolute_diff != 0.0: ## No need to print tensor with no diff
49+
print(
50+
f"{max_absolute_diff}\t{max_relative_diff}\t{node}\t{actual_path}\t{expect_path}"
51+
)
4852

4953

5054
def main():
@@ -70,9 +74,7 @@ def main():
7074
expect_file = expect_file + ".npy"
7175
expect_file_path = os.path.join(expect_dir, expect_file)
7276
if os.path.exists(expect_file_path):
73-
compare_npy(
74-
actual_file_path, expect_file_path, edge_name, node_name
75-
)
77+
compare_npy(meta_file, actual_file_path, expect_file_path)
7678

7779

7880
if __name__ == "__main__":

src/04kernel/include/kernel/attributes/broadcaster.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace refactor::kernel {
1515
explicit Broadcaster(std::vector<slice_t<dim_t>>);
1616
explicit Broadcaster(TensorRefs const &inputs);
1717
void locate(dim_t k, dim_t ans[]) const noexcept;
18+
bool needBroadcast() const noexcept;
1819
};
1920

2021
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1-
#ifndef KERNEL_MATMUL_INFO_H
2-
#define KERNEL_MATMUL_INFO_H
1+
#ifndef KERNEL_MAT_MUL_INFO_H
2+
#define KERNEL_MAT_MUL_INFO_H
33

44
#include "kernel/attributes/broadcaster.h"
55
#include "kernel/attributes/expand_info.h"
6-
#include <variant>
76

87
namespace refactor::kernel {
98

109
struct MatMulInfo {
1110
DataType dataType;
1211
float alpha, beta;
1312
bool transA, transB;
14-
size_t m, k, n;
13+
dim_t m, k, n;
1514
// Expand operation info for biasd
1615
std::optional<ExpandInfo> biasExpand;
17-
// A constant batch or a 2-directional broadcaster that deals with dimensions before the last 2 dimensions
18-
std::variant<Broadcaster, size_t> broadcasterOrBatch;
16+
// A 2-directional broadcaster that deals with dimensions before the last 2 dimensions
17+
Broadcaster broadcaster;
1918

2019
MatMulInfo(Tensor const &, Tensor const &,
2120
std::optional<std::reference_wrapper<Tensor const>>,
@@ -24,4 +23,4 @@ namespace refactor::kernel {
2423

2524
}// namespace refactor::kernel
2625

27-
#endif// KERNEL_MATMUL_INFO_H
26+
#endif// KERNEL_MAT_MUL_INFO_H
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef KERNEL_MAT_MUL_INTEGER_INFO_H
2+
#define KERNEL_MAT_MUL_INTEGER_INFO_H
3+
4+
#include "kernel/attributes/broadcaster.h"
5+
6+
namespace refactor::kernel {
7+
8+
struct MatMulIntegerInfo {
9+
struct Input {
10+
bool
11+
withZeroPoint,
12+
signed_,
13+
scalar;
14+
15+
Input(TensorRefs const &, size_t i) noexcept;
16+
};
17+
18+
Input a, b;
19+
dim_t m, k, n;
20+
Broadcaster broadcaster;
21+
22+
explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept;
23+
dim_t batch() const noexcept;
24+
};
25+
26+
}// namespace refactor::kernel
27+
28+
#endif// KERNEL_MAT_MUL_INTEGER_INFO_H
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef KERNEL_DEQUANTIZE_LINEAR_H
2+
#define KERNEL_DEQUANTIZE_LINEAR_H
3+
4+
#include "../collector.h"
5+
6+
namespace refactor::kernel {
7+
8+
struct DequantizeLinearCollector final : public InfoCollector {
9+
10+
explicit DequantizeLinearCollector(decltype(_target)) noexcept;
11+
12+
std::vector<KernelBox>
13+
filter(TensorRefs inputs, TensorRefs outputs) const final;
14+
};
15+
16+
}// namespace refactor::kernel
17+
18+
#endif// KERNEL_DEQUANTIZE_LINEAR_H
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_H
2+
#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_H
3+
4+
#include "../collector.h"
5+
6+
namespace refactor::kernel {
7+
8+
struct DynamicQuantizeLinearCollector final : public InfoCollector {
9+
10+
explicit DynamicQuantizeLinearCollector(decltype(_target)) noexcept;
11+
12+
std::vector<KernelBox>
13+
filter(TensorRefs inputs, TensorRefs outputs) const final;
14+
};
15+
16+
}// namespace refactor::kernel
17+
18+
#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_H
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef KERNEL_MAT_MUL_INTEGER_H
2+
#define KERNEL_MAT_MUL_INTEGER_H
3+
4+
#include "../collector.h"
5+
6+
namespace refactor::kernel {
7+
8+
struct MatMulIntegerCollector final : public InfoCollector {
9+
10+
constexpr MatMulIntegerCollector(decltype(_target) target) noexcept
11+
: InfoCollector(target) {}
12+
13+
std::vector<KernelBox>
14+
filter(TensorRefs inputs, TensorRefs outputs) const final;
15+
};
16+
17+
}// namespace refactor::kernel
18+
19+
#endif// KERNEL_MAT_MUL_INTEGER_H

src/04kernel/src/attributes/broadcaster.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,8 @@ namespace refactor::kernel {
9696
}
9797
}
9898

99+
bool Broadcaster::needBroadcast() const noexcept {
100+
return !strides.empty();
101+
}
102+
99103
}// namespace refactor::kernel

src/04kernel/src/attributes/matmul_info.cc renamed to src/04kernel/src/attributes/mat_mul_info.cc

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
#include "kernel/attributes/matmul_info.h"
2-
#include <cstddef>
3-
#include <numeric>
1+
#include "kernel/attributes/mat_mul_info.h"
42

53
namespace refactor::kernel {
64

7-
ExpandInfo buildBias(size_t m, size_t n,
5+
ExpandInfo buildBias(dim_t m, dim_t n,
86
Tensor const &a,
97
Tensor const &b,
108
Tensor const &c) {
119
std::vector<dim_t> output(std::max(a.rank(), b.rank()));
1210
auto it = output.rbegin();
1311
*it++ = n;
1412
*it++ = m;
15-
for (auto da = static_cast<size_t>(a.rank() - 2),
16-
db = static_cast<size_t>(b.rank() - 2);
13+
for (auto da = static_cast<dim_t>(a.rank() - 2),
14+
db = static_cast<dim_t>(b.rank() - 2);
1715
auto i : range0_(output.size() - 2)) {
1816
auto a_ = i < da ? a.shape[da - i - 1] : 1;
1917
auto b_ = i < db ? b.shape[db - i - 1] : 1;
@@ -26,13 +24,6 @@ namespace refactor::kernel {
2624
slice(output.data(), output.size()));
2725
}
2826

29-
std::variant<Broadcaster, size_t> buildBroadcasterOrBatch(slice_t<dim_t> dimA, slice_t<dim_t> dimB) {
30-
if (std::equal(dimA.begin(), dimA.end(), dimB.begin(), dimB.end())) {
31-
return std::accumulate(dimA.begin(), dimA.end(), (size_t) 1, std::multiplies<size_t>());
32-
}
33-
return Broadcaster({dimA, dimB});
34-
}
35-
3627
MatMulInfo::MatMulInfo(
3728
Tensor const &a, Tensor const &b,
3829
std::optional<std::reference_wrapper<Tensor const>> c,
@@ -44,7 +35,8 @@ namespace refactor::kernel {
4435
k(transA ? a.shape.rbegin()[1] : a.shape.rbegin()[0]),
4536
n(transB ? b.shape.rbegin()[1] : b.shape.rbegin()[0]),
4637
biasExpand(c ? std::make_optional(buildBias(m, n, a, b, *c)) : std::nullopt),
47-
broadcasterOrBatch(buildBroadcasterOrBatch(slice(a.shape.data(), a.shape.size() - 2), slice(b.shape.data(), b.shape.size() - 2))) {
38+
broadcaster({slice(a.shape.data(), a.shape.size() - 2),
39+
slice(b.shape.data(), b.shape.size() - 2)}) {
4840
auto kB = transB ? b.shape.rbegin()[0] : b.shape.rbegin()[1];
4941
ASSERT(k == kB, "MatMul: input shape not matched.");
5042
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "kernel/attributes/mat_mul_integer_info.h"
2+
3+
namespace refactor::kernel {
4+
5+
MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept
6+
: withZeroPoint(false),
7+
signed_(true),
8+
scalar(true) {
9+
if (inputs.size() > i + 2) {
10+
auto const &t = inputs[i + 2].get();
11+
auto size = t.elementsSize();
12+
if (t.data) {
13+
auto data = slice(t.data->get<uint8_t>(), size);
14+
if (std::all_of(data.begin(), data.end(), [](auto x) { return x == 0; })) {
15+
return;
16+
}
17+
}
18+
withZeroPoint = true;
19+
signed_ = t.dataType == DataType::I8;
20+
scalar = size == 1;
21+
}
22+
}
23+
24+
MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept
25+
: a(inputs, 0),
26+
b(inputs, 1),
27+
#define A (inputs[0].get().shape)
28+
#define B (inputs[1].get().shape)
29+
m(A.rbegin()[1]),
30+
k(A.rbegin()[0]),
31+
n(B.rbegin()[0]),
32+
broadcaster({slice(A.data(), A.size() - 2),
33+
slice(B.data(), B.size() - 2)}) {
34+
}
35+
#undef A
36+
#undef B
37+
38+
dim_t MatMulIntegerInfo::batch() const noexcept {
39+
return broadcaster.outputsCount;
40+
}
41+
42+
}// namespace refactor::kernel

0 commit comments

Comments
 (0)