@@ -2758,13 +2758,16 @@ def _setup_multi_policy_and_weights(
27582758
27592759 if weight_sync_schemes is not None :
27602760 # Weight sync schemes handle all weight distribution
2761- # No need to extract weights or create stateful policies here
2761+ # Extract weights so schemes can access them, but don't do in-place replacement
27622762 self ._policy_weights_dict = {}
2763- self ._fallback_policy = policy
2764- self ._get_weights_fn = None
2763+ self ._fallback_policy = None
2764+
2765+ if not any (policy_factory ) and policy is not None :
2766+ # Extract weights for the first device so schemes can access them
2767+ # Use first device as representative
2768+ first_device = self .policy_device [0 ] if self .policy_device else None
27652769
2766- # Validate device types for SharedMemWeightSyncScheme
2767- if not any (policy_factory ):
2770+ # Validate device types for SharedMemWeightSyncScheme
27682771 for scheme in weight_sync_schemes .values ():
27692772 if isinstance (scheme , SharedMemWeightSyncScheme ):
27702773 for policy_device in self .policy_device :
@@ -2776,6 +2779,29 @@ def _setup_multi_policy_and_weights(
27762779 f"Device type '{ policy_device .type } ' not supported for SharedMemWeightSyncScheme. "
27772780 f"Only 'cpu' and 'cuda' are supported."
27782781 )
2782+
2783+ # Extract weights from policy
2784+ weights = (
2785+ TensorDict .from_module (policy )
2786+ if isinstance (policy , nn .Module )
2787+ else TensorDict ()
2788+ )
2789+
2790+ # For SharedMemWeightSyncScheme, share the weights
2791+ if any (
2792+ isinstance (scheme , SharedMemWeightSyncScheme )
2793+ for scheme in weight_sync_schemes .values ()
2794+ ):
2795+ if first_device and first_device .type == "cpu" :
2796+ weights = weights .share_memory_ ()
2797+ elif first_device and first_device .type == "cuda" :
2798+ # CUDA tensors maintain shared references through mp.Queue
2799+ weights = weights .to (first_device ).share_memory_ ()
2800+
2801+ self ._policy_weights_dict [first_device ] = weights
2802+ self ._fallback_policy = policy
2803+
2804+ self ._get_weights_fn = None
27792805 else :
27802806 # Using legacy weight updater - extract weights and create stateful policies
27812807 self ._setup_multi_policy_and_weights_legacy (
@@ -3067,21 +3093,24 @@ def _run_processes(self) -> None:
30673093 1 , torch .get_num_threads () - total_workers
30683094 ) # 1 more thread for this proc
30693095
3070- # Initialize weight sync schemes to create queues before workers start
3096+ # Set up for worker processes
30713097 torch .set_num_threads (self .num_threads )
30723098 queue_out = mp .Queue (self ._queue_len ) # sends data from proc to main
30733099 self .procs = []
30743100 self .pipes = []
30753101 self ._traj_pool = _TrajectoryPool (lock = True )
30763102
3077- # Initialize all weight sync schemes early
3078- # Schemes own their queues and handle distribution internally
3103+ # Initialize weight sync schemes early for SharedMemWeightSyncScheme
3104+ # (queue created in __init__ will be pickled with scheme to workers)
3105+ # For MultiProcessWeightSyncScheme, we'll initialize after pipes are available
30793106 if self ._weight_sync_schemes :
30803107 for model_id , scheme in self ._weight_sync_schemes .items ():
3081- # Check if scheme has new API
3082- if hasattr (scheme , "init_on_sender" ):
3108+ # Only initialize SharedMemWeightSyncScheme now (needs queue before workers)
3109+ # MultiProcessWeightSyncScheme will be initialized after workers are created
3110+ if isinstance (scheme , SharedMemWeightSyncScheme ) and hasattr (
3111+ scheme , "init_on_sender"
3112+ ):
30833113 scheme .init_on_sender (model_id = model_id , context = self )
3084- # Get the initialized sender
30853114 self ._weight_senders [model_id ] = scheme .get_sender ()
30863115
30873116 # Create a policy on the right device
@@ -3257,6 +3286,18 @@ def _run_processes(self) -> None:
32573286 # Legacy string error message
32583287 raise RuntimeError (msg )
32593288
3289+ # Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available
3290+ # (SharedMemWeightSyncScheme was already initialized before workers)
3291+ if self ._weight_sync_schemes :
3292+ for model_id , scheme in self ._weight_sync_schemes .items ():
3293+ # Only initialize non-SharedMem schemes here (need pipes)
3294+ if not isinstance (scheme , SharedMemWeightSyncScheme ) and hasattr (
3295+ scheme , "init_on_sender"
3296+ ):
3297+ scheme .init_on_sender (model_id = model_id , context = self )
3298+ # Get the initialized sender
3299+ self ._weight_senders [model_id ] = scheme .get_sender ()
3300+
32603301 self .queue_out = queue_out
32613302 self .closed = False
32623303
0 commit comments