Skip to content

fix(exporters): apply fine-tune weights to text embedder + decoder prefill#242

Open
rylinjames wants to merge 1 commit into
mainfrom
fix/finetune-text-decoder-weights
Open

fix(exporters): apply fine-tune weights to text embedder + decoder prefill#242
rylinjames wants to merge 1 commit into
mainfrom
fix/finetune-text-decoder-weights

Conversation

@rylinjames

Copy link
Copy Markdown
Collaborator

Bug

When a user exports a fine-tuned SmolVLA checkpoint via export_vlm_prefix, the fine-tuned state_dict was applied to the vision encoder only. export_text_embedder and export_decoder_prefill each called AutoModel.from_pretrained independently and applied no fine-tune weights, silently shipping base-model text embeddings and decoder KV prefix.

A user deploying a fine-tuned SmolVLA via the decomposed path got:

  • vision_encoder.onnx -- fine-tuned (correct)
  • text_embedder.onnx -- base model (wrong; embed_tokens table was never updated)
  • decoder_prefill.onnx -- base model (wrong; all 16 decoder layer k/v projections were never updated)

Key-namespace evidence (no silent no-op)

AutoModel.from_pretrained(SmolVLM2) returns SmolVLMModel with state_dict keys:

text_model.embed_tokens.weight
text_model.layers.0.self_attn.k_proj.weight
text_model.layers.0.self_attn.v_proj.weight
...
vision_model.embeddings.patch_embedding.weight
connector.weight

After stripping model.vlm_with_expert.vlm. from the checkpoint:

text_model.embed_tokens.weight           <- matches
text_model.layers.0.self_attn.k_proj.weight  <- matches
...
connector.weight                         <- matches

These are identical namespaces. Verified with a synthetic _SmolVLMModel module: load_state_dict(strict=False) returns 0 missing, 0 unexpected -- applied == total_model_keys.

Applied-count logging fix

_apply_checkpoint_vlm_weights previously returned len(rebased) (number of checkpoint keys) instead of the actual applied count (total_model_keys - missing). Now it correctly computes and logs:

[vlm-weights/text_embedder] applied 247/247 fine-tune weights (0 missing, 0 unexpected)

Zero-applied raises a WARNING so silent no-ops are visible in future.

Changes

  • _apply_checkpoint_vlm_weights: add tag param for per-sub-model log attribution; fix return value to actual applied count; add zero-applied WARNING.
  • export_text_embedder: add checkpoint_state_dict param; apply weights to full AutoModel before extracting embed_tokens (Python reference semantics ensure the extracted submodule already carries fine-tuned weights).
  • export_decoder_prefill: add checkpoint_state_dict param; apply weights to full AutoModel before extracting text_model.
  • export_vlm_prefix: thread checkpoint_state_dict=state_dict to both call sites; also fix two pre-existing F841 ruff violations.

Tests (no HF download)

9 new unit tests in TestApplyCheckpointVLMWeights using a synthetic _SmolVLMModel-shaped nn.Module:

  1. test_non_zero_keys_applied -- guard against silent no-op: applied > 0
  2. test_applied_count_equals_total_model_keys -- full coverage checkpoint applies all keys
  3. test_finetune_values_actually_loaded -- param values equal fill value after apply
  4. test_none_checkpoint_is_a_noop -- None checkpoint leaves model unchanged
  5. test_embed_tokens_reference_updated_after_apply -- text embedder pattern: apply to full model, extract embed_tokens, values updated
  6. test_text_model_reference_updated_after_apply -- decoder prefill pattern: apply to full model, extract text_model, k_proj/v_proj updated
  7. test_no_prefix_match_returns_zero -- bad checkpoint prefix returns 0, model unchanged
  8. test_partial_checkpoint_partial_apply -- text_model-only checkpoint updates only text_model keys
  9. test_tag_does_not_affect_result -- tag is cosmetic only

All 34 tests pass (1 integration test skipped); ruff clean; py_compile clean.

Follow-up (noted, not blocked)

Tokenizer: export_text_embedder loads the model via AutoModel, not AutoTokenizer, so any fine-tuned vocabulary additions in the tokenizer are not applied here. The tokenizer is loaded separately in VLMPrefixOrchestrator._load_tokenizer_and_processor from checkpoint_path_or_id -- this is likely correct already, but worth an explicit integration test once TETHER_INTEGRATION=1 infra is wired.

Generated with Claude Code

…efill

Fixes silent base-model export: export_vlm_prefix applied the user's
fine-tuned checkpoint to the vision encoder but NOT to export_text_embedder
or export_decoder_prefill, so both ONNX files always shipped base-model
weights even when a fine-tune state_dict was provided.

- Add checkpoint_state_dict param to export_text_embedder and
  export_decoder_prefill; apply via _apply_checkpoint_vlm_weights before
  extracting embed_tokens / text_model so Python reference semantics
  carry the updated weights into the exported submodule.
- Thread state_dict=state_dict from export_vlm_prefix's two call sites.
- Add tag arg to _apply_checkpoint_vlm_weights for per-sub-model log
  attribution (e.g. "[vlm-weights/text_embedder]").
- Fix applied-count return: was len(rebased) (checkpoint keys); now
  total_model_keys - missing_keys (actually applied). Add zero-applied
  WARNING log so silent no-ops are visible.
- Fix pre-existing F841 ruff violations (unused text_emb_path /
  decoder_path assignments in export_vlm_prefix).
- Add 9 unit tests in TestApplyCheckpointVLMWeights using a synthetic
  SmolVLMModel-shaped nn.Module (no HF download); tests cover: non-zero
  guard, full-key application, value correctness, None noop, embed_tokens
  reference semantics, text_model reference semantics, bad-prefix zero
  return, partial checkpoint, tag invariance.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant