Skip to content

Commit 88c5098

Browse files
authored
Merge pull request #141 from BenAnn/master
feat(webgl): update pool2d
2 parents e994cde + 906a48f commit 88c5098

File tree

6 files changed

+156
-7
lines changed

6 files changed

+156
-7
lines changed

packages/paddlejs-backend-webgl/src/ops/shader/pool2d_max.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,34 @@
22
* @file pool2d_max
33
*/
44

5+
function recoverShape({ total_shape, channel, height_shape, width_shape }) {
6+
return [total_shape / channel / height_shape / width_shape, channel, height_shape, width_shape];
7+
}
8+
59
function mainFunc(
610
{ origin },
7-
{ strides = [], paddings = [], ksize }
11+
{ strides = [], paddings = [], ksize, global_pooling, runtime }
812
) {
913
const [stride_v = 1, stride_h = 1] = strides;
1014
const [padTop = 0, padLeft = 0] = paddings;
1115
const [ksize_x, ksize_y] = ksize;
16+
const originShape = recoverShape(origin);
17+
let computedIndex = '';
18+
let outputCode = 'setOutput(float(res));';
19+
if (runtime === 0 && global_pooling === true) {
20+
computedIndex = `
21+
if (curr > res) {
22+
index = ${originShape[2] * originShape[3]} * out_pos[1] + ${originShape[3]} * oy + ox;
23+
}
24+
`;
25+
outputCode = 'setOutput(float(index));';
26+
}
27+
1228
return `
1329
// start函数
1430
void main(void) {
1531
float res = -1. / 0.;
32+
int index = 0;
1633
// 获取output的坐标
1734
ivec4 out_pos = getOutputTensorPos();
1835
int b = out_pos[0];
@@ -40,10 +57,11 @@ function mainFunc(
4057
}
4158
// origin数据
4259
float curr = getValueFromTensorPos_origin(out_pos[0], out_pos[1], oy, ox);
60+
${computedIndex}
4361
res = max(res, curr);
4462
}
4563
}
46-
setOutput(res);
64+
${outputCode}
4765
}
4866
`;
4967
}
@@ -52,14 +70,17 @@ export default {
5270
params: [
5371
'strides',
5472
'paddings',
55-
'ksize'
73+
'ksize',
74+
'global_pooling',
75+
'runtime'
5676
],
5777
textureFuncConf: {
5878
origin: ['getValueFromTensorPos']
5979
},
6080
behaviors: [
6181
'isMax',
6282
'setPacked',
83+
'setAdaptive',
6384
'isGlobalPooling'
6485
]
6586
};

packages/paddlejs-backend-webgl/src/ops/utils.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const inputParams = [
2121
'limit',
2222
'channel',
2323
'total_shape',
24-
'numbers_shape',
24+
'numbers_shape'
2525
];
2626

2727
const outParams = [
@@ -44,7 +44,7 @@ const baseParams = {
4444
]
4545
};
4646

47-
function getTensorParams(inputTensors: Tensor[], ownParams: [], fShaderParams: object): opInfo {
47+
function getTensorParams(inputTensors: Tensor[], ownParams: [], fShaderParams: object, runtime: number): opInfo {
4848
const tensorsParams = {};
4949
const opParams = {};
5050
const tensorNames = [] as string[];
@@ -97,6 +97,9 @@ function getTensorParams(inputTensors: Tensor[], ownParams: [], fShaderParams: o
9797
if (fShaderParams['active_function']) {
9898
opParams['active_function'] = fShaderParams['active_function'];
9999
}
100+
101+
opParams['runtime'] = runtime;
102+
100103
return { textureParams: tensorsParams, opParams, active_function: fShaderParams['active_function'] };
101104
}
102105

packages/paddlejs-backend-webgl/src/webgl/buildShader.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ export default function buildShader(textureConf, type, inputTensors, fShaderPara
2020
const { params = {}, mainFunc, textureFuncConf = {}, commonFuncConf } = ops[opName];
2121

2222
// textureList: [filter, origin, bias]
23-
const { textureParams, opParams, active_function } = getTensorParams(inputTensors, params, fShaderParams);
23+
const { textureParams, opParams, active_function } = getTensorParams(
24+
inputTensors, params, fShaderParams, runtime
25+
);
2426

2527
const prefixCode = genPrefixCode(textureConf);
2628

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
{
2+
"chunkNum": 0,
3+
"ops": [
4+
{
5+
"attrs": {
6+
"op_device": ""
7+
},
8+
"inputs": {
9+
"X": [
10+
"feed"
11+
]
12+
},
13+
"outputs": {
14+
"Out": [
15+
"tmp_43"
16+
]
17+
},
18+
"type": "feed"
19+
},
20+
{
21+
"attrs": {
22+
"adaptive": true,
23+
"global_pooling": false,
24+
"ksize": [
25+
1,
26+
1
27+
],
28+
"op_device": "",
29+
"paddings": [
30+
0,
31+
0
32+
],
33+
"strides": [
34+
1,
35+
1
36+
]
37+
},
38+
"inputs": {
39+
"X": [
40+
"tmp_43"
41+
]
42+
},
43+
"outputs": {
44+
"Mask": [
45+
"max_pool2d_with_index_0.tmp_1"
46+
],
47+
"Out": [
48+
"max_pool2d_with_index_0.tmp_0"
49+
]
50+
},
51+
"type": "max_pool2d_with_index"
52+
},
53+
{
54+
"attrs": {
55+
"op_device": ""
56+
},
57+
"inputs": {
58+
"X": [
59+
"max_pool2d_with_index_0.tmp_0"
60+
]
61+
},
62+
"outputs": {
63+
"Out": [
64+
"fetch"
65+
]
66+
},
67+
"type": "fetch"
68+
}
69+
],
70+
"vars": [
71+
{
72+
"name": "tmp_43",
73+
"shape": [
74+
1,
75+
3,
76+
3,
77+
3
78+
],
79+
"data": [
80+
1, 2, 3, 1, 1, 5, 1, 7, 1,
81+
9, 10, 1, 1, 1, 8, 1, 1, 1,
82+
1, 7, 1, 9, 14, 1, 1, 1, 8
83+
]
84+
},
85+
{
86+
"name": "max_pool2d_with_index_0.tmp_0",
87+
"shape": [
88+
1,
89+
3,
90+
1,
91+
1
92+
],
93+
"data": [7, 10, 14]
94+
},
95+
{
96+
"name": "max_pool2d_with_index_0.tmp_1",
97+
"shape": [
98+
1,
99+
3,
100+
1,
101+
1
102+
],
103+
"data": [7, 10, 22]
104+
}
105+
]
106+
}

packages/paddlejs-core/src/opFactory/opBehaviors.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ const behaviors : Behaviors = {
2020
}
2121
},
2222

23+
setAdaptive() {
24+
if (
25+
this.attrs.adaptive
26+
&& this.attrs.ksize.length === 2
27+
&& this.attrs.ksize[0] === 1
28+
&& this.attrs.ksize[1] === 1
29+
) {
30+
this.attrs.adaptive = false;
31+
this.attrs.global_pooling = true;
32+
}
33+
},
34+
2335
isGlobalPooling() {
2436
const counter = this.input.X[0] || {};
2537
const length = (counter.shape && counter.shape.length) || 0;

packages/paddlejs-core/src/opFactory/opDataBuilder.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ export default class OpData {
127127
scale: 'scale',
128128
bias: 'bias',
129129
mean: 'mean',
130-
variance: 'variance'
130+
variance: 'variance',
131+
mask: 'out'
131132
};
132133

133134

@@ -170,6 +171,10 @@ export default class OpData {
170171
this.name = 'reshape2';
171172
}
172173

174+
if (this.name.indexOf('max_pool2d_with_index') > -1) {
175+
this.name = 'pool2d_max';
176+
}
177+
173178
const tensorData: ModelVar[] = this.tensorData;
174179
// unique behavior
175180
const opKey = `${GLOBALS.backend}_${this.name}`;

0 commit comments

Comments
 (0)