Skip to content

Commit 6559de0

Browse files
committed
clean up implementation
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent e6f0abc commit 6559de0

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/llmcompressor/utils/helpers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,11 +1074,19 @@ def disable_lm_head(model: torch.nn.Module):
10741074
"""
10751075
_, lm_head = get_embeddings(model)
10761076
if lm_head is not None:
1077-
if not isinstance(lm_head, torch.nn.Linear):
1078-
raise NotImplementedError(
1079-
f"Cannot disable LM head of type {lm_head.__class__.__name__}"
1080-
)
1077+
logger.warning(
1078+
f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
1079+
"but was unable to to find lm_head. This may lead to unexpected OOM."
1080+
)
1081+
yield
1082+
return
1083+
1084+
elif not isinstance(lm_head, torch.nn.Linear):
1085+
logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}")
1086+
yield
1087+
return
10811088

1089+
else:
10821090
dummy_weight = lm_head.weight.to("meta")
10831091

10841092
def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -1087,13 +1095,6 @@ def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
10871095
with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)):
10881096
yield
10891097

1090-
else:
1091-
logger.warning(
1092-
f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
1093-
"but was unable to to find lm_head. This may lead to unexpected OOM."
1094-
)
1095-
yield
1096-
10971098

10981099
@contextlib.contextmanager
10991100
def patch_attr(base: object, attr: str, value: Any):

0 commit comments

Comments
 (0)