Skip to content

[Feature] Add BC Loss for behavior cloning#3667

Open
ParamThakkar123 wants to merge 28 commits intopytorch:mainfrom
ParamThakkar123:bc_loss
Open

[Feature] Add BC Loss for behavior cloning#3667
ParamThakkar123 wants to merge 28 commits intopytorch:mainfrom
ParamThakkar123:bc_loss

Conversation

@ParamThakkar123
Copy link
Copy Markdown
Contributor

Summary

This PR implements the BC Loss module for behavior cloning as requested in issue #3635.

Changes

  • torchrl/objectives/bc.py: New BCLoss module that supports both stochastic and deterministic policies
  • torchrl/objectives/init.py: Add BCLoss to module exports
  • test/objectives/test_bc.py: Comprehensive test suite covering all functionality
  • docs/source/reference/objectives_other.rst: Add BCLoss to documentation

Features

  • Auto-detects policy type based on whether actor outputs log_prob
  • For stochastic policies: minimizes -E[log π(a_expert | s)]
  • For deterministic policies: minimizes distance(a_pred, a_expert) with configurable loss functions (l1, l2, smooth_l1)
  • Follows standard LossModule pattern with proper keys, dispatch, and reduction
  • Integrates cleanly with existing offline RL stack

Tests

All tests pass including:

  • Forward/backward passes for both policy types
  • Different loss functions and reduction modes
  • Training convergence verification
  • Custom key configurations

Closes #3635

- Add BCLoss module in torchrl/objectives/bc.py
- Supports both stochastic and deterministic policies
- Auto-detects policy type based on log_prob output
- Configurable loss functions for deterministic policies
- Add comprehensive tests in test/objectives/test_bc.py
- Update documentation and module exports
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3667

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 6 New Failures, 1 Cancelled Job

As of commit 06b7d0b with merge base 33475e3 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 23, 2026
@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: Add BC Loss for behavior cloning

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@github-actions github-actions Bot added Documentation Improvements or additions to documentation Objectives labels Apr 23, 2026
@ParamThakkar123 ParamThakkar123 changed the title Add BC Loss for behavior cloning [Feature] Add BC Loss for behavior cloning Apr 23, 2026
@github-actions github-actions Bot added the Feature New feature label Apr 23, 2026
@ParamThakkar123
Copy link
Copy Markdown
Contributor Author

I haven't run the pre-commit on this because the pre-commit setup fails in my env with some compat issues

@Xmaster6y
Copy link
Copy Markdown
Contributor

Thanks for the work, eager to use this.

I just happen to read this and saw that you use log_prob key as a switch and it seems fragile to me. In some modules the key can be action_log_prob, or custom in general. Maybe we want something more explicit and compute dist / log probs in the loss directly using something closer to CQLLoss.actor_bc_loss.

I'll make sure to put it to the tests in the coming days/weeks.

@ParamThakkar123
Copy link
Copy Markdown
Contributor Author

Sure I will make this correction to it 🫡 .

@theap06
Copy link
Copy Markdown
Contributor

theap06 commented Apr 25, 2026

@ParamThakkar123 also, look into ensuring that the MSE path works well and the module structure is clean. Running a quick reproduction script against the three standard BC scenarios surfaces some issues worth fixing before merge. Thanks for the PR!

import torch, torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential
from torchrl.modules import Actor
from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
from torchrl.objectives.bc import BCLoss

torch.manual_seed(0)

# Failure 1 — cross_entropy not in string dispatch
actor = Actor(nn.Linear(4, 4))
loss = BCLoss(actor, loss_function='cross_entropy')
td = TensorDict({'observation': torch.randn(8,4), 'action': torch.randint(0,4,(8,)).long()}, [8])
loss(td)
# ValueError: Unsupported loss_function: cross_entropy

# Failure 2 — integer action labels crash
loss2 = BCLoss(Actor(nn.Linear(4, 4)))
td2 = TensorDict({'observation': torch.randn(8,4), 'action': torch.randint(0,4,(8,)).long()}, [8])
loss2(td2)
# RuntimeError: size of tensor a (4) must match tensor b (8)

# Failure 3 — stochastic actor silently computes MSE instead of NLL
net = nn.Sequential(nn.Linear(4, 8), NormalParamExtractor())
mod = TensorDictModule(net, in_keys=['observation'], out_keys=['loc','scale'])
stoch = ProbabilisticTensorDictSequential(mod,
    ProbabilisticTensorDictModule(['loc','scale'], ['action'], TanhNormal, return_log_prob=True))
loss3 = BCLoss(stoch)
td3 = TensorDict({'observation': torch.randn(8,4), 'action': torch.randn(8,4)}, [8])
print(loss3(td3)['loss_bc'].item())   # 0.9426  — MSE, always >= 0
# Real NLL for same inputs: 68.21 — completely different signal

@ParamThakkar123
Copy link
Copy Markdown
Contributor Author

Thanks for the insights @theap06 , I will look into it and fix this 🫡

Comment thread torchrl/modules/tensordict_module/actors.py Outdated
Comment thread torchrl/objectives/bc.py
class BCLoss(LossModule):
"""Behavior Cloning Loss Module.

Implements behavior cloning loss for both stochastic and deterministic policies.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the ref to Arxiv if we have it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the paper:

"Integrating Behavior Cloning and Reinforcement Learning for Improved
Performance in Dense and Sparse Reward Environments"
https://arxiv.org/abs/1910.04281

Comment thread torchrl/objectives/bc.py Outdated
Comment thread torchrl/objectives/bc.py Outdated
Comment thread torchrl/objectives/bc.py Outdated
Comment thread torchrl/objectives/bc.py Outdated
Comment thread torchrl/objectives/bc.py Outdated
Comment thread torchrl/objectives/bc.py Outdated
Comment thread torchrl/objectives/bc.py Outdated
@ParamThakkar123
Copy link
Copy Markdown
Contributor Author

On it

@ParamThakkar123
Copy link
Copy Markdown
Contributor Author

@vmoens implemented all the fixes as per reviews 🫡

@ParamThakkar123 ParamThakkar123 requested a review from vmoens April 25, 2026 21:01
Copy link
Copy Markdown
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

@ParamThakkar123
Copy link
Copy Markdown
Contributor Author

@vmoens SOTA and one Unit tests seems to fail but those seem unrelated to my changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Integrations/torch_geometric Integrations Modules Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Add BC Loss for behavior cloning

4 participants