File tree Expand file tree Collapse file tree 2 files changed +84
-1
lines changed
packages/paddlejs-backend-webgl/src/ops Expand file tree Collapse file tree 2 files changed +84
-1
lines changed Original file line number Diff line number Diff line change @@ -42,6 +42,7 @@ import connect from './shader/connect';
4242import squeeze2 from './shader/squeeze2' ;
4343import pad3d from './shader/pad3d' ;
4444import reduce_mean from './shader/reduce_mean' ;
45+ import shuffle_channel from './shader/shuffle_channel' ;
4546import hard_swish from './shader/hard_swish' ;
4647import nearest_interp from './shader/nearest_interp' ;
4748import nearest_interp_v2 from './shader/nearest_interp_v2' ;
@@ -96,7 +97,8 @@ const ops = {
9697 sqrt : dynamic ( 'sqrt' ) ,
9798 squeeze2,
9899 pad3d,
99- bilinear_interp_v2
100+ bilinear_interp_v2,
101+ shuffle_channel
100102} ;
101103export {
102104 ops
Original file line number Diff line number Diff line change 1+ /**
2+ * @file shuffle_channel
3+ * @description reshape2 transpose2 reshape2
4+ */
5+
6+
7+ function mainFunc (
8+ {
9+ out
10+ } ,
11+ {
12+ group = 2
13+ }
14+ ) {
15+ const { total_shape, height_shape, width_shape, channel } = out ;
16+ const channels_per_group = channel / group ;
17+
18+ const [
19+ perm_0 ,
20+ perm_1 ,
21+ perm_2 ,
22+ perm_3
23+ ] = [ 1 , 0 , 2 , 3 ] ;
24+
25+ return `
26+ // start函数
27+ void main(void) {
28+ // 输出数据
29+ ivec4 oPos = getOutputTensorPos();
30+ float o = 0.0;
31+
32+ int sumVal = oPos.a
33+ + oPos.b * ${ width_shape }
34+ + oPos.g * ${ height_shape } * ${ width_shape }
35+ + oPos.r * ${ channel } * ${ width_shape } * ${ height_shape } ;
36+
37+ ivec4 transpose_out_pos = transferFromNHWCtoNCHW(
38+ sumVal,
39+ ${ group } ,
40+ ${ width_shape } ,
41+ ${ height_shape } ,
42+ ${ total_shape }
43+ );
44+
45+ ivec4 transpose_in_pos = ivec4(transpose_out_pos[${ perm_0 } ],
46+ transpose_out_pos[${ perm_1 } ], transpose_out_pos[${ perm_2 } ], transpose_out_pos[${ perm_3 } ]);
47+ int sumVal2 = transpose_in_pos.a
48+ + transpose_in_pos.b * ${ width_shape }
49+ + transpose_in_pos.g * ${ height_shape } * ${ width_shape }
50+ + transpose_in_pos.r * ${ channels_per_group } * ${ width_shape } * ${ height_shape } ;
51+ ivec4 origin_oPos = transferFromNHWCtoNCHW(
52+ sumVal2,
53+ ${ channel } ,
54+ ${ width_shape } ,
55+ ${ height_shape } ,
56+ ${ total_shape }
57+ );
58+
59+
60+ o = getValueFromTensorPos_origin(
61+ origin_oPos[0],
62+ origin_oPos[1],
63+ origin_oPos[2],
64+ origin_oPos[3]
65+ );
66+
67+ setOutput(float(o));
68+ }
69+ ` ;
70+ }
71+
72+ export default {
73+ mainFunc,
74+ params : [
75+ 'group'
76+ ] ,
77+ textureFuncConf : {
78+ origin : [ 'getValueFromTensorPos' ]
79+ } ,
80+ commonFuncConf : [ 'transferFromNHWCtoNCHW' ]
81+ } ;
You can’t perform that action at this time.
0 commit comments