22 * @file FormatInputsX
33 */
44
5+ import env from '../env' ;
56import Transformer from './transformer' ;
67
78export default class FormatInputsX extends Transformer {
@@ -28,8 +29,30 @@ export default class FormatInputsX extends Transformer {
2829 return ;
2930 }
3031
31- const inputsX = inputs . X || inputs . Input ;
3232 // 兼容key为X,value是个长度大于1的数组的情况,如concat
33+ const inputsX = inputs . X || inputs . Input ;
34+
35+ // wasm backend is not support any number of inputs, retain temporarily
36+ if ( env . get ( 'backend' ) === 'wasm' ) {
37+ if ( inputsX . length > 4 ) {
38+ throw Error ( 'Not yet supporting concat input tensors more than 4.' ) ;
39+ }
40+ if ( inputsX . length > 1 ) {
41+ // 兼容key为X,value是个长度大于1的数组的情况,如concat
42+ const [ x_name , y_name , z_name , m_name ] = inputsX ;
43+ inputs [ 'X' ] = [ x_name ] ;
44+ y_name && ( inputs [ 'Y' ] = [ y_name ] ) ;
45+ if ( z_name ) {
46+ inputs [ 'Z' ] = [ z_name ] ;
47+ originOp . type += '_mul' ;
48+ }
49+ if ( m_name ) {
50+ inputs [ 'M' ] = [ m_name ] ;
51+ }
52+ }
53+ return ;
54+ }
55+
3356 if ( inputsX . length > 1 ) {
3457 inputsX . forEach ( ( item , index ) => {
3558 inputs [ `origin${ index > 0 ? `_${ index } ` : '' } ` ] = [ item ] ;
0 commit comments