Skip to content

Commit e3e0863

Browse files
committed
fix the fixes:
1 parent 2e73a09 commit e3e0863

File tree

2 files changed

+61
-27
lines changed

2 files changed

+61
-27
lines changed

torchrl/collectors/collectors.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2758,13 +2758,16 @@ def _setup_multi_policy_and_weights(
27582758

27592759
if weight_sync_schemes is not None:
27602760
# Weight sync schemes handle all weight distribution
2761-
# No need to extract weights or create stateful policies here
2761+
# Extract weights so schemes can access them, but don't do in-place replacement
27622762
self._policy_weights_dict = {}
2763-
self._fallback_policy = policy
2764-
self._get_weights_fn = None
2763+
self._fallback_policy = None
2764+
2765+
if not any(policy_factory) and policy is not None:
2766+
# Extract weights for the first device so schemes can access them
2767+
# Use first device as representative
2768+
first_device = self.policy_device[0] if self.policy_device else None
27652769

2766-
# Validate device types for SharedMemWeightSyncScheme
2767-
if not any(policy_factory):
2770+
# Validate device types for SharedMemWeightSyncScheme
27682771
for scheme in weight_sync_schemes.values():
27692772
if isinstance(scheme, SharedMemWeightSyncScheme):
27702773
for policy_device in self.policy_device:
@@ -2776,6 +2779,29 @@ def _setup_multi_policy_and_weights(
27762779
f"Device type '{policy_device.type}' not supported for SharedMemWeightSyncScheme. "
27772780
f"Only 'cpu' and 'cuda' are supported."
27782781
)
2782+
2783+
# Extract weights from policy
2784+
weights = (
2785+
TensorDict.from_module(policy)
2786+
if isinstance(policy, nn.Module)
2787+
else TensorDict()
2788+
)
2789+
2790+
# For SharedMemWeightSyncScheme, share the weights
2791+
if any(
2792+
isinstance(scheme, SharedMemWeightSyncScheme)
2793+
for scheme in weight_sync_schemes.values()
2794+
):
2795+
if first_device and first_device.type == "cpu":
2796+
weights = weights.share_memory_()
2797+
elif first_device and first_device.type == "cuda":
2798+
# CUDA tensors maintain shared references through mp.Queue
2799+
weights = weights.to(first_device).share_memory_()
2800+
2801+
self._policy_weights_dict[first_device] = weights
2802+
self._fallback_policy = policy
2803+
2804+
self._get_weights_fn = None
27792805
else:
27802806
# Using legacy weight updater - extract weights and create stateful policies
27812807
self._setup_multi_policy_and_weights_legacy(
@@ -3067,21 +3093,24 @@ def _run_processes(self) -> None:
30673093
1, torch.get_num_threads() - total_workers
30683094
) # 1 more thread for this proc
30693095

3070-
# Initialize weight sync schemes to create queues before workers start
3096+
# Set up for worker processes
30713097
torch.set_num_threads(self.num_threads)
30723098
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
30733099
self.procs = []
30743100
self.pipes = []
30753101
self._traj_pool = _TrajectoryPool(lock=True)
30763102

3077-
# Initialize all weight sync schemes early
3078-
# Schemes own their queues and handle distribution internally
3103+
# Initialize weight sync schemes early for SharedMemWeightSyncScheme
3104+
# (queue created in __init__ will be pickled with scheme to workers)
3105+
# For MultiProcessWeightSyncScheme, we'll initialize after pipes are available
30793106
if self._weight_sync_schemes:
30803107
for model_id, scheme in self._weight_sync_schemes.items():
3081-
# Check if scheme has new API
3082-
if hasattr(scheme, "init_on_sender"):
3108+
# Only initialize SharedMemWeightSyncScheme now (needs queue before workers)
3109+
# MultiProcessWeightSyncScheme will be initialized after workers are created
3110+
if isinstance(scheme, SharedMemWeightSyncScheme) and hasattr(
3111+
scheme, "init_on_sender"
3112+
):
30833113
scheme.init_on_sender(model_id=model_id, context=self)
3084-
# Get the initialized sender
30853114
self._weight_senders[model_id] = scheme.get_sender()
30863115

30873116
# Create a policy on the right device
@@ -3257,6 +3286,18 @@ def _run_processes(self) -> None:
32573286
# Legacy string error message
32583287
raise RuntimeError(msg)
32593288

3289+
# Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available
3290+
# (SharedMemWeightSyncScheme was already initialized before workers)
3291+
if self._weight_sync_schemes:
3292+
for model_id, scheme in self._weight_sync_schemes.items():
3293+
# Only initialize non-SharedMem schemes here (need pipes)
3294+
if not isinstance(scheme, SharedMemWeightSyncScheme) and hasattr(
3295+
scheme, "init_on_sender"
3296+
):
3297+
scheme.init_on_sender(model_id=model_id, context=self)
3298+
# Get the initialized sender
3299+
self._weight_senders[model_id] = scheme.get_sender()
3300+
32603301
self.queue_out = queue_out
32613302
self.closed = False
32623303

torchrl/weight_update/weight_sync_schemes.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@
55
from __future__ import annotations
66

77
import abc
8+
import time
89

910
import weakref
1011
from collections.abc import Iterator
12+
from queue import Empty
1113
from typing import Any, Literal, Protocol
1214

15+
import torch
16+
import torch.distributed
17+
1318
from tensordict import TensorDict, TensorDictBase
1419

15-
from torch import nn
20+
from torch import multiprocessing as mp, nn
1621

1722
__all__ = [
1823
"TransportBackend",
@@ -195,8 +200,6 @@ def _infer_device(self, td: TensorDictBase):
195200
Returns:
196201
torch.device or None if no tensors found or all on different devices.
197202
"""
198-
import torch
199-
200203
for value in td.values(True, True):
201204
if isinstance(value, torch.Tensor):
202205
return value.device
@@ -688,8 +691,6 @@ def send_ack(self, message: str = "updated") -> None:
688691

689692
def check_connection(self) -> bool:
690693
"""Check if torch.distributed is initialized."""
691-
import torch.distributed
692-
693694
return torch.distributed.is_initialized()
694695

695696

@@ -1602,7 +1603,8 @@ def __init__(
16021603
self._shared_transport = SharedMemTransport(
16031604
self.policy_weights, auto_register=auto_register
16041605
)
1605-
self._weight_init_queue = None # Created during init_on_sender
1606+
# Create queue immediately so it's available when scheme is pickled to workers
1607+
self._weight_init_queue = mp.Queue()
16061608

16071609
def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None:
16081610
"""Register shared memory weights for a model.
@@ -1659,13 +1661,8 @@ def init_on_sender(
16591661
"device_to_workers mapping must be provided via context or kwargs"
16601662
)
16611663

1662-
# Create queue once for this scheme instance (owned by scheme, not collector)
1663-
if self._weight_init_queue is None:
1664-
from torch import multiprocessing as mp
1665-
1666-
self._weight_init_queue = mp.Queue()
1667-
16681664
# Set worker info in transport
1665+
# Queue was already created in __init__ so it's available to workers
16691666
self._shared_transport.set_worker_info(device_to_workers)
16701667
self._shared_transport._weight_queue = self._weight_init_queue
16711668

@@ -1723,8 +1720,6 @@ def init_on_worker(
17231720
# Receive weights from the scheme's queue if available
17241721
if self._weight_init_queue is not None and worker_idx is not None:
17251722
# Read from queue until we find our worker_idx and model_id
1726-
from queue import Empty
1727-
17281723
timeout = kwargs.get("timeout", 10.0)
17291724
try:
17301725
while True:
@@ -1746,8 +1741,6 @@ def init_on_worker(
17461741
(msg_worker_idx, msg_model_id, shared_weights)
17471742
)
17481743
# Small sleep to avoid immediately picking up the same message
1749-
import time
1750-
17511744
time.sleep(0.001)
17521745
except Empty:
17531746
# No weights pre-registered for this model (will use auto-register or policy_factory)

0 commit comments

Comments
 (0)