@@ -53,10 +53,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
5353 out .coords [dim ] = newdim , ds [dim ].data , ds [dim ].attrs
5454 return out
5555
56- def _maybe_stack_batch_dims (ds , input_dims , stacked_dim_name = 'sample' ):
56+ def _maybe_stack_batch_dims (ds , input_dims , squeeze_batch_dim , stacked_dim_name = 'sample' ):
5757 batch_dims = [d for d in ds .dims if d not in input_dims ]
5858 if len (batch_dims ) < 2 :
59- return ds .expand_dims (stacked_dim_name , 0 )
59+ if (squeeze_batch_dim ):
60+ return ds
61+ else :
62+ return ds .expand_dims (stacked_dim_name , 0 )
6063 ds_stack = ds .stack (** {stacked_dim_name : batch_dims })
6164 # ensure correct order
6265 dim_order = (stacked_dim_name ,) + tuple (input_dims )
@@ -89,6 +92,10 @@ class BatchGenerator:
8992 preload_batch : bool, optional
9093 If ``True``, each batch will be loaded into memory before reshaping /
9194 processing, triggering any dask arrays to be computed.
95+ squeeze_batch_dim : bool, optional
96+ If ``False", each batch's dataset will have a "batch" dimension of size 1
97+ prepended to the array. This functionality is useful for interoperability
98+ with Keras / Tensorflow.
9299
93100 Yields
94101 ------
@@ -104,6 +111,7 @@ def __init__(
104111 batch_dims = {},
105112 concat_input_dims = False ,
106113 preload_batch = True ,
114+ squeeze_batch_dim = True
107115 ):
108116
109117 self .ds = _as_xarray_dataset (ds )
@@ -113,6 +121,7 @@ def __init__(
113121 self .batch_dims = OrderedDict (batch_dims )
114122 self .concat_input_dims = concat_input_dims
115123 self .preload_batch = preload_batch
124+ self .squeeze_batch_dim = squeeze_batch_dim
116125
117126 def __iter__ (self ):
118127 for ds_batch in self ._iterate_batch_dims (self .ds ):
@@ -131,11 +140,11 @@ def __iter__(self):
131140 new_input_dims = [
132141 dim + new_dim_suffix for dim in self .input_dims
133142 ]
134- yield _maybe_stack_batch_dims (dsc , new_input_dims )
143+ yield _maybe_stack_batch_dims (dsc , new_input_dims , self . squeeze_batch_dim )
135144 else :
136145 for ds_input in input_generator :
137146 yield _maybe_stack_batch_dims (
138- ds_input , list (self .input_dims )
147+ ds_input , list (self .input_dims ), self . squeeze_batch_dim
139148 )
140149
141150 def _iterate_batch_dims (self , ds ):
0 commit comments