Skip to content

Commit 65e2ee7

Browse files
committed
feat(webgl): update pool2d
1 parent 8f74fa6 commit 65e2ee7

File tree

4 files changed

+146
-4
lines changed

4 files changed

+146
-4
lines changed

packages/paddlejs-backend-webgl/src/ops/shader/pool2d_max.ts

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,24 @@
22
* @file pool2d_max
33
*/
44

5+
function recoverShape({ total_shape, channel, height_shape, width_shape }) {
6+
return [total_shape / channel / height_shape / width_shape, channel, height_shape, width_shape];
7+
}
8+
59
function mainFunc(
610
{ origin },
7-
{ strides = [], paddings = [], ksize }
11+
{ strides = [], paddings = [], ksize, global_pooling }
812
) {
913
const [stride_v = 1, stride_h = 1] = strides;
1014
const [padTop = 0, padLeft = 0] = paddings;
1115
const [ksize_x, ksize_y] = ksize;
16+
const originShape = recoverShape(origin);
17+
1218
return `
1319
// start函数
1420
void main(void) {
1521
float res = -1. / 0.;
22+
int index = 0;
1623
// 获取output的坐标
1724
ivec4 out_pos = getOutputTensorPos();
1825
int b = out_pos[0];
@@ -40,10 +47,20 @@ function mainFunc(
4047
}
4148
// origin数据
4249
float curr = getValueFromTensorPos_origin(out_pos[0], out_pos[1], oy, ox);
50+
if (layer_run_time == 0 && ${global_pooling === true}) {
51+
if (curr > res) {
52+
index = ${originShape[2] * originShape[3]} * out_pos[1] + ${originShape[3]} * oy + ox;
53+
}
54+
}
4355
res = max(res, curr);
4456
}
4557
}
46-
setOutput(res);
58+
if (layer_run_time == 0 && ${global_pooling === true}) {
59+
setOutput(float(index));
60+
}
61+
else {
62+
setOutput(float(res));
63+
}
4764
}
4865
`;
4966
}
@@ -52,14 +69,16 @@ export default {
5269
params: [
5370
'strides',
5471
'paddings',
55-
'ksize'
72+
'ksize',
73+
'global_pooling'
5674
],
5775
textureFuncConf: {
5876
origin: ['getValueFromTensorPos']
5977
},
6078
behaviors: [
6179
'isMax',
6280
'setPacked',
81+
'setAdaptive',
6382
'isGlobalPooling'
6483
]
6584
};
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
{
2+
"chunkNum": 0,
3+
"ops": [
4+
{
5+
"attrs": {
6+
"op_device": ""
7+
},
8+
"inputs": {
9+
"X": [
10+
"feed"
11+
]
12+
},
13+
"outputs": {
14+
"Out": [
15+
"tmp_43"
16+
]
17+
},
18+
"type": "feed"
19+
},
20+
{
21+
"attrs": {
22+
"adaptive": true,
23+
"global_pooling": false,
24+
"ksize": [
25+
1,
26+
1
27+
],
28+
"op_device": "",
29+
"paddings": [
30+
0,
31+
0
32+
],
33+
"strides": [
34+
1,
35+
1
36+
]
37+
},
38+
"inputs": {
39+
"X": [
40+
"tmp_43"
41+
]
42+
},
43+
"outputs": {
44+
"Mask": [
45+
"max_pool2d_with_index_0.tmp_1"
46+
],
47+
"Out": [
48+
"max_pool2d_with_index_0.tmp_0"
49+
]
50+
},
51+
"type": "max_pool2d_with_index"
52+
},
53+
{
54+
"attrs": {
55+
"op_device": ""
56+
},
57+
"inputs": {
58+
"X": [
59+
"max_pool2d_with_index_0.tmp_0"
60+
]
61+
},
62+
"outputs": {
63+
"Out": [
64+
"fetch"
65+
]
66+
},
67+
"type": "fetch"
68+
}
69+
],
70+
"vars": [
71+
{
72+
"name": "tmp_43",
73+
"shape": [
74+
1,
75+
3,
76+
3,
77+
3
78+
],
79+
"data": [
80+
1, 2, 3, 1, 1, 5, 1, 7, 1,
81+
9, 10, 1, 1, 1, 8, 1, 1, 1,
82+
1, 7, 1, 9, 14, 1, 1, 1, 8
83+
]
84+
},
85+
{
86+
"name": "max_pool2d_with_index_0.tmp_0",
87+
"shape": [
88+
1,
89+
3,
90+
1,
91+
1
92+
],
93+
"data": [7, 10, 14]
94+
},
95+
{
96+
"name": "max_pool2d_with_index_0.tmp_1",
97+
"shape": [
98+
1,
99+
3,
100+
1,
101+
1
102+
],
103+
"data": [7, 10, 22]
104+
}
105+
]
106+
}

packages/paddlejs-core/src/opFactory/opBehaviors.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ const behaviors : Behaviors = {
2020
}
2121
},
2222

23+
setAdaptive() {
24+
if (
25+
this.attrs.adaptive
26+
&& this.attrs.ksize.length === 2
27+
&& this.attrs.ksize[0] === 1
28+
&& this.attrs.ksize[1] === 1
29+
) {
30+
this.attrs.adaptive = false;
31+
this.attrs.global_pooling = true;
32+
}
33+
},
34+
2335
isGlobalPooling() {
2436
const counter = this.input.X[0] || {};
2537
const length = (counter.shape && counter.shape.length) || 0;

packages/paddlejs-core/src/opFactory/opDataBuilder.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ export default class OpData {
127127
scale: 'scale',
128128
bias: 'bias',
129129
mean: 'mean',
130-
variance: 'variance'
130+
variance: 'variance',
131+
mask: 'out'
131132
};
132133

133134

@@ -170,6 +171,10 @@ export default class OpData {
170171
this.name = 'reshape2';
171172
}
172173

174+
if (this.name.indexOf('max_pool2d_with_index') > -1) {
175+
this.name = 'pool2d_max';
176+
}
177+
173178
const tensorData: ModelVar[] = this.tensorData;
174179
// unique behavior
175180
const opKey = `${GLOBALS.backend}_${this.name}`;

0 commit comments

Comments
 (0)