From c0074ec0cc13581d1c5fd30282853a846a7685fc Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 10 Nov 2025 02:31:04 +0000 Subject: [PATCH] Exposes graphdef for flax models. Signed-off-by: Lance Wang --- tpu_inference/models/common/model_loader.py | 4 ++-- tpu_inference/runner/tpu_runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index dd4082709..f14952ee1 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -284,7 +284,7 @@ def combine_hidden_states(graphdef, state, hidden_states): "get_mrope_input_positions_fn": get_mrope_input_positions_fn, } - return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model + return model_fn, compute_logits_fn, combine_hidden_states_fn, multimodal_fns, state, lora_manager, model, graphdef def get_vllm_model( @@ -305,7 +305,7 @@ def get_vllm_model( compute_logits_fn = model.jit_compute_logits_func() # the model needs to be returned because lora weights are neither torch.nn.parameter nor torch.nn.buffer. After we load the lora weights and set it to the torch.nn.Module, we can shard it and move it to TPU. combine_hidden_states_fn = None - return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model + return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model, None def get_model( diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index 9a92423f6..47ee1bc1b 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -468,7 +468,7 @@ def _init_inputs(self) -> None: dtype=np.int64) def load_model(self): - self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model( + self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model, self.graphdef = get_model( self.vllm_config, self.rng_key, self.mesh,