Skip to content

Fix EpisodeWrapper dropping sub-step metrics during action_repeat (#610)#670

Open
SAY-5 wants to merge 1 commit intogoogle:mainfrom
SAY-5:fix-episode-wrapper-metrics-accumulation
Open

Fix EpisodeWrapper dropping sub-step metrics during action_repeat (#610)#670
SAY-5 wants to merge 1 commit intogoogle:mainfrom
SAY-5:fix-episode-wrapper-metrics-accumulation

Conversation

@SAY-5
Copy link
Copy Markdown

@SAY-5 SAY-5 commented Apr 17, 2026

Fixes #610.

Problem

EpisodeWrapper.step uses jax.lax.scan to execute action_repeat sub-steps and correctly sums rewards across them. However, it only returns nstate.reward from the scan body — nstate.metrics is never accumulated. After the scan, state.metrics contains only the last sub-step's values, so any sparse or per-step metric (e.g. an action-change penalty, a contact indicator, or a sparse goal reward that fires on sub-step 1 but not sub-step 3) is silently dropped.

This is visible when logging metrics to TensorBoard: sparse rewards always show zero because only the last sub-step (which typically has no sparse event) is recorded.

Fix

Return nstate.metrics alongside nstate.reward from the scan body so that lax.scan stacks them across the sub-step axis. Then sum each metric over that axis with jax.tree_util.tree_map(lambda m: jp.sum(m, axis=0), all_metrics) and update state.metrics with the summed values before the existing episode-metrics aggregation loop.

# Before: only rewards accumulated
def f(state, _):
    nstate = self.env.step(state, action)
    return nstate, nstate.reward

state, rewards = jax.lax.scan(f, state, (), self.action_repeat)
# state.metrics == last sub-step only

# After: metrics accumulated alongside rewards
def f(state, _):
    nstate = self.env.step(state, action)
    return nstate, (nstate.reward, nstate.metrics)

state, (rewards, all_metrics) = jax.lax.scan(f, state, (), self.action_repeat)
summed_metrics = jax.tree_util.tree_map(lambda m: jp.sum(m, axis=0), all_metrics)
state = state.replace(metrics=summed_metrics)
# state.metrics == sum across all sub-steps

This matches the pattern already used for rewards (jp.sum(rewards, axis=0)) and ensures the downstream episode-metrics aggregation loop (which reads state.metrics) sees the correct totals.

Impact

  • action_repeat=1: no change (sum over axis of length 1 is identity)
  • action_repeat>1: all per-step metrics are now correctly summed, not just the last sub-step's
  • No change to the public API or State schema

…ogle#610)

The lax.scan loop in EpisodeWrapper.step collected rewards across
action-repeat sub-steps but not metrics. After the scan, state.metrics
only contained the last sub-step's values, so any sparse or per-step
metric (e.g. action-change penalty) was silently dropped for all
earlier sub-steps.

Return metrics alongside rewards from the scan body and sum them
across the sub-step axis, matching how rewards are already handled.
Update state.metrics with the summed values before the episode
metrics aggregation loop so downstream consumers see the correct
totals.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EpisodeWrapper only preserves the last sub‑step’s metrics when using action_repeat

1 participant