Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/tether/exporters/openvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
followed by a bin-to-continuous lookup:

bin_idx = vocab_size - token_id - 1
action_normalized = linspace(-1, 1, 256)[bin_idx]
action_normalized = bin_centers[bin_idx]
action = unnormalize(action_normalized, norm_stats[dataset])

where ``vocab_size`` is the effective text vocab
(``text_config.vocab_size - pad_to_multiple_of``, 32000 for openvla-7b)
and ``bin_centers`` are the centers between 256 edges over [-1, 1].

There is no dedicated action expert to reconstruct. The full model is
Llama-2-7B + DINOv2 + SigLIP + 3-layer projector — ~7.5B params of
standard transformers architecture that HuggingFace's optimum-onnx
Expand Down
22 changes: 16 additions & 6 deletions src/tether/postprocess/openvla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
which map onto the top N bins of the vocabulary via:

bin_idx = vocab_size - token_id - 1
action_normalized = linspace(action_low, action_high, N)[bin_idx]
action_normalized = bin_centers[bin_idx]
action_unnorm = unnormalize(action_normalized, norm_stats[dataset])

OpenVLA's HuggingFace model computes the decode vocab as
`text_config.vocab_size - pad_to_multiple_of` (32000 for openvla-7b), then
uses centers between `n_action_bins` edges.

This module provides the postprocessing step that wraps a standard
Llama ONNX/PyTorch forward pass and turns its logits into actions.
Works with either the full LM output (logits over vocab) or the
Expand Down Expand Up @@ -41,13 +45,13 @@ def tokens_to_action_bins(
vocab_size: int,
n_bins: int = 256,
) -> np.ndarray:
"""Convert predicted tokens to bin indices in [0, n_bins).
"""Convert predicted tokens to bin-center indices in [0, n_bins - 2].

OpenVLA assigns the top n_bins tokens of the Llama vocab to actions:
bin_idx = vocab_size - token_id - 1
"""
bin_idx = vocab_size - token_ids - 1
return np.clip(bin_idx, 0, n_bins - 1)
return np.clip(bin_idx, 0, n_bins - 2)


def bins_to_normalized(
Expand All @@ -56,9 +60,15 @@ def bins_to_normalized(
action_low: float = -1.0,
action_high: float = 1.0,
) -> np.ndarray:
"""Map bin indices to normalized actions in [action_low, action_high]."""
"""Map bin indices to normalized action bin centers.

OpenVLA defines `n_action_bins` edges and uses the centers between them,
so the largest valid bin index is `n_bins - 2`.
"""
bins = np.linspace(action_low, action_high, n_bins, dtype=np.float32)
return bins[bin_idx]
bin_centers = (bins[:-1] + bins[1:]) / 2.0
safe_idx = np.clip(bin_idx, 0, bin_centers.shape[0] - 1)
return bin_centers[safe_idx]


def unnormalize_actions(
Expand Down Expand Up @@ -100,7 +110,7 @@ def decode_actions(
action_dim: int,
norm_stats: dict[str, Any] | None = None,
dataset_name: str | None = None,
vocab_size: int = 32064,
vocab_size: int = 32000,
n_bins: int = 256,
action_low: float = -1.0,
action_high: float = 1.0,
Expand Down
55 changes: 36 additions & 19 deletions tests/test_openvla_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@
import pytest

from tether.postprocess.openvla import (
bins_to_normalized,
decode_actions,
logits_to_tokens,
tokens_to_action_bins,
bins_to_normalized,
unnormalize_actions,
decode_actions,
)


def _center(
index: int,
n_bins: int = 256,
action_low: float = -1.0,
action_high: float = 1.0,
) -> np.float32:
edges = np.linspace(action_low, action_high, n_bins, dtype=np.float32)
centers = (edges[:-1] + edges[1:]) / 2.0
return centers[index]


class TestLogitsToTokens:
def test_picks_argmax_last_action_dim(self):
# Batch of 1, seq 10, vocab 100; force token 42 at positions -3:
Expand All @@ -28,28 +39,28 @@ def test_rejects_wrong_ndim(self):

class TestTokensToActionBins:
def test_top_token_is_bin_zero(self):
# top-of-vocab token (vocab_size-1) → bin 0
tokens = np.array([[32063, 32062, 32061]])
bins = tokens_to_action_bins(tokens, vocab_size=32064, n_bins=256)
# OpenVLA decodes against effective vocab_size=32000, not the padded LM size.
tokens = np.array([[31999, 31998, 31997]])
bins = tokens_to_action_bins(tokens, vocab_size=32000, n_bins=256)
assert (bins == np.array([[0, 1, 2]])).all()

def test_clips_out_of_range(self):
# Any token below the top 256 should clip to bin 255 (lowest bin)
# Any token below the action token band clips to the last valid center.
tokens = np.array([[0, 100, 1000]])
bins = tokens_to_action_bins(tokens, vocab_size=32064, n_bins=256)
assert (bins == 255).all()
bins = tokens_to_action_bins(tokens, vocab_size=32000, n_bins=256)
assert (bins == 254).all()


class TestBinsToNormalized:
def test_bin_0_maps_to_low(self):
def test_bin_0_maps_to_first_center(self):
bins = np.array([[0]])
out = bins_to_normalized(bins, n_bins=256, action_low=-1.0, action_high=1.0)
assert out[0, 0] == pytest.approx(-1.0)
assert out[0, 0] == pytest.approx(_center(0))

def test_bin_last_maps_to_high(self):
def test_bin_last_maps_to_last_center(self):
bins = np.array([[255]])
out = bins_to_normalized(bins)
assert out[0, 0] == pytest.approx(1.0)
assert out[0, 0] == pytest.approx(_center(254))


class TestUnnormalizeActions:
Expand Down Expand Up @@ -86,20 +97,26 @@ def test_unknown_dataset_raises(self):

class TestDecodeActions:
def test_full_pipeline_normalized(self):
# 1 batch, seq 8, vocab 32064, pick top token at last 7 positions
# 1 batch, seq 8, padded vocab 32064; OpenVLA decode uses token 31999.
logits = np.zeros((1, 8, 32064), dtype=np.float32)
logits[0, -7:, 32063] = 10.0 # topmost token = bin 0 = -1.0 normalized
logits[0, -7:, 31999] = 10.0 # effective top token = bin 0 = first center
out = decode_actions(logits, action_dim=7)
assert out.shape == (1, 7)
assert np.all(out == -1.0)
assert np.allclose(out, _center(0))

def test_with_norm_stats(self):
logits = np.zeros((1, 8, 32064), dtype=np.float32)
logits[0, -7:, 32063] = 10.0
logits[0, -7:, 31999] = 10.0
norm_stats = {
"bridge": {"action": {"q01": [0.0]*7, "q99": [2.0]*7, "mask": [True]*7}}
"bridge": {
"action": {
"q01": [0.0] * 7,
"q99": [2.0] * 7,
"mask": [True] * 7,
}
}
}
out = decode_actions(logits, action_dim=7, norm_stats=norm_stats, dataset_name="bridge")
assert out.shape == (1, 7)
# normalized=-1, q01=0, q99=2 → 0.5*(-1+1)*2 + 0 = 0.0
assert np.allclose(out, 0.0)
# q01=0, q99=2 maps normalized x to x + 1.
assert np.allclose(out, _center(0) + 1.0)
Loading