Skip to content

Commit ee2f3b0

Browse files
Merge pull request #417 from yueshuangyan/master
fix: compatible with opBehavior and transform for wasm
2 parents 2a9f390 + fe85969 commit ee2f3b0

File tree

7 files changed

+56
-9
lines changed

7 files changed

+56
-9
lines changed

packages/paddlejs-backend-cpu/src/ops/elementwise_add.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class Attrs {
7575
}
7676

7777
const behaviors = [
78-
'processAxis'
78+
'processElementwiseAxis'
7979
];
8080

8181
const inputsName = [

packages/paddlejs-backend-wasm/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@paddlejs/paddlejs-backend-wasm",
3-
"version": "1.0.2",
3+
"version": "1.0.3",
44
"description": "",
55
"main": "lib/index",
66
"scripts": {

packages/paddlejs-backend-wasm/src/ops.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,25 +60,25 @@ export default {
6060
},
6161
elementwise_add: {
6262
behaviors: [
63-
'processAxis',
63+
'processElementwiseAxis',
6464
'genElementwiseCounterPos'
6565
]
6666
},
6767
elementwise_div: {
6868
behaviors: [
69-
'processAxis',
69+
'processElementwiseAxis',
7070
'genElementwiseCounterPos'
7171
]
7272
},
7373
elementwise_mul: {
7474
behaviors: [
75-
'processAxis',
75+
'processElementwiseAxis',
7676
'genElementwiseCounterPos'
7777
]
7878
},
7979
elementwise_sub: {
8080
behaviors: [
81-
'processAxis',
81+
'processElementwiseAxis',
8282
'genElementwiseCounterPos'
8383
]
8484
},

packages/paddlejs-backend-webgpu/src/ops/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ export const ops = {
7575
main: elementwise_add_main,
7676
deps: elementwise_add_deps,
7777
behaviors: [
78-
'processAxis'
78+
'processElementwiseAxis'
7979
]
8080
},
8181
split: {

packages/paddlejs-core/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@paddlejs/paddlejs-core",
3-
"version": "2.1.18",
3+
"version": "2.1.19",
44
"description": "",
55
"main": "lib/index",
66
"scripts": {

packages/paddlejs-core/src/opFactory/opBehaviors.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import env from '../env';
12
import { OpData } from '../commons/interface';
23
import * as Utils from './utils';
34

@@ -219,6 +220,29 @@ const behaviors : Behaviors = {
219220
this.processedAttrs.num = Object.values(this.tensorDataMap)
220221
.filter(item => item.tensorName === 'out').length || 1;
221222
}
223+
224+
// wasm backend is not support any number of inputs, retain temporarily
225+
if (env.get('backend') === 'wasm') {
226+
this.processedAttrs.fourInputs = false;
227+
228+
const counter = this.tensorDataMap['counter'];
229+
if (counter) {
230+
const yShape = Utils.formatShape(counter.shape);
231+
this.processedAttrs.counter_num = yShape[axis];
232+
}
233+
const appender = this.tensorDataMap['appender'];
234+
if (appender) {
235+
const zShape = Utils.formatShape(appender.shape);
236+
this.processedAttrs.append_num = zShape[axis];
237+
}
238+
239+
const fourth = this.tensorDataMap['fourth'];
240+
if (fourth) {
241+
this.processedAttrs.fourInputs = true;
242+
const mShape = Utils.formatShape(fourth.shape);
243+
this.processedAttrs.fourth_num = mShape[axis];
244+
}
245+
}
222246
},
223247

224248
processElementwiseAxis() {

packages/paddlejs-core/src/transform/formatInputsX.ts

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
* @file FormatInputsX
33
*/
44

5+
import env from '../env';
56
import Transformer from './transformer';
67

78
export default class FormatInputsX extends Transformer {
@@ -28,8 +29,30 @@ export default class FormatInputsX extends Transformer {
2829
return;
2930
}
3031

31-
const inputsX = inputs.X || inputs.Input;
3232
// 兼容key为X,value是个长度大于1的数组的情况,如concat
33+
const inputsX = inputs.X || inputs.Input;
34+
35+
// wasm backend is not support any number of inputs, retain temporarily
36+
if (env.get('backend') === 'wasm') {
37+
if (inputsX.length > 4) {
38+
throw Error('Not yet supporting concat input tensors more than 4.');
39+
}
40+
if (inputsX.length > 1) {
41+
// 兼容key为X,value是个长度大于1的数组的情况,如concat
42+
const [x_name, y_name, z_name, m_name] = inputsX;
43+
inputs['X'] = [x_name];
44+
y_name && (inputs['Y'] = [y_name]);
45+
if (z_name) {
46+
inputs['Z'] = [z_name];
47+
originOp.type += '_mul';
48+
}
49+
if (m_name) {
50+
inputs['M'] = [m_name];
51+
}
52+
}
53+
return;
54+
}
55+
3356
if (inputsX.length > 1) {
3457
inputsX.forEach((item, index) => {
3558
inputs[`origin${index > 0 ? `_${index}` : ''}`] = [item];

0 commit comments

Comments
 (0)