Skip to content

Commit 906a48f

Browse files
committed
feat(webgl): update pool2d
1 parent 65e2ee7 commit 906a48f

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

packages/paddlejs-backend-webgl/src/ops/shader/pool2d_max.ts

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,22 @@ function recoverShape({ total_shape, channel, height_shape, width_shape }) {
88

99
function mainFunc(
1010
{ origin },
11-
{ strides = [], paddings = [], ksize, global_pooling }
11+
{ strides = [], paddings = [], ksize, global_pooling, runtime }
1212
) {
1313
const [stride_v = 1, stride_h = 1] = strides;
1414
const [padTop = 0, padLeft = 0] = paddings;
1515
const [ksize_x, ksize_y] = ksize;
1616
const originShape = recoverShape(origin);
17+
let computedIndex = '';
18+
let outputCode = 'setOutput(float(res));';
19+
if (runtime === 0 && global_pooling === true) {
20+
computedIndex = `
21+
if (curr > res) {
22+
index = ${originShape[2] * originShape[3]} * out_pos[1] + ${originShape[3]} * oy + ox;
23+
}
24+
`;
25+
outputCode = 'setOutput(float(index));';
26+
}
1727

1828
return `
1929
// start函数
@@ -47,20 +57,11 @@ function mainFunc(
4757
}
4858
// origin数据
4959
float curr = getValueFromTensorPos_origin(out_pos[0], out_pos[1], oy, ox);
50-
if (layer_run_time == 0 && ${global_pooling === true}) {
51-
if (curr > res) {
52-
index = ${originShape[2] * originShape[3]} * out_pos[1] + ${originShape[3]} * oy + ox;
53-
}
54-
}
60+
${computedIndex}
5561
res = max(res, curr);
5662
}
5763
}
58-
if (layer_run_time == 0 && ${global_pooling === true}) {
59-
setOutput(float(index));
60-
}
61-
else {
62-
setOutput(float(res));
63-
}
64+
${outputCode}
6465
}
6566
`;
6667
}
@@ -70,7 +71,8 @@ export default {
7071
'strides',
7172
'paddings',
7273
'ksize',
73-
'global_pooling'
74+
'global_pooling',
75+
'runtime'
7476
],
7577
textureFuncConf: {
7678
origin: ['getValueFromTensorPos']

packages/paddlejs-backend-webgl/src/ops/utils.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const inputParams = [
2121
'limit',
2222
'channel',
2323
'total_shape',
24-
'numbers_shape',
24+
'numbers_shape'
2525
];
2626

2727
const outParams = [
@@ -44,7 +44,7 @@ const baseParams = {
4444
]
4545
};
4646

47-
function getTensorParams(inputTensors: Tensor[], ownParams: [], fShaderParams: object): opInfo {
47+
function getTensorParams(inputTensors: Tensor[], ownParams: [], fShaderParams: object, runtime: number): opInfo {
4848
const tensorsParams = {};
4949
const opParams = {};
5050
const tensorNames = [] as string[];
@@ -97,6 +97,9 @@ function getTensorParams(inputTensors: Tensor[], ownParams: [], fShaderParams: o
9797
if (fShaderParams['active_function']) {
9898
opParams['active_function'] = fShaderParams['active_function'];
9999
}
100+
101+
opParams['runtime'] = runtime;
102+
100103
return { textureParams: tensorsParams, opParams, active_function: fShaderParams['active_function'] };
101104
}
102105

packages/paddlejs-backend-webgl/src/webgl/buildShader.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ export default function buildShader(textureConf, type, inputTensors, fShaderPara
2020
const { params = {}, mainFunc, textureFuncConf = {}, commonFuncConf } = ops[opName];
2121

2222
// textureList: [filter, origin, bias]
23-
const { textureParams, opParams, active_function } = getTensorParams(inputTensors, params, fShaderParams);
23+
const { textureParams, opParams, active_function } = getTensorParams(
24+
inputTensors, params, fShaderParams, runtime
25+
);
2426

2527
const prefixCode = genPrefixCode(textureConf);
2628

0 commit comments

Comments
 (0)