From b66916dbcdca869df9d69e93b7fa4648d83a0923 Mon Sep 17 00:00:00 2001 From: RomirJ Date: Wed, 10 Jun 2026 22:37:42 -0700 Subject: [PATCH] fix(exporters): apply fine-tune weights to text embedder + decoder prefill Fixes silent base-model export: export_vlm_prefix applied the user's fine-tuned checkpoint to the vision encoder but NOT to export_text_embedder or export_decoder_prefill, so both ONNX files always shipped base-model weights even when a fine-tune state_dict was provided. - Add checkpoint_state_dict param to export_text_embedder and export_decoder_prefill; apply via _apply_checkpoint_vlm_weights before extracting embed_tokens / text_model so Python reference semantics carry the updated weights into the exported submodule. - Thread state_dict=state_dict from export_vlm_prefix's two call sites. - Add tag arg to _apply_checkpoint_vlm_weights for per-sub-model log attribution (e.g. "[vlm-weights/text_embedder]"). - Fix applied-count return: was len(rebased) (checkpoint keys); now total_model_keys - missing_keys (actually applied). Add zero-applied WARNING log so silent no-ops are visible. - Fix pre-existing F841 ruff violations (unused text_emb_path / decoder_path assignments in export_vlm_prefix). - Add 9 unit tests in TestApplyCheckpointVLMWeights using a synthetic SmolVLMModel-shaped nn.Module (no HF download); tests cover: non-zero guard, full-key application, value correctness, None noop, embed_tokens reference semantics, text_model reference semantics, bad-prefix zero return, partial checkpoint, tag invariance. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/tether/exporters/vlm_prefix_exporter.py | 111 ++++++-- tests/test_vlm_prefix.py | 268 ++++++++++++++++++++ 2 files changed, 364 insertions(+), 15 deletions(-) diff --git a/src/tether/exporters/vlm_prefix_exporter.py b/src/tether/exporters/vlm_prefix_exporter.py index 8fcaedf..8d96f7b 100644 --- a/src/tether/exporters/vlm_prefix_exporter.py +++ b/src/tether/exporters/vlm_prefix_exporter.py @@ -213,8 +213,23 @@ def patch_onnx_type_mismatches(onnx_path: str | Path) -> int: return fixed_count -def _apply_checkpoint_vlm_weights(base_model, checkpoint_state_dict): - """Overwrite the base VLM's weights with fine-tuned ones from the checkpoint.""" +def _apply_checkpoint_vlm_weights( + base_model, + checkpoint_state_dict, + tag: str = "vlm", +) -> int: + """Overwrite the base VLM's weights with fine-tuned ones from the checkpoint. + + Args: + base_model: The nn.Module to update in-place. + checkpoint_state_dict: Fine-tuned checkpoint state dict (SmolVLA format). + tag: Short label used in log messages to identify which sub-model is + being updated (e.g. "text_embedder", "decoder_prefill"). + + Returns: + Number of keys actually applied (total model keys − missing keys). + Returns 0 when no matching prefix is found or on error. + """ prefixes = [ "model.vlm_with_expert.vlm.", "vlm_with_expert.vlm.", @@ -233,7 +248,7 @@ def _apply_checkpoint_vlm_weights(base_model, checkpoint_state_dict): if len(parts) >= 3: top.add(".".join(parts[:3])) print( - f"[vlm-weights] No matching VLM prefix found. " + f"[vlm-weights/{tag}] No matching VLM prefix found. " f"Top-3 prefixes sample: {sorted(list(top))[:5]}", flush=True, ) @@ -245,27 +260,36 @@ def _apply_checkpoint_vlm_weights(base_model, checkpoint_state_dict): if k.startswith(matched_prefix) } print( - f"[vlm-weights] matched_prefix={matched_prefix!r}, rebased {len(rebased)} keys", + f"[vlm-weights/{tag}] matched_prefix={matched_prefix!r}, rebased {len(rebased)} keys", flush=True, ) print( - f"[vlm-weights] sample rebased keys: {sorted(list(rebased.keys()))[:5]}", + f"[vlm-weights/{tag}] sample rebased keys: {sorted(list(rebased.keys()))[:5]}", flush=True, ) try: missing, unexpected = base_model.load_state_dict(rebased, strict=False) + total_model_keys = len(list(base_model.state_dict().keys())) + applied = total_model_keys - len(missing) print( - f"[vlm-weights] load: {len(missing)} missing, {len(unexpected)} unexpected", + f"[vlm-weights/{tag}] applied {applied}/{total_model_keys} fine-tune weights " + f"({len(missing)} missing, {len(unexpected)} unexpected)", flush=True, ) + if applied == 0: + print( + f"[vlm-weights/{tag}] WARNING: zero keys applied — fine-tune weights " + f"did NOT update this sub-model. Check key namespace alignment.", + flush=True, + ) if missing: - print(f"[vlm-weights] first 5 missing: {missing[:5]}", flush=True) + print(f"[vlm-weights/{tag}] first 5 missing: {missing[:5]}", flush=True) if unexpected: - print(f"[vlm-weights] first 5 unexpected: {unexpected[:5]}", flush=True) - return len(rebased) + print(f"[vlm-weights/{tag}] first 5 unexpected: {unexpected[:5]}", flush=True) + return applied except Exception as e: - print(f"[vlm-weights] FAILED: {e}", flush=True) + print(f"[vlm-weights/{tag}] FAILED: {e}", flush=True) return 0 @@ -319,7 +343,7 @@ def export_vlm_prefix( ) model.eval() if state_dict is not None: - _apply_checkpoint_vlm_weights(model, state_dict) + _apply_checkpoint_vlm_weights(model, state_dict, tag="vision_encoder") # The AutoModelForImageTextToText wrapper has `self.model = SmolVLMModel(...)` # whose sub-modules are the actual vision/connector/text_model. Drill down @@ -457,18 +481,20 @@ def export_vlm_prefix( config_path.write_text(json.dumps(config, indent=2)) logger.info("Updated config: %s", config_path) - # 8. Export text embedder - text_emb_path = export_text_embedder( + # 8. Export text embedder (pass fine-tune state_dict so embed_tokens are updated) + export_text_embedder( checkpoint_path_or_id=checkpoint_path_or_id, output_dir=output_dir, opset=opset, + checkpoint_state_dict=state_dict, ) - # 9. Export decoder prefill - decoder_path = export_decoder_prefill( + # 9. Export decoder prefill (pass fine-tune state_dict so decoder k/v are updated) + export_decoder_prefill( checkpoint_path_or_id=checkpoint_path_or_id, output_dir=output_dir, opset=opset, + checkpoint_state_dict=state_dict, ) # Update config with all export paths @@ -490,11 +516,22 @@ def export_text_embedder( checkpoint_path_or_id: str = DEFAULT_VLM_MODEL_NAME, output_dir: str | Path = ".", opset: int = 19, + checkpoint_state_dict: dict | None = None, ) -> Path: """Export the token embedding table as a standalone ONNX. Input: ``input_ids [B, seq]`` int64 Output: ``text_embeds [B, seq, 960]`` float32 + + Args: + checkpoint_path_or_id: HuggingFace model ID or local path. + output_dir: Directory for output files. + opset: ONNX opset version. + checkpoint_state_dict: Optional fine-tuned checkpoint state dict + (SmolVLA format, keys like ``model.vlm_with_expert.vlm.text_model.*``). + When provided, overwrites the base model's embed_tokens weights with + the fine-tuned values before export so the shipped ONNX reflects the + actual fine-tuned embedding table rather than the base model's. """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -507,6 +544,25 @@ def export_text_embedder( ) model.eval() + # Apply fine-tune weights BEFORE extracting embed_tokens so the exported + # embedding table reflects the fine-tuned checkpoint, not the base model. + # AutoModel returns SmolVLMModel whose state_dict keys are text_model.*, + # vision_model.*, connector.* — matching the rebased checkpoint keys after + # stripping the model.vlm_with_expert.vlm. prefix. + if checkpoint_state_dict is not None: + applied = _apply_checkpoint_vlm_weights( + model, checkpoint_state_dict, tag="text_embedder" + ) + if applied == 0: + logger.warning( + "export_text_embedder: zero fine-tune weights applied — " + "exported embed_tokens will use BASE model weights" + ) + else: + logger.info( + "export_text_embedder: applied %d fine-tune weights", applied + ) + # Find embed_tokens — try multiple attribute paths embed_tokens = None for attr_path in [ @@ -701,6 +757,7 @@ def export_decoder_prefill( opset: int = 19, num_layers: int = SMOLVLA_NUM_DECODER_LAYERS, prefix_seq_len: int = DEFAULT_PREFIX_SEQ_LEN, + checkpoint_state_dict: dict | None = None, ) -> Path: """Export the 16-layer SmolLM2 decoder as decoder_prefill.onnx. @@ -714,6 +771,11 @@ def export_decoder_prefill( opset: ONNX opset version (default 19, needed for GQA). num_layers: Number of decoder layers to keep (default 16). prefix_seq_len: Sequence length for dummy inputs (default 75). + checkpoint_state_dict: Optional fine-tuned checkpoint state dict + (SmolVLA format, keys like ``model.vlm_with_expert.vlm.text_model.*``). + When provided, overwrites the base model's decoder weights with the + fine-tuned values before export so the KV prefix computed by this + ONNX reflects the fine-tuned decoder rather than the base model's. Returns: Path to the exported ``decoder_prefill.onnx`` file. @@ -730,6 +792,25 @@ def export_decoder_prefill( ) model.eval() + # Apply fine-tune weights BEFORE extracting text_model so the decoder KV + # prefix reflects the fine-tuned checkpoint, not the base model. + # AutoModel returns SmolVLMModel whose state_dict keys are text_model.*, + # vision_model.*, connector.* — matching the rebased checkpoint keys after + # stripping the model.vlm_with_expert.vlm. prefix. + if checkpoint_state_dict is not None: + applied = _apply_checkpoint_vlm_weights( + model, checkpoint_state_dict, tag="decoder_prefill" + ) + if applied == 0: + logger.warning( + "export_decoder_prefill: zero fine-tune weights applied — " + "exported decoder will use BASE model weights" + ) + else: + logger.info( + "export_decoder_prefill: applied %d fine-tune weights", applied + ) + # 2. Extract the text decoder (LlamaModel) text_model = model.text_model total_layers = len(text_model.layers) diff --git a/tests/test_vlm_prefix.py b/tests/test_vlm_prefix.py index 95d14b3..0c5531f 100644 --- a/tests/test_vlm_prefix.py +++ b/tests/test_vlm_prefix.py @@ -19,6 +19,7 @@ from tether.exporters.vlm_prefix_exporter import ( DEFAULT_VLM_KV_DIM, + _apply_checkpoint_vlm_weights, ) from tether.runtime.vlm_components import ( HIDDEN_SIZE, @@ -723,3 +724,270 @@ def test_different_instructions_different_prefix(self): # This test requires real model weights and would only run in CI # with TETHER_INTEGRATION=1 pytest.skip("Integration test requires real model checkpoint") + + +# --------------------------------------------------------------------------- +# Test 13: _apply_checkpoint_vlm_weights — unit tests (no HF model download) +# +# These tests use synthetic nn.Module trees whose state_dict() key namespaces +# mirror the real SmolVLMModel (AutoModel.from_pretrained) layout: +# text_model.embed_tokens.weight +# text_model.norm.weight +# text_model.layers.0.self_attn.k_proj.weight / .bias +# text_model.layers.0.self_attn.v_proj.weight / .bias +# text_model.layers.0.input_layernorm.weight / .bias +# vision_model.embeddings.patch_embedding.weight / .bias +# vision_model.post_layernorm.weight / .bias +# connector.weight / .bias +# +# Checkpoint keys use the SmolVLA prefix model.vlm_with_expert.vlm. so that +# _apply_checkpoint_vlm_weights matches on the first candidate. +# --------------------------------------------------------------------------- + + +def _build_synthetic_smolvlm_model(): + """Return a tiny nn.Module whose state_dict key namespace mirrors SmolVLMModel.""" + + class _SelfAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.k_proj = torch.nn.Linear(8, 4) + self.v_proj = torch.nn.Linear(8, 4) + + class _Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _SelfAttn() + self.input_layernorm = torch.nn.LayerNorm(8) + + class _TextModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed_tokens = torch.nn.Embedding(50, 8) + self.norm = torch.nn.LayerNorm(8) + self.layers = torch.nn.ModuleList([_Layer() for _ in range(2)]) + + class _VisionEmbeddings(torch.nn.Module): + def __init__(self): + super().__init__() + self.patch_embedding = torch.nn.Linear(9, 8) + + class _VisionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.embeddings = _VisionEmbeddings() + self.post_layernorm = torch.nn.LayerNorm(8) + + class _SmolVLMModel(torch.nn.Module): + """Mirrors AutoModel.from_pretrained(SmolVLM2) key layout.""" + def __init__(self): + super().__init__() + self.text_model = _TextModel() + self.vision_model = _VisionModel() + self.connector = torch.nn.Linear(8, 8) + + return _SmolVLMModel() + + +def _make_finetune_checkpoint(model, prefix="model.vlm_with_expert.vlm.", fill=99.0): + """Return a checkpoint state_dict with all-fill values for the given model.""" + return {prefix + k: torch.full_like(v, fill) for k, v in model.state_dict().items()} + + +class TestApplyCheckpointVLMWeights: + """Unit tests for _apply_checkpoint_vlm_weights using synthetic SmolVLMModel. + + Key namespace evidence: + Checkpoint (after stripping model.vlm_with_expert.vlm.): + text_model.embed_tokens.weight, text_model.norm.weight, + text_model.layers.0.self_attn.k_proj.weight, ..., connector.weight, ... + AutoModel state_dict keys (SmolVLMModel): + text_model.embed_tokens.weight, text_model.norm.weight, + text_model.layers.0.self_attn.k_proj.weight, ..., connector.weight, ... + These are IDENTICAL after prefix strip → non-zero application guaranteed. + """ + + def test_non_zero_keys_applied(self): + """Guard against silent no-op: applied count must be > 0.""" + model = _build_synthetic_smolvlm_model() + checkpoint = _make_finetune_checkpoint(model, fill=99.0) + + applied = _apply_checkpoint_vlm_weights(model, checkpoint, tag="test") + + assert applied > 0, ( + f"Expected non-zero keys applied, got {applied}. " + "Checkpoint keys do not match model key namespace — silent no-op." + ) + + def test_applied_count_equals_total_model_keys(self): + """When checkpoint covers all model keys, applied == total model keys.""" + model = _build_synthetic_smolvlm_model() + total_keys = len(list(model.state_dict().keys())) + checkpoint = _make_finetune_checkpoint(model, fill=99.0) + + applied = _apply_checkpoint_vlm_weights(model, checkpoint, tag="test") + + assert applied == total_keys, ( + f"Expected all {total_keys} keys applied, got {applied}" + ) + + def test_finetune_values_actually_loaded(self): + """After application, model params equal the fine-tune values.""" + model = _build_synthetic_smolvlm_model() + fill_value = 77.0 + checkpoint = _make_finetune_checkpoint(model, fill=fill_value) + + _apply_checkpoint_vlm_weights(model, checkpoint, tag="test") + + # Check a representative selection of parameters + embed_weight = model.text_model.embed_tokens.weight + assert float(embed_weight[0, 0]) == pytest.approx(fill_value), ( + f"embed_tokens not updated: got {float(embed_weight[0, 0])}, expected {fill_value}" + ) + + k_proj_weight = model.text_model.layers[0].self_attn.k_proj.weight + assert float(k_proj_weight[0, 0]) == pytest.approx(fill_value), ( + f"layers[0].self_attn.k_proj not updated: got {float(k_proj_weight[0, 0])}" + ) + + connector_weight = model.connector.weight + assert float(connector_weight[0, 0]) == pytest.approx(fill_value), ( + f"connector.weight not updated: got {float(connector_weight[0, 0])}" + ) + + def test_none_checkpoint_is_a_noop(self): + """When checkpoint_state_dict is None, the caller skips the call. + + This test verifies the expected caller pattern: ``if checkpoint is not None`` + guards the call, so a model loaded without a fine-tune checkpoint retains + its original base weights. + """ + model = _build_synthetic_smolvlm_model() + original_sd = {k: v.clone() for k, v in model.state_dict().items()} + + # Caller pattern — guard is on the call site, not inside the function. + checkpoint = None + if checkpoint is not None: + _apply_checkpoint_vlm_weights(model, checkpoint, tag="test") + + for k, orig_v in original_sd.items(): + current_v = model.state_dict()[k] + assert torch.allclose(current_v, orig_v), ( + f"Parameter {k!r} changed when checkpoint was None" + ) + + def test_embed_tokens_reference_updated_after_apply(self): + """embed_tokens extracted AFTER apply reflects fine-tune values (reference semantics). + + This is the critical property for export_text_embedder: the function + applies weights to the full AutoModel, then extracts embed_tokens by + attribute traversal. Since embed_tokens is a Python reference, it must + already carry the fine-tuned weights. + """ + model = _build_synthetic_smolvlm_model() + fill_value = 42.0 + checkpoint = _make_finetune_checkpoint(model, fill=fill_value) + + # Simulate what export_text_embedder does: + # 1. apply_checkpoint_vlm_weights(full_model, ckpt) + # 2. embed_tokens = model.text_model.embed_tokens (extracted after) + _apply_checkpoint_vlm_weights(model, checkpoint, tag="test_text_embedder") + embed_tokens = model.text_model.embed_tokens + + assert float(embed_tokens.weight[0, 0]) == pytest.approx(fill_value), ( + "embed_tokens.weight not updated via reference after full-model apply" + ) + + def test_text_model_reference_updated_after_apply(self): + """text_model extracted AFTER apply reflects fine-tune values (decoder prefill pattern). + + This is the critical property for export_decoder_prefill: the function + applies weights to the full AutoModel, then extracts text_model. The + text_model (and its layers) is a Python submodule reference and must + carry fine-tuned k_proj / v_proj weights. + """ + model = _build_synthetic_smolvlm_model() + fill_value = 55.0 + checkpoint = _make_finetune_checkpoint(model, fill=fill_value) + + # Simulate what export_decoder_prefill does: + # 1. apply_checkpoint_vlm_weights(full_model, ckpt) + # 2. text_model = model.text_model (extracted after) + _apply_checkpoint_vlm_weights(model, checkpoint, tag="test_decoder_prefill") + text_model = model.text_model + + k_proj = text_model.layers[0].self_attn.k_proj.weight + assert float(k_proj[0, 0]) == pytest.approx(fill_value), ( + "text_model.layers[0].self_attn.k_proj not updated after full-model apply" + ) + + v_proj = text_model.layers[1].self_attn.v_proj.weight + assert float(v_proj[0, 0]) == pytest.approx(fill_value), ( + "text_model.layers[1].self_attn.v_proj not updated after full-model apply" + ) + + def test_no_prefix_match_returns_zero(self): + """When checkpoint has no known prefix, applied count is 0 and model unchanged.""" + model = _build_synthetic_smolvlm_model() + original_sd = {k: v.clone() for k, v in model.state_dict().items()} + + # Build checkpoint with an unrecognized prefix + bad_checkpoint = { + "unknown.prefix." + k: torch.ones_like(v) * 123.0 + for k, v in model.state_dict().items() + } + + applied = _apply_checkpoint_vlm_weights(model, bad_checkpoint, tag="test_bad_prefix") + + assert applied == 0, f"Expected 0 applied with bad prefix, got {applied}" + for k, orig_v in original_sd.items(): + current_v = model.state_dict()[k] + assert torch.allclose(current_v, orig_v), ( + f"Parameter {k!r} changed despite bad prefix (no keys should have matched)" + ) + + def test_partial_checkpoint_partial_apply(self): + """When checkpoint covers only text_model keys, only those are applied.""" + model = _build_synthetic_smolvlm_model() + prefix = "model.vlm_with_expert.vlm." + fill_value = 33.0 + + # Only include text_model keys + partial_checkpoint = { + prefix + k: torch.full_like(v, fill_value) + for k, v in model.state_dict().items() + if k.startswith("text_model.") + } + text_model_key_count = sum( + 1 for k in model.state_dict() if k.startswith("text_model.") + ) + + applied = _apply_checkpoint_vlm_weights(model, partial_checkpoint, tag="test_partial") + + assert applied == text_model_key_count, ( + f"Expected {text_model_key_count} applied for partial ckpt, got {applied}" + ) + # text_model keys should be updated + assert float(model.text_model.embed_tokens.weight[0, 0]) == pytest.approx(fill_value) + # connector keys should NOT be updated (still random init, not 33.0) + connector_val = float(model.connector.weight[0, 0]) + assert connector_val != pytest.approx(fill_value), ( + "connector.weight was updated but shouldn't have been in partial checkpoint" + ) + + def test_tag_does_not_affect_result(self): + """Changing the tag parameter does not affect which keys are applied.""" + model_a = _build_synthetic_smolvlm_model() + model_b = _build_synthetic_smolvlm_model() + # Make models start from the same state + model_b.load_state_dict(model_a.state_dict()) + + checkpoint = _make_finetune_checkpoint(model_a, fill=11.0) + + applied_a = _apply_checkpoint_vlm_weights(model_a, checkpoint, tag="text_embedder") + applied_b = _apply_checkpoint_vlm_weights(model_b, checkpoint, tag="decoder_prefill") + + assert applied_a == applied_b, ( + f"tag should not affect applied count: text_embedder={applied_a}, " + f"decoder_prefill={applied_b}" + )