Skip to content

Commit da01409

Browse files
authored
Merge pull request #158 from changy1105/op
feat(paddlejs-backend-webgl): add elementwise_pow, elementwise_sub, t…
2 parents 7564935 + 322e188 commit da01409

File tree

14 files changed

+659
-63
lines changed

14 files changed

+659
-63
lines changed

packages/paddlejs-backend-webgl/src/ops/atom/common_func.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,17 @@ const sqrt = `
6565
}
6666
`;
6767

68-
const pow = `
69-
float pow(float x, float factor, float offset) {
68+
const pow_func = `
69+
float pow_func(float x, float factor, float offset) {
7070
return pow(x, factor);
7171
}
7272
`;
7373

74+
const tanh_func = `
75+
float tanh_func(float x, float y, float z) {
76+
return tanh(x);
77+
}`;
78+
7479
export {
7580
prelu,
7681
relu6,
@@ -80,7 +85,8 @@ export {
8085
hardSigmoid,
8186
scaleWidthBias,
8287
sqrt,
83-
pow,
88+
pow_func,
89+
tanh_func,
8490
transferFromNHWCtoNCHW
8591
};
8692

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ import shuffle_channel from './shader/shuffle_channel';
4646
import hard_swish from './shader/hard_swish';
4747
import nearest_interp from './shader/nearest_interp';
4848
import nearest_interp_v2 from './shader/nearest_interp_v2';
49+
import elementwise_pow from './shader/elementwise_pow';
50+
import elementwise_sub from './shader/elementwise_sub';
51+
import cast from './shader/cast';
4952

5053
const ops = {
5154
arg_max,
@@ -62,6 +65,8 @@ const ops = {
6265
elementwise_add,
6366
elementwise_mul,
6467
elementwise_div,
68+
elementwise_pow,
69+
elementwise_sub,
6570
mul,
6671
matmul,
6772
fc,
@@ -86,6 +91,7 @@ const ops = {
8691
hard_swish,
8792
nearest_interp,
8893
nearest_interp_v2,
94+
cast,
8995
prelu: dynamic('prelu'),
9096
relu6: dynamic('relu6'),
9197
leakyRelu: dynamic('leakyRelu'),
@@ -95,6 +101,7 @@ const ops = {
95101
hard_sigmoid: dynamic('hard_sigmoid'),
96102
pow: dynamic('pow'),
97103
sqrt: dynamic('sqrt'),
104+
tanh: dynamic('tanh'),
98105
squeeze2,
99106
pad3d,
100107
bilinear_interp_v2,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/**
2+
* @file cast 该OP将 x 的数据类型转换为 dtype 并输出
3+
*/
4+
5+
/**
6+
* data_type 值映射关系
7+
* BOOL = 0;
8+
* INT16 = 1;
9+
* INT32 = 2;
10+
* INT64 = 3;
11+
* FP16 = 4;
12+
* FP32 = 5;
13+
* FP64 = 6;
14+
*/
15+
16+
function mainFunc(
17+
{},
18+
{ out_dtype }
19+
) {
20+
21+
let middleStr = '';
22+
switch (out_dtype) {
23+
case 0:
24+
middleStr = `
25+
float res_bool = 0.0;
26+
if (o != 0.0) {
27+
res_bool = 1.0;
28+
}
29+
setOutput(res_bool);`;
30+
break;
31+
32+
case 1:
33+
case 2:
34+
case 3:
35+
middleStr = `
36+
int res_int = int(o);
37+
setOutput(float(res_int));`;
38+
break;
39+
40+
default:
41+
middleStr = `
42+
float res_float = o;
43+
setOutput(res_float);`;
44+
}
45+
return `
46+
void main() {
47+
// 输出数据
48+
ivec4 oPos = getOutputTensorPos();
49+
float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);
50+
${middleStr}
51+
}
52+
`;
53+
}
54+
export default {
55+
mainFunc,
56+
params: [
57+
'out_dtype'
58+
],
59+
textureFuncConf: {
60+
origin: ['getValueFromTensorPos']
61+
}
62+
};

packages/paddlejs-backend-webgl/src/ops/shader/dynamic.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ const commonFuncBehaviors = {
1010
scale: ['transToScale'],
1111
sigmoid: ['transToSigmoid'],
1212
hard_sigmoid: ['transToHardSigmoid'],
13-
pow: ['transToPow']
13+
pow: ['transToPow'],
14+
sqrt: ['transToSqrt'],
15+
tanh: ['transToTanh']
1416
};
1517

1618
function mainFunc(
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* @file elementwise_pow 逐元素对输入Tensor进行幂操作
3+
*/
4+
5+
function mainFunc(
6+
{},
7+
{
8+
counterPos,
9+
Scale_y = 1.0,
10+
Scale_x = 1.0,
11+
Scale_out = 1.0
12+
}
13+
) {
14+
return `
15+
void main(void) {
16+
// 输出数据
17+
ivec4 oPos = getOutputTensorPos();
18+
float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);
19+
20+
float c = getValueFromTensorPos_counter(${counterPos});
21+
float res = pow(float(${Scale_out / Scale_x}) * o, float(${Scale_out / Scale_y}) * c);
22+
setOutput(float(res));
23+
}
24+
`;
25+
}
26+
export default {
27+
mainFunc,
28+
params: [
29+
'Scale_y',
30+
'Scale_x',
31+
'Scale_out',
32+
'counterPos'
33+
],
34+
textureFuncConf: {
35+
counter: ['getValueFromTensorPos'],
36+
origin: ['getValueFromTensorPos']
37+
},
38+
behaviors: [
39+
'processAxis',
40+
'genElementwiseCounterPos'
41+
]
42+
};
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* @file elementwise_sub 逐元素相减算子
3+
*/
4+
5+
function mainFunc(
6+
{},
7+
{
8+
counterPos,
9+
Scale_y = 1.0,
10+
Scale_x = 1.0,
11+
Scale_out = 1.0
12+
}
13+
) {
14+
return `
15+
void main(void) {
16+
// 输出数据
17+
ivec4 oPos = getOutputTensorPos();
18+
float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);
19+
20+
float c = getValueFromTensorPos_counter(${counterPos});
21+
float res = float(${Scale_out / Scale_x}) * o - float(${Scale_out / Scale_y}) * c;
22+
setOutput(float(res));
23+
}
24+
`;
25+
}
26+
export default {
27+
mainFunc,
28+
params: [
29+
'Scale_y',
30+
'Scale_x',
31+
'Scale_out',
32+
'counterPos'
33+
],
34+
textureFuncConf: {
35+
counter: ['getValueFromTensorPos'],
36+
origin: ['getValueFromTensorPos']
37+
},
38+
behaviors: [
39+
'processAxis',
40+
'genElementwiseCounterPos'
41+
]
42+
};
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"ops": [
3+
{
4+
"attrs": {
5+
"in_dtype": 3,
6+
"op_device": "",
7+
"out_dtype": 3
8+
},
9+
"inputs": {
10+
"X": [
11+
"numel_0.tmp_0"
12+
]
13+
},
14+
"outputs": {
15+
"Out": [
16+
"cast_0.tmp_0"
17+
]
18+
},
19+
"type": "cast"
20+
}
21+
],
22+
"vars": [
23+
{
24+
"data": [
25+
524288.0, 1, 0, -1, 0.9, 1.1
26+
],
27+
"name": "numel_0.tmp_0",
28+
"shape": [
29+
6
30+
]
31+
},
32+
{
33+
"name": "cast_0.tmp_0",
34+
"shape": [
35+
6
36+
]
37+
}
38+
]
39+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
{
2+
"ops": [
3+
{
4+
"attrs": {
5+
"axis": 1,
6+
"Scale_y": 1,
7+
"use_mkldnn": false,
8+
"x_data_format": "",
9+
"y_data_format": ""
10+
},
11+
"inputs": {
12+
"X": [
13+
"fc_0.tmp_0"
14+
],
15+
"Y": [
16+
"fc7_offset"
17+
]
18+
},
19+
"outputs": {
20+
"Out": [
21+
"fc_0.tmp_1"
22+
]
23+
},
24+
"type": "elementwise_pow"
25+
}
26+
],
27+
"vars": [
28+
{
29+
"data": [
30+
2,
31+
3,
32+
4
33+
],
34+
"name": "fc_0.tmp_0",
35+
"shape": [
36+
3
37+
]
38+
},
39+
{
40+
"data": [
41+
1,
42+
5,
43+
2
44+
],
45+
"name": "fc7_offset",
46+
"shape": [
47+
3
48+
]
49+
},
50+
{
51+
"data": [
52+
2,
53+
243,
54+
16
55+
],
56+
"name": "fc_0.tmp_1",
57+
"shape": [
58+
3
59+
]
60+
}
61+
]
62+
}

0 commit comments

Comments
 (0)