Skip to content

Commit 31d1be2

Browse files
committed
rebase with main
1 parent 23cf53b commit 31d1be2

File tree

3 files changed

+37
-76
lines changed

3 files changed

+37
-76
lines changed

xarray/backends/pydap_.py

Lines changed: 25 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import os
4-
import warnings
54
from collections.abc import Iterable
65
from typing import TYPE_CHECKING, Any
76

@@ -37,10 +36,9 @@
3736

3837

3938
class PydapArrayWrapper(BackendArray):
40-
def __init__(self, array, batch=False, cache=None, checksums=True):
39+
def __init__(self, array, batch=None, checksums=True):
4140
self.array = array
4241
self._batch = batch
43-
self._cache = cache
4442
self._checksums = checksums
4543

4644
@property
@@ -57,27 +55,19 @@ def __getitem__(self, key):
5755
)
5856

5957
def _getitem(self, key):
60-
if self.array.id in self._cache.keys():
61-
# safely avoid re-downloading some coordinates
62-
result = self._cache[self.array.id]
63-
elif self._batch and hasattr(self.array, "dataset"):
58+
if self._batch and hasattr(self.array, "dataset"):
6459
# this are both True only for pydap>3.5.5
65-
from pydap.lib import resolve_batch_for_all_variables
60+
from pydap.client import data_check, get_batch_data
6661

6762
dataset = self.array.dataset
68-
resolve_batch_for_all_variables(self.array, key, checksums=self._checksums)
69-
result = np.asarray(
70-
dataset._current_batch_promise.wait_for_result(self.array.id)
71-
)
63+
get_batch_data(self.array, checksums=self._checksums, key=key)
64+
result = data_check(np.asarray(dataset[self.array.id].data), key)
7265
else:
7366
result = robust_getitem(self.array, key, catch=ValueError)
74-
try:
75-
result = np.asarray(result.data)
76-
except AttributeError:
77-
result = np.asarray(result)
78-
axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types))
79-
if result.ndim + len(axis) != self.array.ndim and axis:
80-
result = np.squeeze(result, axis)
67+
result = np.asarray(result.data)
68+
axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types))
69+
if result.ndim + len(axis) != self.array.ndim and axis:
70+
result = np.squeeze(result, axis)
8171
return result
8272

8373

@@ -105,7 +95,7 @@ def __init__(
10595
dataset,
10696
group=None,
10797
session=None,
108-
batch=False,
98+
batch=None,
10999
protocol=None,
110100
checksums=True,
111101
):
@@ -119,8 +109,6 @@ def __init__(
119109
self.dataset = dataset
120110
self.group = group
121111
self._batch = batch
122-
self._batch_done = False
123-
self._array_cache = {} # holds 1D dimension data
124112
self._protocol = protocol
125113
self._checksums = checksums # true by default
126114

@@ -135,7 +123,7 @@ def open(
135123
timeout=None,
136124
verify=None,
137125
user_charset=None,
138-
batch=False,
126+
batch=None,
139127
checksums=True,
140128
):
141129
from pydap.client import open_url
@@ -167,34 +155,23 @@ def open(
167155
elif hasattr(url, "ds"):
168156
# pydap dataset
169157
dataset = url.ds
170-
args = {"dataset": dataset}
171-
args["checksums"] = checksums
158+
args = {"dataset": dataset, "checksums": checksums}
172159
if group:
173160
args["group"] = group
174161
if url.startswith(("http", "dap2")):
175162
args["protocol"] = "dap2"
176163
elif url.startswith("dap4"):
177164
args["protocol"] = "dap4"
178165
if batch:
179-
if args["protocol"] == "dap2":
180-
warnings.warn(
181-
f"`batch={batch}` is currently only compatible with the `DAP4` "
182-
"protocol. Make sue the OPeNDAP server implements the `DAP4` "
183-
"protocol and then replace the scheme of the url with `dap4` "
184-
"to make use of it. Setting `batch=False`.",
185-
stacklevel=2,
186-
)
187-
else:
188-
# only update if dap4
189-
args["batch"] = batch
166+
args["batch"] = batch
190167
return cls(**args)
191168

192169
def open_store_variable(self, var):
193-
try:
170+
if hasattr(var, "dims"):
194171
dimensions = [
195172
dim.split("/")[-1] if dim.startswith("/") else dim for dim in var.dims
196173
]
197-
except AttributeError:
174+
else:
198175
# GridType does not have a dims attribute - instead get `dimensions`
199176
# see https://github.com/pydap/pydap/issues/485
200177
dimensions = var.dimensions
@@ -214,7 +191,7 @@ def open_store_variable(self, var):
214191
else:
215192
# all non-dimension variables
216193
data = indexing.LazilyIndexedArray(
217-
PydapArrayWrapper(var, self._batch, self._array_cache, self._checksums)
194+
PydapArrayWrapper(var, self._batch, self._checksums)
218195
)
219196

220197
return Variable(dimensions, data, var.attributes)
@@ -264,16 +241,14 @@ def _get_data_array(self, var):
264241
"""gets dimension data all at once, storing the numpy
265242
arrays within a cached dictionary
266243
"""
267-
from pydap.lib import get_batch_data
244+
from pydap.client import get_batch_data
268245

269-
if not self._batch_done or var.id not in self._array_cache:
270-
# store all dim data into a dict for reuse
271-
self._array_cache = get_batch_data(
272-
var.parent, self._array_cache, self._checksums
273-
)
274-
self._batch_done = True
246+
if not var._is_data_loaded():
247+
# data has not been deserialized yet
248+
# runs only once per store/hierarchy
249+
get_batch_data(var, checksums=self._checksums)
275250

276-
return self._array_cache[var.id]
251+
return self.dataset[var.id].data
277252

278253

279254
class PydapBackendEntrypoint(BackendEntrypoint):
@@ -336,7 +311,7 @@ def open_dataset(
336311
timeout=None,
337312
verify=None,
338313
user_charset=None,
339-
batch=False,
314+
batch=None,
340315
checksums=True,
341316
) -> Dataset:
342317
store = PydapDataStore.open(
@@ -382,7 +357,7 @@ def open_datatree(
382357
timeout=None,
383358
verify=None,
384359
user_charset=None,
385-
batch=False,
360+
batch=None,
386361
checksums=True,
387362
) -> DataTree:
388363
groups_dict = self.open_groups_as_dict(
@@ -423,7 +398,7 @@ def open_groups_as_dict(
423398
timeout=None,
424399
verify=None,
425400
user_charset=None,
426-
batch=False,
401+
batch=None,
427402
checksums=True,
428403
) -> dict[str, Dataset]:
429404
from xarray.core.treenode import NodePath

xarray/tests/test_backends.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6564,14 +6564,15 @@ def test_session(self) -> None:
65646564
@network
65656565
@pytest.mark.parametrize("protocol", ["dap2", "dap4"])
65666566
@pytest.mark.parametrize("batch", [False, True])
6567-
def test_batchdap4_downloads(protocol, batch) -> None:
6567+
def test_batchdap4_downloads(tmpdir, protocol, batch) -> None:
65686568
"""Test that in dap4, all dimensions are downloaded at once"""
65696569
import pydap
65706570
from pydap.net import create_session
65716571

65726572
_version_ = Version(pydap.__version__)
65736573
# Create a session with pre-set params in pydap backend, to cache urls
6574-
session = create_session(use_cache=True, cache_kwargs={"cache_name": "debug"})
6574+
cache_name = tmpdir / "debug"
6575+
session = create_session(use_cache=True, cache_kwargs={"cache_name": cache_name})
65756576
session.cache.clear()
65766577
url = "https://test.opendap.org/opendap/hyrax/data/nc/coads_climatology.nc"
65776578

@@ -6611,25 +6612,6 @@ def test_batchdap4_downloads(protocol, batch) -> None:
66116612
assert len(session.cache.urls()) == 5
66126613

66136614

6614-
@requires_pydap
6615-
@network
6616-
def test_batch_warnswithdap2() -> None:
6617-
from pydap.net import create_session
6618-
6619-
# Create a session with pre-set retry params in pydap backend, to cache urls
6620-
session = create_session(use_cache=True, cache_kwargs={"cache_name": "debug"})
6621-
session.cache.clear()
6622-
6623-
url = "dap2://test.opendap.org/opendap/hyrax/data/nc/coads_climatology.nc"
6624-
with pytest.warns(UserWarning):
6625-
open_dataset(
6626-
url, engine="pydap", session=session, batch=True, decode_times=False
6627-
)
6628-
6629-
# no batching is supported here
6630-
assert len(session.cache.urls()) == 5
6631-
6632-
66336615
class TestEncodingInvalid:
66346616
def test_extract_nc4_variable_encoding(self) -> None:
66356617
var = xr.Variable(("x",), [1, 2, 3], {}, {"foo": "bar"})

xarray/tests/test_backends_datatree.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,8 @@ class TestPyDAPDatatreeIO:
581581
simplegroup_datatree_url = "dap4://test.opendap.org/opendap/dap4/SimpleGroup.nc4.h5"
582582

583583
def test_open_datatree_unaligned_hierarchy(
584-
self, url=unaligned_datatree_url
584+
self,
585+
url=unaligned_datatree_url,
585586
) -> None:
586587
with pytest.raises(
587588
ValueError,
@@ -614,7 +615,7 @@ def test_open_groups(self, url=unaligned_datatree_url) -> None:
614615
) as expected:
615616
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)
616617

617-
def test_inherited_coords(self, url=simplegroup_datatree_url) -> None:
618+
def test_inherited_coords(self, tmpdir, url=simplegroup_datatree_url) -> None:
618619
"""Test that `open_datatree` inherits coordinates from root tree.
619620
620621
This particular h5 file is a test file that inherits the time coordinate from the root
@@ -644,7 +645,10 @@ def test_inherited_coords(self, url=simplegroup_datatree_url) -> None:
644645
from pydap.net import create_session
645646

646647
# Create a session with pre-set retry params in pydap backend, to cache urls
647-
session = create_session(use_cache=True, cache_kwargs={"cache_name": "debug"})
648+
cache_name = tmpdir / "debug"
649+
session = create_session(
650+
use_cache=True, cache_kwargs={"cache_name": cache_name}
651+
)
648652
session.cache.clear()
649653

650654
_version_ = Version(pydap.__version__)
@@ -661,8 +665,8 @@ def test_inherited_coords(self, url=simplegroup_datatree_url) -> None:
661665
)
662666

663667
if _version_ > Version("3.5.5"):
664-
# Total downloads are: 1 dmr, + 1 dap url for all dimensions across groups
665-
assert len(session.cache.urls()) == 2
668+
# Total downloads are: 1 dmr, + 1 dap url for all dimensions for each group
669+
assert len(session.cache.urls()) == 3
666670
else:
667671
# 1 dmr + 1 dap url per dimension (total there are 4 dimension arrays)
668672
assert len(session.cache.urls()) == 5

0 commit comments

Comments
 (0)