Skip to content

Commit 7fa7a28

Browse files
authored
Merge pull request #132 from changy1105/reduce_mean
feat(paddlejs-backend-webgl): add reduce_mean op
2 parents 901ce8a + a9b98ca commit 7fa7a28

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import where from './shader/where';
4040
import connect from './shader/connect';
4141
import squeeze2 from './shader/squeeze2';
4242
import pad3d from './shader/pad3d';
43+
import reduce_mean from './shader/reduce_mean';
4344

4445
const ops = {
4546
arg_max,
@@ -76,6 +77,7 @@ const ops = {
7677
reduce_sum,
7778
where,
7879
connect,
80+
reduce_mean,
7981
prelu: dynamic('prelu'),
8082
relu6: dynamic('relu6'),
8183
leakyRelu: dynamic('leakyRelu'),
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
/**
3+
* @file reduce_mean
4+
*/
5+
6+
function mainFunc(
7+
{},
8+
{ inputs_dim, dim }
9+
) {
10+
return `
11+
// start函数
12+
void main(void) {
13+
ivec4 oPos = getOutputTensorPos();
14+
// 输出坐标转换为输入坐标
15+
float o = 0.0;
16+
for (int i = 0; i < ${inputs_dim}; i++) {
17+
oPos[${dim}] = i;
18+
o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);
19+
}
20+
o = o / float(${inputs_dim});
21+
setOutput(o);
22+
}
23+
`;
24+
}
25+
export default {
26+
mainFunc,
27+
params: [
28+
'inputs_dim',
29+
'dim'
30+
],
31+
textureFuncConf: {
32+
origin: ['getValueFromTensorPos']
33+
},
34+
behaviors: [
35+
'normalizeDim'
36+
]
37+
};
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
{
2+
"ops": [
3+
{
4+
"attrs": {
5+
"op_device": ""
6+
},
7+
"inputs": {
8+
"X": [
9+
"feed"
10+
]
11+
},
12+
"outputs": {
13+
"Out": [
14+
"mean_0.tmp_0"
15+
]
16+
},
17+
"type": "feed"
18+
},
19+
{
20+
"attrs": {
21+
"axis": 0
22+
},
23+
"inputs": {
24+
"X": ["mean_0.tmp_0"]
25+
},
26+
"outputs": {
27+
"Out": ["mean_0.tmp_out"]
28+
},
29+
"type": "reduce_mean"
30+
},
31+
{
32+
"attrs": {
33+
"op_device": ""
34+
},
35+
"inputs": {
36+
"X": [
37+
"mean_0.tmp_out"
38+
]
39+
},
40+
"outputs": {
41+
"Out": [
42+
"fetch"
43+
]
44+
},
45+
"type": "fetch"
46+
}
47+
],
48+
"vars": [
49+
{
50+
"data": [1, 2, 3, 4, 5, 6, 7, 8],
51+
"name": "mean_0.tmp_0",
52+
"shape": [2, 4]
53+
},
54+
{
55+
"data":[3, 4, 5, 6],
56+
"name": "mean_0.tmp_out",
57+
"shape": [1, 4]
58+
}
59+
]
60+
}

0 commit comments

Comments
 (0)