1313from pydantic import BaseModel
1414
1515from llmcompressor .core import active_session
16- from tests .e2e .e2e_utils import run_oneshot_for_e2e_testing
16+ from tests .e2e .e2e_utils import load_model , run_oneshot_for_e2e_testing
1717from tests .test_timer .timer_utils import get_singleton_manager , log_time
1818from tests .testing_utils import requires_gpu
1919
@@ -35,6 +35,10 @@ class LmEvalConfig(BaseModel):
3535
3636try :
3737 import lm_eval
38+ import lm_eval .api .registry
39+
40+ # needed to populate model registry
41+ import lm_eval .models # noqa
3842
3943 lm_eval_installed = True
4044except ImportError :
@@ -120,7 +124,7 @@ def test_lm_eval(self, test_data_file: str):
120124
121125 # Always evaluate base model for recovery testing
122126 logger .info ("================= Evaluating BASE model ======================" )
123- self . base_results = self ._eval_base_model ()
127+ base_results = self ._eval_base_model ()
124128
125129 if not self .save_dir :
126130 self .save_dir = self .model .split ("/" )[1 ] + f"-{ self .scheme } "
@@ -145,22 +149,41 @@ def test_lm_eval(self, test_data_file: str):
145149 self ._handle_recipe ()
146150
147151 logger .info ("================= Running LM Eval on COMPRESSED model ==========" )
148- self ._run_lm_eval ()
152+ compressed_results = self ._eval_compressed_model ()
153+
154+ # Always use recovery testing
155+ self ._validate_recovery (base_results , compressed_results )
156+
157+ # If absolute metrics provided, show warnings (not failures)
158+ if self .lmeval .metrics :
159+ self ._check_absolute_warnings (compressed_results )
149160
150161 self .tear_down ()
151162
152163 @log_time
153- def _eval_base_model (self ):
164+ def _eval_base_model (self ) -> dict :
154165 """Evaluate the base (uncompressed) model."""
155- model_args = {** self .lmeval .model_args , "pretrained" : self .model }
166+ return self ._eval_model (self .model )
167+
168+ @log_time
169+ def _eval_compressed_model (self ) -> dict :
170+ """Evaluate the compressed model."""
171+ return self ._eval_model (self .save_dir )
172+
173+ def _eval_model (self , model : str ) -> dict :
174+ # NOTE: pass in PreTrainedModel to avoid lm_eval's model-loading logic
175+ # https://github.com/EleutherAI/lm-evaluation-harness/pull/3393
176+ lm_eval_cls = lm_eval .api .registry .get_model (self .lmeval .model )
156177
157178 results = lm_eval .simple_evaluate (
158- model = self .lmeval .model ,
159- model_args = model_args ,
179+ model = lm_eval_cls (
180+ pretrained = load_model (model , self .model_class , device_map = "cuda:0" ),
181+ batch_size = self .lmeval .batch_size ,
182+ ** self .lmeval .model_args ,
183+ ),
160184 tasks = [self .lmeval .task ],
161185 num_fewshot = self .lmeval .num_fewshot ,
162186 limit = self .lmeval .limit ,
163- device = "cuda:0" ,
164187 apply_chat_template = self .lmeval .apply_chat_template ,
165188 batch_size = self .lmeval .batch_size ,
166189 )
@@ -181,31 +204,9 @@ def _handle_recipe(self):
181204 fp .write (recipe_yaml_str )
182205 session .reset ()
183206
184- @log_time
185- def _run_lm_eval (self ):
186- model_args = {"pretrained" : self .save_dir }
187- model_args .update (self .lmeval .model_args )
188- results = lm_eval .simple_evaluate (
189- model = self .lmeval .model ,
190- model_args = model_args ,
191- tasks = [self .lmeval .task ],
192- num_fewshot = self .lmeval .num_fewshot ,
193- limit = self .lmeval .limit ,
194- device = "cuda:0" ,
195- apply_chat_template = self .lmeval .apply_chat_template ,
196- batch_size = self .lmeval .batch_size ,
197- )
198-
199- # Always use recovery testing
200- self ._validate_recovery (results )
201-
202- # If absolute metrics provided, show warnings (not failures)
203- if self .lmeval .metrics :
204- self ._check_absolute_warnings (results )
205-
206- def _validate_recovery (self , compressed_results ):
207+ def _validate_recovery (self , base_results , compressed_results ):
207208 """Validate using recovery testing - compare against base model."""
208- base_metrics = self . base_results ["results" ][self .lmeval .task ]
209+ base_metrics = base_results ["results" ][self .lmeval .task ]
209210 compressed_metrics = compressed_results ["results" ][self .lmeval .task ]
210211 higher_is_better_map = compressed_results .get ("higher_is_better" , {}).get (
211212 self .lmeval .task , {}
0 commit comments