11from __future__ import annotations
22
33import math
4+ from functools import partial
45
56from xarray .core import dtypes , nputils
67
@@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
7576 return coeffs , residuals
7677
7778
79+ def _fill_with_last_one (a , b ):
80+ import numpy as np
81+
82+ # cumreduction apply the push func over all the blocks first so,
83+ # the only missing part is filling the missing values using the
84+ # last data of the previous chunk
85+ return np .where (np .isnan (b ), a , b )
86+
87+
88+ def _dtype_push (a , axis , dtype = None ):
89+ from xarray .core .duck_array_ops import _push
90+
91+ # Not sure why the blelloch algorithm force to receive a dtype
92+ return _push (a , axis = axis )
93+
94+
95+ def _reset_cumsum (a , axis , dtype = None ):
96+ import numpy as np
97+
98+ cumsum = np .cumsum (a , axis = axis )
99+ reset_points = np .maximum .accumulate (np .where (a == 0 , cumsum , 0 ), axis = axis )
100+ return cumsum - reset_points
101+
102+
103+ def _last_reset_cumsum (a , axis , keepdims = None ):
104+ import numpy as np
105+
106+ # Take the last cumulative sum taking into account the reset
107+ # This is useful for blelloch method
108+ return np .take (_reset_cumsum (a , axis = axis ), axis = axis , indices = [- 1 ])
109+
110+
111+ def _combine_reset_cumsum (a , b , axis ):
112+ import numpy as np
113+
114+ # It is going to sum the previous result until the first
115+ # non nan value
116+ bitmask = np .cumprod (b != 0 , axis = axis )
117+ return np .where (bitmask , b + a , b )
118+
119+
78120def push (array , n , axis , method = "blelloch" ):
79121 """
80122 Dask-aware bottleneck.push
@@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"):
91133 # TODO: Replace all this function
92134 # once https://github.com/pydata/xarray/issues/9229 being implemented
93135
94- def _fill_with_last_one (a , b ):
95- # cumreduction apply the push func over all the blocks first so,
96- # the only missing part is filling the missing values using the
97- # last data of the previous chunk
98- return np .where (np .isnan (b ), a , b )
99-
100- def _dtype_push (a , axis , dtype = None ):
101- # Not sure why the blelloch algorithm force to receive a dtype
102- return _push (a , axis = axis )
103-
104136 pushed_array = da .reductions .cumreduction (
105137 func = _dtype_push ,
106138 binop = _fill_with_last_one ,
@@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None):
113145 )
114146
115147 if n is not None and 0 < n < array .shape [axis ] - 1 :
116-
117- def _reset_cumsum (a , axis , dtype = None ):
118- cumsum = np .cumsum (a , axis = axis )
119- reset_points = np .maximum .accumulate (np .where (a == 0 , cumsum , 0 ), axis = axis )
120- return cumsum - reset_points
121-
122- def _last_reset_cumsum (a , axis , keepdims = None ):
123- # Take the last cumulative sum taking into account the reset
124- # This is useful for blelloch method
125- return np .take (_reset_cumsum (a , axis = axis ), axis = axis , indices = [- 1 ])
126-
127- def _combine_reset_cumsum (a , b ):
128- # It is going to sum the previous result until the first
129- # non nan value
130- bitmask = np .cumprod (b != 0 , axis = axis )
131- return np .where (bitmask , b + a , b )
132-
133148 valid_positions = da .reductions .cumreduction (
134149 func = _reset_cumsum ,
135- binop = _combine_reset_cumsum ,
150+ binop = partial ( _combine_reset_cumsum , axis = axis ) ,
136151 ident = 0 ,
137152 x = da .isnan (array , dtype = int ),
138153 axis = axis ,
0 commit comments