fix(exporters): apply fine-tune weights to text embedder + decoder prefill#242
Open
rylinjames wants to merge 1 commit into
Open
fix(exporters): apply fine-tune weights to text embedder + decoder prefill#242rylinjames wants to merge 1 commit into
rylinjames wants to merge 1 commit into
Conversation
…efill Fixes silent base-model export: export_vlm_prefix applied the user's fine-tuned checkpoint to the vision encoder but NOT to export_text_embedder or export_decoder_prefill, so both ONNX files always shipped base-model weights even when a fine-tune state_dict was provided. - Add checkpoint_state_dict param to export_text_embedder and export_decoder_prefill; apply via _apply_checkpoint_vlm_weights before extracting embed_tokens / text_model so Python reference semantics carry the updated weights into the exported submodule. - Thread state_dict=state_dict from export_vlm_prefix's two call sites. - Add tag arg to _apply_checkpoint_vlm_weights for per-sub-model log attribution (e.g. "[vlm-weights/text_embedder]"). - Fix applied-count return: was len(rebased) (checkpoint keys); now total_model_keys - missing_keys (actually applied). Add zero-applied WARNING log so silent no-ops are visible. - Fix pre-existing F841 ruff violations (unused text_emb_path / decoder_path assignments in export_vlm_prefix). - Add 9 unit tests in TestApplyCheckpointVLMWeights using a synthetic SmolVLMModel-shaped nn.Module (no HF download); tests cover: non-zero guard, full-key application, value correctness, None noop, embed_tokens reference semantics, text_model reference semantics, bad-prefix zero return, partial checkpoint, tag invariance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
When a user exports a fine-tuned SmolVLA checkpoint via
export_vlm_prefix, the fine-tunedstate_dictwas applied to the vision encoder only.export_text_embedderandexport_decoder_prefilleach calledAutoModel.from_pretrainedindependently and applied no fine-tune weights, silently shipping base-model text embeddings and decoder KV prefix.A user deploying a fine-tuned SmolVLA via the decomposed path got:
vision_encoder.onnx-- fine-tuned (correct)text_embedder.onnx-- base model (wrong; embed_tokens table was never updated)decoder_prefill.onnx-- base model (wrong; all 16 decoder layer k/v projections were never updated)Key-namespace evidence (no silent no-op)
AutoModel.from_pretrained(SmolVLM2)returnsSmolVLMModelwith state_dict keys:After stripping
model.vlm_with_expert.vlm.from the checkpoint:These are identical namespaces. Verified with a synthetic
_SmolVLMModelmodule:load_state_dict(strict=False)returns 0 missing, 0 unexpected -- applied == total_model_keys.Applied-count logging fix
_apply_checkpoint_vlm_weightspreviously returnedlen(rebased)(number of checkpoint keys) instead of the actual applied count (total_model_keys - missing). Now it correctly computes and logs:Zero-applied raises a WARNING so silent no-ops are visible in future.
Changes
_apply_checkpoint_vlm_weights: addtagparam for per-sub-model log attribution; fix return value to actual applied count; add zero-applied WARNING.export_text_embedder: addcheckpoint_state_dictparam; apply weights to fullAutoModelbefore extractingembed_tokens(Python reference semantics ensure the extracted submodule already carries fine-tuned weights).export_decoder_prefill: addcheckpoint_state_dictparam; apply weights to fullAutoModelbefore extractingtext_model.export_vlm_prefix: threadcheckpoint_state_dict=state_dictto both call sites; also fix two pre-existing F841 ruff violations.Tests (no HF download)
9 new unit tests in
TestApplyCheckpointVLMWeightsusing a synthetic_SmolVLMModel-shapednn.Module:test_non_zero_keys_applied-- guard against silent no-op: applied > 0test_applied_count_equals_total_model_keys-- full coverage checkpoint applies all keystest_finetune_values_actually_loaded-- param values equal fill value after applytest_none_checkpoint_is_a_noop-- None checkpoint leaves model unchangedtest_embed_tokens_reference_updated_after_apply-- text embedder pattern: apply to full model, extract embed_tokens, values updatedtest_text_model_reference_updated_after_apply-- decoder prefill pattern: apply to full model, extract text_model, k_proj/v_proj updatedtest_no_prefix_match_returns_zero-- bad checkpoint prefix returns 0, model unchangedtest_partial_checkpoint_partial_apply-- text_model-only checkpoint updates only text_model keystest_tag_does_not_affect_result-- tag is cosmetic onlyAll 34 tests pass (1 integration test skipped); ruff clean; py_compile clean.
Follow-up (noted, not blocked)
Tokenizer:
export_text_embedderloads the model viaAutoModel, notAutoTokenizer, so any fine-tuned vocabulary additions in the tokenizer are not applied here. The tokenizer is loaded separately inVLMPrefixOrchestrator._load_tokenizer_and_processorfromcheckpoint_path_or_id-- this is likely correct already, but worth an explicit integration test onceTETHER_INTEGRATION=1infra is wired.Generated with Claude Code