Skip to content

Commit df244cc

Browse files
committed
torch.load fix for all pytorch versions.
1 parent 98ac171 commit df244cc

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,12 @@ def _load_dataset_from_cache(self, tag: str = "train"):
758758
)
759759
elif self.cache_mode is self.CACHE_MODES.DISK:
760760
try:
761-
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
761+
# get the torch version
762+
torch_version = torch.__version__
763+
if torch_version < "2.6":
764+
dataset = torch.load(self.cache_dir / f"{tag}_dataset") # fix for torch version change of torch.load
765+
elif torch_version >= "2.6":
766+
dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False)
762767
except FileNotFoundError:
763768
raise FileNotFoundError(
764769
f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader"

src/pytorch_tabular/utils/python_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,24 @@ def pl_load(
7474
"""
7575
if not isinstance(path_or_url, (str, Path)):
7676
# any sort of BytesIO or similar
77-
return torch.load(path_or_url, map_location=map_location, weights_only=False)
77+
# get the torch version
78+
torch_version = torch.__version__
79+
if torch_version < "2.6":
80+
return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6
81+
elif torch_version >= "2.6":
82+
return torch.load(path_or_url, map_location=map_location, weights_only=False)
7883
if str(path_or_url).startswith("http"):
7984
return torch.hub.load_state_dict_from_url(
8085
str(path_or_url),
8186
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
8287
)
8388
fs = get_filesystem(path_or_url)
8489
with fs.open(path_or_url, "rb") as f:
85-
return torch.load(f, map_location=map_location, weights_only=False)
90+
if torch_version < "2.6":
91+
return torch.load(f, map_location=map_location) # for torch version < 2.6
92+
elif torch_version >= "2.6":
93+
return torch.load(f, map_location=map_location, weights_only=False)
94+
8695

8796

8897
def check_numpy(x):

0 commit comments

Comments
 (0)