@@ -18,6 +18,29 @@ def sample_ds_1d():
1818 return ds
1919
2020
21+ @pytest .fixture (scope = 'module' )
22+ def sample_ds_3d ():
23+ shape = (10 , 50 , 100 )
24+ ds = xr .Dataset (
25+ {
26+ 'foo' : (['time' , 'y' , 'x' ], np .random .rand (* shape )),
27+ 'bar' : (['time' , 'y' , 'x' ], np .random .randint (0 , 10 , shape )),
28+ },
29+ {
30+ 'x' : (['x' ], np .arange (shape [- 1 ])),
31+ 'y' : (['y' ], np .arange (shape [- 2 ])),
32+ },
33+ )
34+ return ds
35+
36+
37+ def test_constructor_coerces_to_dataset ():
38+ da = xr .DataArray (np .random .rand (10 ), dims = 'x' , name = 'foo' )
39+ bg = BatchGenerator (da , input_dims = {'x' : 2 })
40+ assert isinstance (bg .ds , xr .Dataset )
41+ assert bg .ds .equals (da .to_dataset ())
42+
43+
2144# TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension
2245# Should we enforce that each batch size always has to be the same
2346@pytest .mark .parametrize ('bsize' , [5 , 10 ])
@@ -86,22 +109,6 @@ def test_batch_1d_overlap(sample_ds_1d, olap):
86109 assert ds_batch .equals (ds_batch_expected )
87110
88111
89- @pytest .fixture (scope = 'module' )
90- def sample_ds_3d ():
91- shape = (10 , 50 , 100 )
92- ds = xr .Dataset (
93- {
94- 'foo' : (['time' , 'y' , 'x' ], np .random .rand (* shape )),
95- 'bar' : (['time' , 'y' , 'x' ], np .random .randint (0 , 10 , shape )),
96- },
97- {
98- 'x' : (['x' ], np .arange (shape [- 1 ])),
99- 'y' : (['y' ], np .arange (shape [- 2 ])),
100- },
101- )
102- return ds
103-
104-
105112@pytest .mark .parametrize ('bsize' , [5 , 10 ])
106113def test_batch_3d_1d_input (sample_ds_3d , bsize ):
107114
@@ -160,3 +167,25 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
160167 * (sample_ds_3d .dims ['y' ] // bsize )
161168 * sample_ds_3d .dims ['time' ]
162169 )
170+
171+
172+ def test_preload_batch_false (sample_ds_1d ):
173+ sample_ds_1d_dask = sample_ds_1d .chunk ({'x' : 2 })
174+ bg = BatchGenerator (
175+ sample_ds_1d_dask , input_dims = {'x' : 2 }, preload_batch = False
176+ )
177+ assert bg .preload_batch is False
178+ for ds_batch in bg :
179+ assert isinstance (ds_batch , xr .Dataset )
180+ assert ds_batch .chunks
181+
182+
183+ def test_preload_batch_true (sample_ds_1d ):
184+ sample_ds_1d_dask = sample_ds_1d .chunk ({'x' : 2 })
185+ bg = BatchGenerator (
186+ sample_ds_1d_dask , input_dims = {'x' : 2 }, preload_batch = True
187+ )
188+ assert bg .preload_batch is True
189+ for ds_batch in bg :
190+ assert isinstance (ds_batch , xr .Dataset )
191+ assert not ds_batch .chunks
0 commit comments