From dd9e028b66ecb271b95a759e1d5d748ac13c1a87 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Fri, 19 Jun 2026 21:23:31 -0700 Subject: [PATCH] tidy: PyTorch now always at 2.x, remove unnecessary code that thinks it can still be 1.x --- tests/unit/utilities/test_devices.py | 17 +++-------------- transformer_lens/HookedTransformer.py | 4 ---- transformer_lens/utilities/devices.py | 23 ++++++++++------------- 3 files changed, 13 insertions(+), 31 deletions(-) diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 3116c6c04..8160356d7 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -48,23 +48,12 @@ def test_get_device_cuda_available(): @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) def test_get_device_mps_available(): - """Test get_device when MPS is available, PyTorch version >= 2.0, and env var set.""" + """Test get_device when MPS is available and env var set.""" with patch("torch.cuda.is_available", return_value=False): with patch("torch.backends.mps.is_available", return_value=True): with patch("torch.backends.mps.is_built", return_value=True): - with patch("torch.__version__", "2.0.0"): - device = get_device() - assert device == "mps" - - -def test_get_device_mps_pytorch_1x(): - """Test get_device when MPS is available but PyTorch version < 2.0.""" - with patch("torch.cuda.is_available", return_value=False): - with patch("torch.backends.mps.is_available", return_value=True): - with patch("torch.backends.mps.is_built", return_value=True): - with patch("torch.__version__", "1.13.0"): - device = get_device() - assert device == "cpu" + device = get_device() + assert device == "mps" def test_get_device_cpu_fallback(): diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 9baf479e5..d903c9c67 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -35,7 +35,6 @@ import torch.nn.functional as F import tqdm.auto as tqdm from jaxtyping import Float, Int -from packaging import version from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -1345,9 +1344,6 @@ def from_pretrained( load_in_8bit = qc.get("load_in_8bit", False) quant_method = qc.get("quant_method", "") assert not load_in_8bit, "8-bit quantization is not supported" - assert not ( - load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) - ), "Quantization is only supported for torch versions >= 2.1.1" assert not ( load_in_4bit and ("llama" not in model_name.lower()) ), "Quantization is only supported for Llama models" diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index c152158b6..9e7c340ba 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -57,8 +57,7 @@ def get_device() -> str: """Get the best available device, with MPS safety checks. MPS is only auto-selected when the environment variable - ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch - version is 2.0 or higher. + ``TRANSFORMERLENS_ALLOW_MPS=1`` is set. Returns: str: The best available device name (cuda, mps, or cpu) @@ -67,17 +66,15 @@ def get_device() -> str: return "cuda" if torch.backends.mps.is_available() and torch.backends.mps.is_built(): - major_version = int(torch.__version__.split(".")[0]) - if major_version >= 2: - # Only auto-select MPS when explicitly opted-in via env var - if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": - return "mps" - logging.info( - "MPS device available but not auto-selected due to known correctness issues " - "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " - "https://github.com/TransformerLensOrg/TransformerLens/issues/1178", - torch.__version__, - ) + # Only auto-select MPS when explicitly opted-in via env var + if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": + return "mps" + logging.info( + "MPS device available but not auto-selected due to known correctness issues " + "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " + "https://github.com/TransformerLensOrg/TransformerLens/issues/1178", + torch.__version__, + ) return "cpu"