22import sys
33import warnings
44from contextlib import contextmanager
5+ from copy import copy
56from functools import singledispatch
67from textwrap import dedent
78from typing import Union
1516from numba import types
1617from numba .core .errors import TypingError
1718from numba .cpython .unsafe .tuple import tuple_setitem # noqa: F401
18- from numba .extending import box
19+ from numba .extending import box , overload
1920
2021from pytensor import config
2122from pytensor .compile .builders import OpFromGraph
4748from pytensor .tensor .type_other import MakeSlice , NoneConst
4849
4950
51+ def global_numba_func (func ):
52+ """Use to return global numba functions in numba_funcify_*.
53+
54+ This allows tests to remove the compilation using mock.
55+ """
56+ return func
57+
58+
5059def numba_njit (* args , ** kwargs ):
5160
5261 kwargs = kwargs .copy ()
@@ -573,29 +582,36 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
573582 return numba_njit (incsubtensor_fn , boundscheck = True )
574583
575584
585+ @numba_njit (boundscheck = True )
586+ def advancedincsubtensor1_inplace_set (x , vals , idxs ):
587+ for idx , val in zip (idxs , vals ):
588+ x [idx ] = val
589+ return x
590+
591+
592+ @numba_njit (boundscheck = True )
593+ def advancedincsubtensor1_inplace_inc (x , vals , idxs ):
594+ for idx , val in zip (idxs , vals ):
595+ x [idx ] += val
596+ return x
597+
598+
576599@numba_funcify .register (AdvancedIncSubtensor1 )
577600def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
578601 inplace = op .inplace
579602 set_instead_of_inc = op .set_instead_of_inc
580603
581604 if set_instead_of_inc :
582-
583- @numba_njit (boundscheck = True )
584- def advancedincsubtensor1_inplace (x , vals , idxs ):
585- for idx , val in zip (idxs , vals ):
586- x [idx ] = val
587- return x
588-
605+ advancedincsubtensor1_inplace = global_numba_func (
606+ advancedincsubtensor1_inplace_set
607+ )
589608 else :
590-
591- @numba_njit (boundscheck = True )
592- def advancedincsubtensor1_inplace (x , vals , idxs ):
593- for idx , val in zip (idxs , vals ):
594- x [idx ] += val
595- return x
609+ advancedincsubtensor1_inplace = global_numba_func (
610+ advancedincsubtensor1_inplace_inc
611+ )
596612
597613 if inplace :
598- return advancedincsubtensor1_inplace
614+ return global_numba_func ( advancedincsubtensor1_inplace )
599615 else :
600616
601617 @numba_njit
@@ -606,51 +622,48 @@ def advancedincsubtensor1(x, vals, idxs):
606622 return advancedincsubtensor1
607623
608624
609- @ numba_funcify . register ( DeepCopyOp )
610- def numba_funcify_DeepCopyOp ( op , node , ** kwargs ):
625+ def deepcopyop ( x ):
626+ return copy ( x )
611627
612- # Scalars are apparently returned as actual Python scalar types and not
613- # NumPy scalars, so we need two separate Numba functions for each case.
614628
615- # The type can also be RandomType with no ndims
616- if not hasattr (node .outputs [0 ].type , "ndim" ) or node .outputs [0 ].type .ndim == 0 :
617- # TODO: Do we really need to compile a pass-through function like this?
618- @numba_njit (inline = "always" )
619- def deepcopyop (x ):
620- return x
629+ @overload (deepcopyop )
630+ def dispatch_deepcopyop (x ):
631+ if isinstance (x , types .Array ):
632+ return lambda x : np .copy (x )
621633
622- else :
634+ return lambda x : x
623635
624- @numba_njit (inline = "always" )
625- def deepcopyop (x ):
626- return x .copy ()
627636
637+ @numba_funcify .register (DeepCopyOp )
638+ def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
628639 return deepcopyop
629640
630641
642+ @numba_njit
643+ def makeslice (* x ):
644+ return slice (* x )
645+
646+
631647@numba_funcify .register (MakeSlice )
632648def numba_funcify_MakeSlice (op , ** kwargs ):
633- @numba_njit
634- def makeslice (* x ):
635- return slice (* x )
649+ return global_numba_func (makeslice )
636650
637- return makeslice
651+
652+ @numba_njit
653+ def shape (x ):
654+ return np .asarray (np .shape (x ))
638655
639656
640657@numba_funcify .register (Shape )
641658def numba_funcify_Shape (op , ** kwargs ):
642- @numba_njit (inline = "always" )
643- def shape (x ):
644- return np .asarray (np .shape (x ))
645-
646- return shape
659+ return global_numba_func (shape )
647660
648661
649662@numba_funcify .register (Shape_i )
650663def numba_funcify_Shape_i (op , ** kwargs ):
651664 i = op .i
652665
653- @numba_njit ( inline = "always" )
666+ @numba_njit
654667 def shape_i (x ):
655668 return np .shape (x )[i ]
656669
@@ -683,13 +696,13 @@ def numba_funcify_Reshape(op, **kwargs):
683696
684697 if ndim == 0 :
685698
686- @numba_njit ( inline = "always" )
699+ @numba_njit
687700 def reshape (x , shape ):
688701 return x .item ()
689702
690703 else :
691704
692- @numba_njit ( inline = "always" )
705+ @numba_njit
693706 def reshape (x , shape ):
694707 # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
695708 return np .reshape (
@@ -732,15 +745,15 @@ def int_to_float_fn(inputs, out_dtype):
732745
733746 args_dtype = np .dtype (f"f{ out_dtype .itemsize } " )
734747
735- @numba_njit ( inline = "always" )
748+ @numba_njit
736749 def inputs_cast (x ):
737750 return x .astype (args_dtype )
738751
739752 else :
740753 args_dtype_sz = max (_arg .type .numpy_dtype .itemsize for _arg in inputs )
741754 args_dtype = np .dtype (f"f{ args_dtype_sz } " )
742755
743- @numba_njit ( inline = "always" )
756+ @numba_njit
744757 def inputs_cast (x ):
745758 return x .astype (args_dtype )
746759
@@ -755,7 +768,7 @@ def numba_funcify_Dot(op, node, **kwargs):
755768 out_dtype = node .outputs [0 ].type .numpy_dtype
756769 inputs_cast = int_to_float_fn (node .inputs , out_dtype )
757770
758- @numba_njit ( inline = "always" )
771+ @numba_njit
759772 def dot (x , y ):
760773 return np .asarray (np .dot (inputs_cast (x ), inputs_cast (y ))).astype (out_dtype )
761774
@@ -770,13 +783,14 @@ def numba_funcify_Softplus(op, node, **kwargs):
770783 @numba_njit
771784 def softplus (x ):
772785 if x < - 37.0 :
773- return direct_cast ( np .exp (x ), x_dtype )
786+ value = np .exp (x )
774787 elif x < 18.0 :
775- return direct_cast ( np .log1p (np .exp (x )), x_dtype )
788+ value = np .log1p (np .exp (x ))
776789 elif x < 33.3 :
777- return direct_cast ( x + np .exp (- x ), x_dtype )
790+ value = x + np .exp (- x )
778791 else :
779- return direct_cast (x , x_dtype )
792+ value = x
793+ return direct_cast (value , x_dtype )
780794
781795 return softplus
782796
@@ -791,7 +805,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
791805
792806 inputs_cast = int_to_float_fn (node .inputs , out_dtype )
793807
794- @numba_njit ( inline = "always" )
808+ @numba_njit
795809 def cholesky (a ):
796810 return np .linalg .cholesky (inputs_cast (a )).astype (out_dtype )
797811
@@ -852,7 +866,7 @@ def solve(a, b):
852866 out_dtype = node .outputs [0 ].type .numpy_dtype
853867 inputs_cast = int_to_float_fn (node .inputs , out_dtype )
854868
855- @numba_njit ( inline = "always" )
869+ @numba_njit
856870 def solve (a , b ):
857871 return np .linalg .solve (
858872 inputs_cast (a ),
0 commit comments