Skip to content

Commit c0074ec

Browse files
committed
Exposes graphdef for flax models.
Signed-off-by: Lance Wang <lancewang@google.com>
1 parent 7f7be82 commit c0074ec

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def combine_hidden_states(graphdef, state, hidden_states):
284284
"get_mrope_input_positions_fn": get_mrope_input_positions_fn,
285285
}
286286

287-
return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model
287+
return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model, graphdef
288288

289289

290290
def get_vllm_model(
@@ -305,7 +305,7 @@ def get_vllm_model(
305305
compute_logits_fn = model.jit_compute_logits_func()
306306
# the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU.
307307
combine_hidden_states_fn = None
308-
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model
308+
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model, None
309309

310310

311311
def get_model(

tpu_inference/runner/tpu_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def _init_inputs(self) -> None:
468468
dtype=np.int64)
469469

470470
def load_model(self):
471-
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
471+
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model, self.graphdef = get_model(
472472
self.vllm_config,
473473
self.rng_key,
474474
self.mesh,

0 commit comments

Comments
 (0)