Fix EpisodeWrapper dropping sub-step metrics during action_repeat (#610)#670
Open
SAY-5 wants to merge 1 commit intogoogle:mainfrom
Open
Fix EpisodeWrapper dropping sub-step metrics during action_repeat (#610)#670SAY-5 wants to merge 1 commit intogoogle:mainfrom
SAY-5 wants to merge 1 commit intogoogle:mainfrom
Conversation
…ogle#610) The lax.scan loop in EpisodeWrapper.step collected rewards across action-repeat sub-steps but not metrics. After the scan, state.metrics only contained the last sub-step's values, so any sparse or per-step metric (e.g. action-change penalty) was silently dropped for all earlier sub-steps. Return metrics alongside rewards from the scan body and sum them across the sub-step axis, matching how rewards are already handled. Update state.metrics with the summed values before the episode metrics aggregation loop so downstream consumers see the correct totals.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #610.
Problem
EpisodeWrapper.stepusesjax.lax.scanto executeaction_repeatsub-steps and correctly sums rewards across them. However, it only returnsnstate.rewardfrom the scan body —nstate.metricsis never accumulated. After the scan,state.metricscontains only the last sub-step's values, so any sparse or per-step metric (e.g. an action-change penalty, a contact indicator, or a sparse goal reward that fires on sub-step 1 but not sub-step 3) is silently dropped.This is visible when logging metrics to TensorBoard: sparse rewards always show zero because only the last sub-step (which typically has no sparse event) is recorded.
Fix
Return
nstate.metricsalongsidenstate.rewardfrom the scan body so thatlax.scanstacks them across the sub-step axis. Then sum each metric over that axis withjax.tree_util.tree_map(lambda m: jp.sum(m, axis=0), all_metrics)and updatestate.metricswith the summed values before the existing episode-metrics aggregation loop.This matches the pattern already used for rewards (
jp.sum(rewards, axis=0)) and ensures the downstream episode-metrics aggregation loop (which readsstate.metrics) sees the correct totals.Impact
action_repeat=1: no change (sum over axis of length 1 is identity)action_repeat>1: all per-step metrics are now correctly summed, not just the last sub-step's