Skip to content

Commit ef6200a

Browse files
committed
feat(paddlejs-backend-webgl): add hard_swish、nearest_interp op
1 parent 8c38a04 commit ef6200a

File tree

6 files changed

+152
-0
lines changed

6 files changed

+152
-0
lines changed

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ import connect from './shader/connect';
4242
import squeeze2 from './shader/squeeze2';
4343
import pad3d from './shader/pad3d';
4444
import reduce_mean from './shader/reduce_mean';
45+
import hard_swish from './shader/hard_swish';
46+
import nearest_interp from './shader/nearest_interp';
47+
import nearest_interp_v2 from './shader/nearest_interp_v2';
4548

4649
const ops = {
4750
arg_max,
@@ -79,6 +82,9 @@ const ops = {
7982
where,
8083
connect,
8184
reduce_mean,
85+
hard_swish,
86+
nearest_interp,
87+
nearest_interp_v2,
8288
prelu: dynamic('prelu'),
8389
relu6: dynamic('relu6'),
8490
leakyRelu: dynamic('leakyRelu'),
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* @file hard_swish
3+
*/
4+
5+
function mainFunc(
6+
{},
7+
{
8+
offset = 3.0,
9+
scale = 6.0,
10+
threshold = 6.0
11+
}
12+
) {
13+
return `
14+
void main(void) {
15+
// 输出数据
16+
ivec4 oPos = getOutputTensorPos();
17+
float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);
18+
float res = o * min(max(0.0, o + float(${offset})), float(${threshold})) / float(${scale});
19+
setOutput(res);
20+
}
21+
`;
22+
}
23+
export default {
24+
mainFunc,
25+
params: [
26+
'offset',
27+
'scale',
28+
'threshold'
29+
],
30+
textureFuncConf: {
31+
origin: ['getValueFromTensorPos']
32+
}
33+
};
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* @file nearest_interp
3+
*/
4+
5+
function mainFunc(
6+
{ origin, out },
7+
{ align_corners }
8+
) {
9+
10+
return `
11+
// start函数
12+
int getData(float n, float scale, bool align_corners) {
13+
float m = align_corners ? (n / scale + 0.5) : (n / scale);
14+
return int(floor(m));
15+
}
16+
17+
void main(void) {
18+
// 输出数据
19+
ivec4 oPos = getOutputTensorPos();
20+
21+
float scale_x = 0.0;
22+
float scale_y = 0.0;
23+
if (${align_corners}) {
24+
scale_x = float(${out.width_shape} -1) / float(${origin.width_shape} - 1);
25+
scale_y = float(${out.height_shape} - 1) / float(${origin.height_shape} - 1);
26+
}
27+
else {
28+
scale_x = float(${out.width_shape}) / float(${origin.width_shape});
29+
scale_y = float(${out.height_shape}) / float(${origin.height_shape});
30+
}
31+
32+
int vx = getData(float(oPos.a), scale_x, ${align_corners});
33+
int vy = getData(float(oPos.b), scale_y, ${align_corners});
34+
35+
float o = getValueFromTensorPos_origin(oPos.r, oPos.g, vy, vx);
36+
setOutput(float(o));
37+
}
38+
`;
39+
}
40+
41+
export default {
42+
mainFunc,
43+
params: [
44+
'align_corners'
45+
],
46+
textureFuncConf: {
47+
origin: ['getValueFromTensorPos']
48+
},
49+
commonFuncConf: ['transferFromNHWCtoNCHW']
50+
};
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/**
2+
* @file nearest_interp_v2
3+
*/
4+
5+
import nearest_interp from './nearest_interp';
6+
7+
export default nearest_interp;
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"ops": [
3+
{
4+
"attrs": {
5+
"offset": 3,
6+
"scale": 6,
7+
"threshold": 6
8+
},
9+
"inputs": {
10+
"X": ["hard_swish.tmp_0"]
11+
},
12+
"outputs": {
13+
"Out": ["hard_swish.tmp_1"]
14+
},
15+
"type": "hard_swish"
16+
}
17+
],
18+
"vars": [
19+
{
20+
"data": [1, 2, 3, 4],
21+
"name": "hard_swish.tmp_0",
22+
"shape": [1, 4]
23+
},
24+
{
25+
"name": "hard_swish.tmp_1",
26+
"shape": [1, 4]
27+
}
28+
]
29+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"ops": [
3+
{
4+
"attrs": {
5+
"align_corners": false
6+
},
7+
"inputs": {
8+
"X": ["nearest_interp.tmp_0"]
9+
},
10+
"outputs": {
11+
"Out": ["nearest_interp.tmp_1"]
12+
},
13+
"type": "nearest_interp"
14+
}
15+
],
16+
"vars": [
17+
{
18+
"data": [2, 3, 6, 10],
19+
"name": "nearest_interp.tmp_0",
20+
"shape": [1, 4]
21+
},
22+
{
23+
"name": "nearest_interp.tmp_1",
24+
"shape": [1, 4]
25+
}
26+
]
27+
}

0 commit comments

Comments
 (0)