diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index 7161acc3..3e35f4d6 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -27,10 +27,16 @@ if WITH_PT24: import torch + import numpy as np + import _codecs torch.serialization.add_safe_globals([ stype, torch_frame.data.stats.StatType, + np.core.multiarray.scalar, + np.dtype, + type(np.dtype(np.int32)), + _codecs.encode, ]) __version__ = '0.2.3'