Skip to content

Commit 81eff05

Browse files
[tests] switch lm_eval invocation to use pre-loaded transformers model (#2018)
SUMMARY: `lm_eval==0.4.9.1` has a broken entrypoint when using a model with a compressed-tensors quantization config with `--model hf`: ``` FAILED tests/lmeval/test_lmeval.py::TestLMEval::test_lm_eval[tests/lmeval/configs/vl_w4a16_actorder_weight.yaml] - ValueError: The model is quantized with CompressedTensorsConfig but you are passing a dict config. Please make sure to pass the same quantization config class to `from_pretrained` with different loading attributes. ``` It has been resolved on main, though a separate issue persists that is resolved with this PR -- EleutherAI/lm-evaluation-harness#3393. While that is in transit, and to avoid having to use lm_eval main in our ci/cd, this PR resolves the issue by pre-loading the model with `AutoModelForCausalLM` rather than relying on lm_eval's strange model loading logic. TEST PLAN: tests run now, for some reason the vl test is super slow on ibm-h100-1. The same thing happens on main. I've seen this before, but I'm not sure what's causing it. It seemed to correct itself the following day --------- Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 63c175b commit 81eff05

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

tests/e2e/e2e_utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,12 @@
1010
from tests.testing_utils import process_dataset
1111

1212

13-
@log_time
14-
def _load_model_and_processor(
15-
model: str,
16-
model_class: str,
17-
):
13+
def load_model(model: str, model_class: str, device_map: str | None = None):
1814
pretrained_model_class = getattr(transformers, model_class)
19-
loaded_model = pretrained_model_class.from_pretrained(model, torch_dtype="auto")
20-
processor = AutoProcessor.from_pretrained(model)
21-
return loaded_model, processor
15+
loaded_model = pretrained_model_class.from_pretrained(
16+
model, torch_dtype="auto", device_map=device_map
17+
)
18+
return loaded_model
2219

2320

2421
@log_time
@@ -41,9 +38,8 @@ def run_oneshot_for_e2e_testing(
4138
# Load model.
4239
oneshot_kwargs = {}
4340

44-
loaded_model, processor = _load_model_and_processor(
45-
model=model, model_class=model_class
46-
)
41+
loaded_model = load_model(model=model, model_class=model_class)
42+
processor = AutoProcessor.from_pretrained(model)
4743

4844
if dataset_id:
4945
ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split)

tests/lmeval/test_lmeval.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pydantic import BaseModel
1414

1515
from llmcompressor.core import active_session
16-
from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing
16+
from tests.e2e.e2e_utils import load_model, run_oneshot_for_e2e_testing
1717
from tests.test_timer.timer_utils import get_singleton_manager, log_time
1818
from tests.testing_utils import requires_gpu
1919

@@ -35,6 +35,10 @@ class LmEvalConfig(BaseModel):
3535

3636
try:
3737
import lm_eval
38+
import lm_eval.api.registry
39+
40+
# needed to populate model registry
41+
import lm_eval.models # noqa
3842

3943
lm_eval_installed = True
4044
except ImportError:
@@ -120,7 +124,7 @@ def test_lm_eval(self, test_data_file: str):
120124

121125
# Always evaluate base model for recovery testing
122126
logger.info("================= Evaluating BASE model ======================")
123-
self.base_results = self._eval_base_model()
127+
base_results = self._eval_base_model()
124128

125129
if not self.save_dir:
126130
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
@@ -145,22 +149,41 @@ def test_lm_eval(self, test_data_file: str):
145149
self._handle_recipe()
146150

147151
logger.info("================= Running LM Eval on COMPRESSED model ==========")
148-
self._run_lm_eval()
152+
compressed_results = self._eval_compressed_model()
153+
154+
# Always use recovery testing
155+
self._validate_recovery(base_results, compressed_results)
156+
157+
# If absolute metrics provided, show warnings (not failures)
158+
if self.lmeval.metrics:
159+
self._check_absolute_warnings(compressed_results)
149160

150161
self.tear_down()
151162

152163
@log_time
153-
def _eval_base_model(self):
164+
def _eval_base_model(self) -> dict:
154165
"""Evaluate the base (uncompressed) model."""
155-
model_args = {**self.lmeval.model_args, "pretrained": self.model}
166+
return self._eval_model(self.model)
167+
168+
@log_time
169+
def _eval_compressed_model(self) -> dict:
170+
"""Evaluate the compressed model."""
171+
return self._eval_model(self.save_dir)
172+
173+
def _eval_model(self, model: str) -> dict:
174+
# NOTE: pass in PreTrainedModel to avoid lm_eval's model-loading logic
175+
# https://github.com/EleutherAI/lm-evaluation-harness/pull/3393
176+
lm_eval_cls = lm_eval.api.registry.get_model(self.lmeval.model)
156177

157178
results = lm_eval.simple_evaluate(
158-
model=self.lmeval.model,
159-
model_args=model_args,
179+
model=lm_eval_cls(
180+
pretrained=load_model(model, self.model_class, device_map="cuda:0"),
181+
batch_size=self.lmeval.batch_size,
182+
**self.lmeval.model_args,
183+
),
160184
tasks=[self.lmeval.task],
161185
num_fewshot=self.lmeval.num_fewshot,
162186
limit=self.lmeval.limit,
163-
device="cuda:0",
164187
apply_chat_template=self.lmeval.apply_chat_template,
165188
batch_size=self.lmeval.batch_size,
166189
)
@@ -181,31 +204,9 @@ def _handle_recipe(self):
181204
fp.write(recipe_yaml_str)
182205
session.reset()
183206

184-
@log_time
185-
def _run_lm_eval(self):
186-
model_args = {"pretrained": self.save_dir}
187-
model_args.update(self.lmeval.model_args)
188-
results = lm_eval.simple_evaluate(
189-
model=self.lmeval.model,
190-
model_args=model_args,
191-
tasks=[self.lmeval.task],
192-
num_fewshot=self.lmeval.num_fewshot,
193-
limit=self.lmeval.limit,
194-
device="cuda:0",
195-
apply_chat_template=self.lmeval.apply_chat_template,
196-
batch_size=self.lmeval.batch_size,
197-
)
198-
199-
# Always use recovery testing
200-
self._validate_recovery(results)
201-
202-
# If absolute metrics provided, show warnings (not failures)
203-
if self.lmeval.metrics:
204-
self._check_absolute_warnings(results)
205-
206-
def _validate_recovery(self, compressed_results):
207+
def _validate_recovery(self, base_results, compressed_results):
207208
"""Validate using recovery testing - compare against base model."""
208-
base_metrics = self.base_results["results"][self.lmeval.task]
209+
base_metrics = base_results["results"][self.lmeval.task]
209210
compressed_metrics = compressed_results["results"][self.lmeval.task]
210211
higher_is_better_map = compressed_results.get("higher_is_better", {}).get(
211212
self.lmeval.task, {}

0 commit comments

Comments
 (0)