Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/tether/exporters/monolithic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,21 @@ def export_monolithic(
except Exception as e:
logger.warning("Weight fusion pass failed (non-fatal): %s", e)

# Re-write VERIFICATION.md now that weight fusion may have replaced
# model.onnx (and/or model.onnx.data) with structurally different bytes.
# The first write happens inside _write_tether_config (called by the
# family exporter above) and hashes the PRE-fusion file; this second
# call overwrites that stale report with hashes of the POST-fusion
# files so that `tether validate` agrees with what is actually on disk.
try:
from tether.verification_report import write_verification_report
write_verification_report(Path(output_dir), parity=None)
logger.debug("VERIFICATION.md refreshed after weight fusion")
except Exception as e:
logger.warning(
"Post-fusion VERIFICATION.md refresh failed (non-fatal): %s", e
)

return result


Expand Down
76 changes: 76 additions & 0 deletions tests/test_export_monolithic_model_type_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,79 @@ def _fake_export_smolvla(model_id, output_dir, *, num_steps=10, target="desktop"
output_dir=tmp_path / "out",
)
assert captured.get("smolvla_called")


# ---------------------------------------------------------------------------
# Post-fusion VERIFICATION.md refresh (bug fix: hash-after-fusion)
# ---------------------------------------------------------------------------


def test_export_monolithic_calls_write_verification_after_fusion(tmp_path, monkeypatch):
"""export_monolithic must call write_verification_report AFTER fuse_weights
so the VERIFICATION.md hashes the post-fusion file, not the pre-fusion one.

We monkeypatch the family exporter (returns a fake onnx_path so fusion
can fire) and both fuse_weights and write_verification_report to track
call order. No model download, no torch, no lerobot required.
"""
output_dir = tmp_path / "out"
output_dir.mkdir()

# Create a fake model.onnx so fuse_weights has a valid path
fake_onnx = output_dir / "model.onnx"
fake_onnx.write_bytes(b"pre-fusion fake onnx")

call_order: list[str] = []

def _fake_export_smolvla(model_id, output_dir, *, num_steps=10, target="desktop"):
call_order.append("family_export")
return {"status": "ok", "onnx_path": str(fake_onnx)}

def _fake_fuse_weights(onnx_path, num_steps=10):
call_order.append("fuse_weights")
# Simulate weight fusion rewriting the file
Path(onnx_path).write_bytes(b"post-fusion fake onnx -- different bytes")
return onnx_path

write_verification_calls: list[str] = []

def _fake_write_verification(export_dir, parity=None, **kwargs):
call_order.append("write_verification_report")
write_verification_calls.append(str(export_dir))
# Return a dummy path so callers don't break
return Path(export_dir) / "VERIFICATION.md"

monkeypatch.setattr(
"tether.exporters.monolithic.export_smolvla_monolithic", _fake_export_smolvla,
)
# Patch fuse_weights at the module level that export_monolithic imports it from
import tether.exporters.weight_fusion as _wf
monkeypatch.setattr(_wf, "fuse_weights", _fake_fuse_weights)
# Patch write_verification_report via its import site in monolithic
import tether.verification_report as _vr
monkeypatch.setattr(_vr, "write_verification_report", _fake_write_verification)

export_monolithic(
model_id="HuggingFaceVLA/smolvla_libero",
output_dir=output_dir,
)

# fuse_weights must appear in the call order
assert "fuse_weights" in call_order, (
"fuse_weights was never called — test setup may be wrong"
)

# write_verification_report must be called AFTER fuse_weights at least once
try:
fuse_idx = call_order.index("fuse_weights")
except ValueError:
pytest.fail("fuse_weights not in call_order")

post_fusion_verification_calls = [
i for i, name in enumerate(call_order)
if name == "write_verification_report" and i > fuse_idx
]
assert post_fusion_verification_calls, (
f"write_verification_report was never called after fuse_weights. "
f"Call order: {call_order}"
)
63 changes: 63 additions & 0 deletions tests/test_verification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,66 @@ def test_overwrites_prior_report(tmp_path):
def test_missing_export_dir_raises(tmp_path):
with pytest.raises(FileNotFoundError):
write_verification_report(tmp_path / "does_not_exist")


def _extract_sha256_for(report_text: str, filename: str) -> str:
"""Pull the sha256 hex string from the VERIFICATION.md table row for *filename*."""
for line in report_text.splitlines():
if filename in line:
# Table row format: | `filename` | <size> | `<sha256>` |
parts = line.split("|")
for part in parts:
part = part.strip().strip("`")
# sha256 hex is exactly 64 hex chars
if len(part) == 64 and all(c in "0123456789abcdef" for c in part):
return part
raise AssertionError(f"Could not find sha256 for {filename!r} in report")


def test_hash_freshness_after_file_mutation(tmp_path):
"""A second write_verification_report call after mutating a file records
the NEW hash, not the stale pre-mutation one.

This directly validates that the post-fusion VERIFICATION.md refresh in
export_monolithic will produce correct hashes when model.onnx bytes change.
"""
export_dir = tmp_path / "export"
export_dir.mkdir()

(export_dir / "tether_config.json").write_text(json.dumps({
"model_id": "lerobot/smolvla_base",
"model_type": "smolvla",
"target": "desktop",
"opset": 19,
"num_denoising_steps": 1,
"chunk_size": 50,
}))
onnx_file = export_dir / "model.onnx"
onnx_file.write_bytes(b"pre-fusion onnx bytes -- version 1")

# First write: hashes the pre-mutation file
write_verification_report(export_dir, parity=None)
text_before = (export_dir / REPORT_FILENAME).read_text()
hash_before = _extract_sha256_for(text_before, "model.onnx")

# Simulate weight fusion: atomically replace model.onnx with different bytes
onnx_file.write_bytes(b"post-fusion onnx bytes -- version 2 structurally different")

# Second write: should recompute and record the new hash
write_verification_report(export_dir, parity=None)
text_after = (export_dir / REPORT_FILENAME).read_text()
hash_after = _extract_sha256_for(text_after, "model.onnx")

assert hash_before != hash_after, (
"Expected hash to change after mutating model.onnx, "
f"but both writes recorded: {hash_before}"
)
# Confirm the second hash matches what we compute independently
import hashlib
expected = hashlib.sha256(
b"post-fusion onnx bytes -- version 2 structurally different"
).hexdigest()
assert hash_after == expected, (
f"Post-mutation hash in report ({hash_after!r}) does not match "
f"expected sha256 ({expected!r})"
)
Loading