diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index dd4082709..f14952ee1 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -284,7 +284,7 @@ def combine_hidden_states(graphdef, state, hidden_states): "get_mrope_input_positions_fn": get_mrope_input_positions_fn, } - return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model + return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model, graphdef def get_vllm_model( @@ -305,7 +305,7 @@ def get_vllm_model( compute_logits_fn = model.jit_compute_logits_func() # the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU. combine_hidden_states_fn = None - return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model + return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model, None def get_model( diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 9a92423f6..47ee1bc1b 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -468,7 +468,7 @@ def _init_inputs(self) -> None: dtype=np.int64) def load_model(self): - self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model( + self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model, self.graphdef = get_model( self.vllm_config, self.rng_key, self.mesh,