Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions tests/unit/utilities/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,12 @@ def test_get_device_cuda_available():

@patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"})
def test_get_device_mps_available():
"""Test get_device when MPS is available, PyTorch version >= 2.0, and env var set."""
"""Test get_device when MPS is available and env var set."""
with patch("torch.cuda.is_available", return_value=False):
with patch("torch.backends.mps.is_available", return_value=True):
with patch("torch.backends.mps.is_built", return_value=True):
with patch("torch.__version__", "2.0.0"):
device = get_device()
assert device == "mps"


def test_get_device_mps_pytorch_1x():
"""Test get_device when MPS is available but PyTorch version < 2.0."""
with patch("torch.cuda.is_available", return_value=False):
with patch("torch.backends.mps.is_available", return_value=True):
with patch("torch.backends.mps.is_built", return_value=True):
with patch("torch.__version__", "1.13.0"):
device = get_device()
assert device == "cpu"
device = get_device()
assert device == "mps"


def test_get_device_cpu_fallback():
Expand Down
4 changes: 0 additions & 4 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import torch.nn.functional as F
import tqdm.auto as tqdm
from jaxtyping import Float, Int
from packaging import version
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Expand Down Expand Up @@ -1345,9 +1344,6 @@ def from_pretrained(
load_in_8bit = qc.get("load_in_8bit", False)
quant_method = qc.get("quant_method", "")
assert not load_in_8bit, "8-bit quantization is not supported"
assert not (
load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
), "Quantization is only supported for torch versions >= 2.1.1"
assert not (
load_in_4bit and ("llama" not in model_name.lower())
), "Quantization is only supported for Llama models"
Expand Down
23 changes: 10 additions & 13 deletions transformer_lens/utilities/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def get_device() -> str:
"""Get the best available device, with MPS safety checks.

MPS is only auto-selected when the environment variable
``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch
version is 2.0 or higher.
``TRANSFORMERLENS_ALLOW_MPS=1`` is set.

Returns:
str: The best available device name (cuda, mps, or cpu)
Expand All @@ -67,17 +66,15 @@ def get_device() -> str:
return "cuda"

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
major_version = int(torch.__version__.split(".")[0])
if major_version >= 2:
# Only auto-select MPS when explicitly opted-in via env var
if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1":
return "mps"
logging.info(
"MPS device available but not auto-selected due to known correctness issues "
"(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: "
"https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
torch.__version__,
)
# Only auto-select MPS when explicitly opted-in via env var
if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1":
return "mps"
logging.info(
"MPS device available but not auto-selected due to known correctness issues "
"(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: "
"https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
torch.__version__,
)

return "cpu"

Expand Down
Loading