Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 96 additions & 15 deletions src/tether/exporters/vlm_prefix_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,23 @@ def patch_onnx_type_mismatches(onnx_path: str | Path) -> int:
return fixed_count


def _apply_checkpoint_vlm_weights(base_model, checkpoint_state_dict):
"""Overwrite the base VLM's weights with fine-tuned ones from the checkpoint."""
def _apply_checkpoint_vlm_weights(
base_model,
checkpoint_state_dict,
tag: str = "vlm",
) -> int:
"""Overwrite the base VLM's weights with fine-tuned ones from the checkpoint.

Args:
base_model: The nn.Module to update in-place.
checkpoint_state_dict: Fine-tuned checkpoint state dict (SmolVLA format).
tag: Short label used in log messages to identify which sub-model is
being updated (e.g. "text_embedder", "decoder_prefill").

Returns:
Number of keys actually applied (total model keys − missing keys).
Returns 0 when no matching prefix is found or on error.
"""
prefixes = [
"model.vlm_with_expert.vlm.",
"vlm_with_expert.vlm.",
Expand All @@ -233,7 +248,7 @@ def _apply_checkpoint_vlm_weights(base_model, checkpoint_state_dict):
if len(parts) >= 3:
top.add(".".join(parts[:3]))
print(
f"[vlm-weights] No matching VLM prefix found. "
f"[vlm-weights/{tag}] No matching VLM prefix found. "
f"Top-3 prefixes sample: {sorted(list(top))[:5]}",
flush=True,
)
Expand All @@ -245,27 +260,36 @@ def _apply_checkpoint_vlm_weights(base_model, checkpoint_state_dict):
if k.startswith(matched_prefix)
}
print(
f"[vlm-weights] matched_prefix={matched_prefix!r}, rebased {len(rebased)} keys",
f"[vlm-weights/{tag}] matched_prefix={matched_prefix!r}, rebased {len(rebased)} keys",
flush=True,
)
print(
f"[vlm-weights] sample rebased keys: {sorted(list(rebased.keys()))[:5]}",
f"[vlm-weights/{tag}] sample rebased keys: {sorted(list(rebased.keys()))[:5]}",
flush=True,
)

try:
missing, unexpected = base_model.load_state_dict(rebased, strict=False)
total_model_keys = len(list(base_model.state_dict().keys()))
applied = total_model_keys - len(missing)
print(
f"[vlm-weights] load: {len(missing)} missing, {len(unexpected)} unexpected",
f"[vlm-weights/{tag}] applied {applied}/{total_model_keys} fine-tune weights "
f"({len(missing)} missing, {len(unexpected)} unexpected)",
flush=True,
)
if applied == 0:
print(
f"[vlm-weights/{tag}] WARNING: zero keys applied — fine-tune weights "
f"did NOT update this sub-model. Check key namespace alignment.",
flush=True,
)
if missing:
print(f"[vlm-weights] first 5 missing: {missing[:5]}", flush=True)
print(f"[vlm-weights/{tag}] first 5 missing: {missing[:5]}", flush=True)
if unexpected:
print(f"[vlm-weights] first 5 unexpected: {unexpected[:5]}", flush=True)
return len(rebased)
print(f"[vlm-weights/{tag}] first 5 unexpected: {unexpected[:5]}", flush=True)
return applied
except Exception as e:
print(f"[vlm-weights] FAILED: {e}", flush=True)
print(f"[vlm-weights/{tag}] FAILED: {e}", flush=True)
return 0


Expand Down Expand Up @@ -319,7 +343,7 @@ def export_vlm_prefix(
)
model.eval()
if state_dict is not None:
_apply_checkpoint_vlm_weights(model, state_dict)
_apply_checkpoint_vlm_weights(model, state_dict, tag="vision_encoder")

# The AutoModelForImageTextToText wrapper has `self.model = SmolVLMModel(...)`
# whose sub-modules are the actual vision/connector/text_model. Drill down
Expand Down Expand Up @@ -457,18 +481,20 @@ def export_vlm_prefix(
config_path.write_text(json.dumps(config, indent=2))
logger.info("Updated config: %s", config_path)

# 8. Export text embedder
text_emb_path = export_text_embedder(
# 8. Export text embedder (pass fine-tune state_dict so embed_tokens are updated)
export_text_embedder(
checkpoint_path_or_id=checkpoint_path_or_id,
output_dir=output_dir,
opset=opset,
checkpoint_state_dict=state_dict,
)

# 9. Export decoder prefill
decoder_path = export_decoder_prefill(
# 9. Export decoder prefill (pass fine-tune state_dict so decoder k/v are updated)
export_decoder_prefill(
checkpoint_path_or_id=checkpoint_path_or_id,
output_dir=output_dir,
opset=opset,
checkpoint_state_dict=state_dict,
)

# Update config with all export paths
Expand All @@ -490,11 +516,22 @@ def export_text_embedder(
checkpoint_path_or_id: str = DEFAULT_VLM_MODEL_NAME,
output_dir: str | Path = ".",
opset: int = 19,
checkpoint_state_dict: dict | None = None,
) -> Path:
"""Export the token embedding table as a standalone ONNX.

Input: ``input_ids [B, seq]`` int64
Output: ``text_embeds [B, seq, 960]`` float32

Args:
checkpoint_path_or_id: HuggingFace model ID or local path.
output_dir: Directory for output files.
opset: ONNX opset version.
checkpoint_state_dict: Optional fine-tuned checkpoint state dict
(SmolVLA format, keys like ``model.vlm_with_expert.vlm.text_model.*``).
When provided, overwrites the base model's embed_tokens weights with
the fine-tuned values before export so the shipped ONNX reflects the
actual fine-tuned embedding table rather than the base model's.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -507,6 +544,25 @@ def export_text_embedder(
)
model.eval()

# Apply fine-tune weights BEFORE extracting embed_tokens so the exported
# embedding table reflects the fine-tuned checkpoint, not the base model.
# AutoModel returns SmolVLMModel whose state_dict keys are text_model.*,
# vision_model.*, connector.* — matching the rebased checkpoint keys after
# stripping the model.vlm_with_expert.vlm. prefix.
if checkpoint_state_dict is not None:
applied = _apply_checkpoint_vlm_weights(
model, checkpoint_state_dict, tag="text_embedder"
)
if applied == 0:
logger.warning(
"export_text_embedder: zero fine-tune weights applied — "
"exported embed_tokens will use BASE model weights"
)
else:
logger.info(
"export_text_embedder: applied %d fine-tune weights", applied
)

# Find embed_tokens — try multiple attribute paths
embed_tokens = None
for attr_path in [
Expand Down Expand Up @@ -701,6 +757,7 @@ def export_decoder_prefill(
opset: int = 19,
num_layers: int = SMOLVLA_NUM_DECODER_LAYERS,
prefix_seq_len: int = DEFAULT_PREFIX_SEQ_LEN,
checkpoint_state_dict: dict | None = None,
) -> Path:
"""Export the 16-layer SmolLM2 decoder as decoder_prefill.onnx.

Expand All @@ -714,6 +771,11 @@ def export_decoder_prefill(
opset: ONNX opset version (default 19, needed for GQA).
num_layers: Number of decoder layers to keep (default 16).
prefix_seq_len: Sequence length for dummy inputs (default 75).
checkpoint_state_dict: Optional fine-tuned checkpoint state dict
(SmolVLA format, keys like ``model.vlm_with_expert.vlm.text_model.*``).
When provided, overwrites the base model's decoder weights with the
fine-tuned values before export so the KV prefix computed by this
ONNX reflects the fine-tuned decoder rather than the base model's.

Returns:
Path to the exported ``decoder_prefill.onnx`` file.
Expand All @@ -730,6 +792,25 @@ def export_decoder_prefill(
)
model.eval()

# Apply fine-tune weights BEFORE extracting text_model so the decoder KV
# prefix reflects the fine-tuned checkpoint, not the base model.
# AutoModel returns SmolVLMModel whose state_dict keys are text_model.*,
# vision_model.*, connector.* — matching the rebased checkpoint keys after
# stripping the model.vlm_with_expert.vlm. prefix.
if checkpoint_state_dict is not None:
applied = _apply_checkpoint_vlm_weights(
model, checkpoint_state_dict, tag="decoder_prefill"
)
if applied == 0:
logger.warning(
"export_decoder_prefill: zero fine-tune weights applied — "
"exported decoder will use BASE model weights"
)
else:
logger.info(
"export_decoder_prefill: applied %d fine-tune weights", applied
)

# 2. Extract the text decoder (LlamaModel)
text_model = model.text_model
total_layers = len(text_model.layers)
Expand Down
Loading
Loading