@@ -539,44 +539,54 @@ def numba_funcify_DimShuffle(op, **kwargs):
539539
540540 ndim_new_shape = len (shuffle ) + len (augment )
541541
542+ no_transpose = all (i == j for i , j in enumerate (transposition ))
543+ if no_transpose :
544+
545+ @numba_basic .numba_njit
546+ def transpose (x ):
547+ return x
548+
549+ else :
550+
551+ @numba_basic .numba_njit
552+ def transpose (x ):
553+ return np .transpose (x , transposition )
554+
555+ shape_template = (1 ,) * ndim_new_shape
556+
557+ # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
558+ # is typed as `getitem(Tuple(), int)`, which has no implementation
559+ # (since getting an item from an empty sequence doesn't make sense).
560+ # To avoid this compile-time error, we omit the expression altogether.
542561 if len (shuffle ) > 0 :
543562
544563 @numba_basic .numba_njit
545- def populate_new_shape (i , j , new_shape , shuffle_shape ):
546- if i in augment :
547- new_shape = numba_basic .tuple_setitem (new_shape , i , 1 )
548- return j , new_shape
549- else :
550- new_shape = numba_basic .tuple_setitem (new_shape , i , shuffle_shape [j ])
551- return j + 1 , new_shape
564+ def find_shape (array_shape ):
565+ shape = shape_template
566+ j = 0
567+ for i in range (ndim_new_shape ):
568+ if i not in augment :
569+ length = array_shape [j ]
570+ shape = numba_basic .tuple_setitem (shape , i , length )
571+ j = j + 1
572+ return shape
552573
553574 else :
554- # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
555- # is typed as `getitem(Tuple(), int)`, which has no implementation
556- # (since getting an item from an empty sequence doesn't make sense).
557- # To avoid this compile-time error, we omit the expression altogether.
558- @numba_basic .numba_njit (inline = "always" )
559- def populate_new_shape (i , j , new_shape , shuffle_shape ):
560- return j , numba_basic .tuple_setitem (new_shape , i , 1 )
575+
576+ @numba_basic .numba_njit
577+ def find_shape (array_shape ):
578+ return shape_template
561579
562580 if ndim_new_shape > 0 :
563- create_zeros_tuple = numba_basic .create_tuple_creator (
564- lambda _ : 0 , ndim_new_shape
565- )
566581
567582 @numba_basic .numba_njit
568583 def dimshuffle_inner (x , shuffle ):
569- res = np .transpose (x , transposition )
570- shuffle_shape = res .shape [: len (shuffle )]
571-
572- new_shape = create_zeros_tuple ()
573-
574- j = 0
575- for i in range (len (new_shape )):
576- j , new_shape = populate_new_shape (i , j , new_shape , shuffle_shape )
584+ x = transpose (x )
585+ shuffle_shape = x .shape [: len (shuffle )]
586+ new_shape = find_shape (shuffle_shape )
577587
578588 # FIXME: Numba's `array.reshape` only accepts C arrays.
579- res_reshape = np .reshape (np .ascontiguousarray (res ), new_shape )
589+ res_reshape = np .reshape (np .ascontiguousarray (x ), new_shape )
580590
581591 if not inplace :
582592 return res_reshape .copy ()
0 commit comments