From 84741d5294941f0cb0d6556f866b6b9c110c5230 Mon Sep 17 00:00:00 2001 From: RomirJ Date: Thu, 11 Jun 2026 01:27:09 -0700 Subject: [PATCH] fix(openvla): match action decode to bin centers Use the OpenVLA effective text vocab for action token decode and map bins through centers between the configured action edges. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/tether/exporters/openvla.py | 6 +++- src/tether/postprocess/openvla.py | 22 +++++++++---- tests/test_openvla_postprocess.py | 55 ++++++++++++++++++++----------- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/tether/exporters/openvla.py b/src/tether/exporters/openvla.py index 047cbf9d..3a3847c6 100644 --- a/src/tether/exporters/openvla.py +++ b/src/tether/exporters/openvla.py @@ -23,9 +23,13 @@ followed by a bin-to-continuous lookup: bin_idx = vocab_size - token_id - 1 - action_normalized = linspace(-1, 1, 256)[bin_idx] + action_normalized = bin_centers[bin_idx] action = unnormalize(action_normalized, norm_stats[dataset]) +where ``vocab_size`` is the effective text vocab +(``text_config.vocab_size - pad_to_multiple_of``, 32000 for openvla-7b) +and ``bin_centers`` are the centers between 256 edges over [-1, 1]. + There is no dedicated action expert to reconstruct. The full model is Llama-2-7B + DINOv2 + SigLIP + 3-layer projector — ~7.5B params of standard transformers architecture that HuggingFace's optimum-onnx diff --git a/src/tether/postprocess/openvla.py b/src/tether/postprocess/openvla.py index fd88332a..c103c379 100644 --- a/src/tether/postprocess/openvla.py +++ b/src/tether/postprocess/openvla.py @@ -4,9 +4,13 @@ which map onto the top N bins of the vocabulary via: bin_idx = vocab_size - token_id - 1 - action_normalized = linspace(action_low, action_high, N)[bin_idx] + action_normalized = bin_centers[bin_idx] action_unnorm = unnormalize(action_normalized, norm_stats[dataset]) +OpenVLA's HuggingFace model computes the decode vocab as +`text_config.vocab_size - pad_to_multiple_of` (32000 for openvla-7b), then +uses centers between `n_action_bins` edges. + This module provides the postprocessing step that wraps a standard Llama ONNX/PyTorch forward pass and turns its logits into actions. Works with either the full LM output (logits over vocab) or the @@ -41,13 +45,13 @@ def tokens_to_action_bins( vocab_size: int, n_bins: int = 256, ) -> np.ndarray: - """Convert predicted tokens to bin indices in [0, n_bins). + """Convert predicted tokens to bin-center indices in [0, n_bins - 2]. OpenVLA assigns the top n_bins tokens of the Llama vocab to actions: bin_idx = vocab_size - token_id - 1 """ bin_idx = vocab_size - token_ids - 1 - return np.clip(bin_idx, 0, n_bins - 1) + return np.clip(bin_idx, 0, n_bins - 2) def bins_to_normalized( @@ -56,9 +60,15 @@ def bins_to_normalized( action_low: float = -1.0, action_high: float = 1.0, ) -> np.ndarray: - """Map bin indices to normalized actions in [action_low, action_high].""" + """Map bin indices to normalized action bin centers. + + OpenVLA defines `n_action_bins` edges and uses the centers between them, + so the largest valid bin index is `n_bins - 2`. + """ bins = np.linspace(action_low, action_high, n_bins, dtype=np.float32) - return bins[bin_idx] + bin_centers = (bins[:-1] + bins[1:]) / 2.0 + safe_idx = np.clip(bin_idx, 0, bin_centers.shape[0] - 1) + return bin_centers[safe_idx] def unnormalize_actions( @@ -100,7 +110,7 @@ def decode_actions( action_dim: int, norm_stats: dict[str, Any] | None = None, dataset_name: str | None = None, - vocab_size: int = 32064, + vocab_size: int = 32000, n_bins: int = 256, action_low: float = -1.0, action_high: float = 1.0, diff --git a/tests/test_openvla_postprocess.py b/tests/test_openvla_postprocess.py index 47d45fd4..116c7cc2 100644 --- a/tests/test_openvla_postprocess.py +++ b/tests/test_openvla_postprocess.py @@ -4,14 +4,25 @@ import pytest from tether.postprocess.openvla import ( + bins_to_normalized, + decode_actions, logits_to_tokens, tokens_to_action_bins, - bins_to_normalized, unnormalize_actions, - decode_actions, ) +def _center( + index: int, + n_bins: int = 256, + action_low: float = -1.0, + action_high: float = 1.0, +) -> np.float32: + edges = np.linspace(action_low, action_high, n_bins, dtype=np.float32) + centers = (edges[:-1] + edges[1:]) / 2.0 + return centers[index] + + class TestLogitsToTokens: def test_picks_argmax_last_action_dim(self): # Batch of 1, seq 10, vocab 100; force token 42 at positions -3: @@ -28,28 +39,28 @@ def test_rejects_wrong_ndim(self): class TestTokensToActionBins: def test_top_token_is_bin_zero(self): - # top-of-vocab token (vocab_size-1) → bin 0 - tokens = np.array([[32063, 32062, 32061]]) - bins = tokens_to_action_bins(tokens, vocab_size=32064, n_bins=256) + # OpenVLA decodes against effective vocab_size=32000, not the padded LM size. + tokens = np.array([[31999, 31998, 31997]]) + bins = tokens_to_action_bins(tokens, vocab_size=32000, n_bins=256) assert (bins == np.array([[0, 1, 2]])).all() def test_clips_out_of_range(self): - # Any token below the top 256 should clip to bin 255 (lowest bin) + # Any token below the action token band clips to the last valid center. tokens = np.array([[0, 100, 1000]]) - bins = tokens_to_action_bins(tokens, vocab_size=32064, n_bins=256) - assert (bins == 255).all() + bins = tokens_to_action_bins(tokens, vocab_size=32000, n_bins=256) + assert (bins == 254).all() class TestBinsToNormalized: - def test_bin_0_maps_to_low(self): + def test_bin_0_maps_to_first_center(self): bins = np.array([[0]]) out = bins_to_normalized(bins, n_bins=256, action_low=-1.0, action_high=1.0) - assert out[0, 0] == pytest.approx(-1.0) + assert out[0, 0] == pytest.approx(_center(0)) - def test_bin_last_maps_to_high(self): + def test_bin_last_maps_to_last_center(self): bins = np.array([[255]]) out = bins_to_normalized(bins) - assert out[0, 0] == pytest.approx(1.0) + assert out[0, 0] == pytest.approx(_center(254)) class TestUnnormalizeActions: @@ -86,20 +97,26 @@ def test_unknown_dataset_raises(self): class TestDecodeActions: def test_full_pipeline_normalized(self): - # 1 batch, seq 8, vocab 32064, pick top token at last 7 positions + # 1 batch, seq 8, padded vocab 32064; OpenVLA decode uses token 31999. logits = np.zeros((1, 8, 32064), dtype=np.float32) - logits[0, -7:, 32063] = 10.0 # topmost token = bin 0 = -1.0 normalized + logits[0, -7:, 31999] = 10.0 # effective top token = bin 0 = first center out = decode_actions(logits, action_dim=7) assert out.shape == (1, 7) - assert np.all(out == -1.0) + assert np.allclose(out, _center(0)) def test_with_norm_stats(self): logits = np.zeros((1, 8, 32064), dtype=np.float32) - logits[0, -7:, 32063] = 10.0 + logits[0, -7:, 31999] = 10.0 norm_stats = { - "bridge": {"action": {"q01": [0.0]*7, "q99": [2.0]*7, "mask": [True]*7}} + "bridge": { + "action": { + "q01": [0.0] * 7, + "q99": [2.0] * 7, + "mask": [True] * 7, + } + } } out = decode_actions(logits, action_dim=7, norm_stats=norm_stats, dataset_name="bridge") assert out.shape == (1, 7) - # normalized=-1, q01=0, q99=2 → 0.5*(-1+1)*2 + 0 = 0.0 - assert np.allclose(out, 0.0) + # q01=0, q99=2 maps normalized x to x + 1. + assert np.allclose(out, _center(0) + 1.0)