55
66import torch
77import torch .distributed as dist
8+ import torch .distributed ._symmetric_memory as symm_mem
89from torch ._C ._autograd import DeviceType
910from torch ._C ._distributed_c10d import _SymmetricMemory
1011from torch ._inductor .utils import fresh_inductor_cache , run_and_get_triton_code
@@ -81,9 +82,25 @@ def _init_process(self):
8182 rank = self .rank ,
8283 store = store ,
8384 )
84- enable_symm_mem_for_group (dist .group .WORLD .group_name )
8585 torch .manual_seed (42 + self .rank )
8686
87+ @skipIfRocm
88+ @skip_if_lt_x_gpu (2 )
89+ def test_cuda_nvlink_connectivity_detection (self ) -> None :
90+ from torch ._C ._distributed_c10d import _detect_dma_connectivity
91+
92+ connectivity = _detect_dma_connectivity (DeviceType .CUDA , "nvlink" )
93+ self .assertEqual (connectivity .device_type , DeviceType .CUDA )
94+ self .assertEqual (connectivity .connection_type , "nvlink" )
95+ self .assertEqual (len (connectivity .matrix ), torch .cuda .device_count ())
96+ for row in connectivity .matrix :
97+ self .assertEqual (len (row ), torch .cuda .device_count ())
98+
99+ @skipIfRocm
100+ def test_large_alloc (self ) -> None :
101+ t = symm_mem .empty (2 * 1024 ** 3 , dtype = torch .uint8 , device = "cuda" )
102+ self .assertEqual (t .numel () * t .element_size (), 2 * 1024 ** 3 )
103+
87104 def _get_test_alloc_args (self ):
88105 shape = (64 , 64 )
89106 stride = (64 , 1 )
@@ -92,64 +109,56 @@ def _get_test_alloc_args(self):
92109 group_name = "0"
93110 return (shape , stride , dtype , device , group_name )
94111
95- def _verify_symmetric_memory (self , symm_mem ):
96- self .assertEqual (symm_mem .world_size , 2 )
112+ def _verify_symmetric_memory (self , symm_mem_hdl ):
113+ self .assertEqual (symm_mem_hdl .world_size , 2 )
97114
98- buf = symm_mem .get_buffer (0 , (symm_mem .buffer_size // 4 ,), torch .float32 )
115+ buf = symm_mem_hdl .get_buffer (
116+ 0 , (symm_mem_hdl .buffer_size // 4 ,), torch .float32
117+ )
99118 self .assertEqual (buf .storage_offset (), 0 )
100- self .assertEqual (buf .untyped_storage ().size (), symm_mem .buffer_size )
119+ self .assertEqual (buf .untyped_storage ().size (), symm_mem_hdl .buffer_size )
101120
102- if symm_mem .rank == 0 :
103- symm_mem .wait_signal (src_rank = 1 )
121+ if symm_mem_hdl .rank == 0 :
122+ symm_mem_hdl .wait_signal (src_rank = 1 )
104123 self .assertTrue (buf .eq (42 ).all ())
105124 else :
106125 buf .fill_ (42 )
107- symm_mem .put_signal (dst_rank = 0 )
126+ symm_mem_hdl .put_signal (dst_rank = 0 )
108127
109- symm_mem .barrier ()
128+ symm_mem_hdl .barrier ()
110129
111- if symm_mem .rank == 0 :
112- symm_mem .barrier ()
130+ if symm_mem_hdl .rank == 0 :
131+ symm_mem_hdl .barrier ()
113132 self .assertTrue (buf .eq (43 ).all ())
114133 else :
115134 buf .fill_ (43 )
116- symm_mem .barrier ()
135+ symm_mem_hdl .barrier ()
117136
118- symm_mem .barrier ()
119-
120- @skipIfRocm
121- @skip_if_lt_x_gpu (2 )
122- def test_cuda_nvlink_connectivity_detection (self ) -> None :
123- from torch ._C ._distributed_c10d import _detect_dma_connectivity
124-
125- connectivity = _detect_dma_connectivity (DeviceType .CUDA , "nvlink" )
126- self .assertEqual (connectivity .device_type , DeviceType .CUDA )
127- self .assertEqual (connectivity .connection_type , "nvlink" )
128- self .assertEqual (len (connectivity .matrix ), torch .cuda .device_count ())
129- for row in connectivity .matrix :
130- self .assertEqual (len (row ), torch .cuda .device_count ())
137+ symm_mem_hdl .barrier ()
131138
132139 @skipIfRocm
133140 @skip_if_lt_x_gpu (2 )
134141 def test_empty_strided_p2p (self ) -> None :
135142 self ._init_process ()
143+ enable_symm_mem_for_group (dist .group .WORLD .group_name )
136144
137145 alloc_args = self ._get_test_alloc_args ()
138146
139147 t = torch .empty ((64 , 64 ), device = self .device )
140148 self .assertIsNone (_SymmetricMemory .rendezvous (t ))
141149
142150 t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
143- symm_mem = _SymmetricMemory .rendezvous (t )
151+ symm_mem_hdl = _SymmetricMemory .rendezvous (t )
144152
145153 del t
146- self ._verify_symmetric_memory (symm_mem )
154+ self ._verify_symmetric_memory (symm_mem_hdl )
147155 dist .destroy_process_group ()
148156
149157 @skipIfRocm
150158 @skip_if_lt_x_gpu (2 )
151159 def test_empty_strided_p2p_persistent (self ) -> None :
152160 self ._init_process ()
161+ enable_symm_mem_for_group (dist .group .WORLD .group_name )
153162
154163 alloc_args = self ._get_test_alloc_args ()
155164
@@ -168,51 +177,47 @@ def test_empty_strided_p2p_persistent(self) -> None:
168177 t = _SymmetricMemory .empty_strided_p2p (* alloc_args , alloc_id = 42 )
169178 self .assertEqual (t .data_ptr (), data_ptr )
170179
171- symm_mem = _SymmetricMemory .rendezvous (t )
172- self ._verify_symmetric_memory (symm_mem )
180+ symm_mem_hdl = _SymmetricMemory .rendezvous (t )
181+ self ._verify_symmetric_memory (symm_mem_hdl )
173182 dist .destroy_process_group ()
174183
175184 @skipIfRocm
176185 @skip_if_lt_x_gpu (2 )
177186 def test_get_signal_pad (self ) -> None :
178187 self ._init_process ()
179188
180- t = _SymmetricMemory . empty_strided_p2p ( * self . _get_test_alloc_args () )
181- symm_mem = _SymmetricMemory .rendezvous (t )
189+ t = symm_mem . empty ( 1 , device = "cuda" )
190+ symm_mem_hdl = symm_mem .rendezvous (t , group = dist . group . WORLD )
182191 peer_rank = (self .rank + 1 ) % self .world_size
183192
184- signal_pad = symm_mem .get_signal_pad (self .rank )
185- self .assertEqual (signal_pad .data_ptr (), symm_mem .signal_pad_ptrs [symm_mem .rank ])
193+ signal_pad = symm_mem_hdl .get_signal_pad (self .rank )
194+ self .assertEqual (
195+ signal_pad .data_ptr (), symm_mem_hdl .signal_pad_ptrs [symm_mem_hdl .rank ]
196+ )
186197
187- signal_pad = symm_mem .get_signal_pad (peer_rank )
198+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank )
188199 self .assertEqual (signal_pad .dtype , torch .uint32 )
189- self .assertEqual (signal_pad .numel (), symm_mem .signal_pad_size // 4 )
200+ self .assertEqual (signal_pad .numel (), symm_mem_hdl .signal_pad_size // 4 )
190201
191202 # Only specify sizes
192- signal_pad = symm_mem .get_signal_pad (peer_rank , (8 , 8 ))
203+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank , (8 , 8 ))
193204 self .assertEqual (signal_pad .dtype , torch .uint32 )
194205 self .assertEqual (signal_pad .numel (), 64 )
195206
196207 # Only specify dtype
197- signal_pad = symm_mem .get_signal_pad (peer_rank , dtype = torch .uint64 )
208+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank , dtype = torch .uint64 )
198209 self .assertEqual (signal_pad .dtype , torch .uint64 )
199- self .assertEqual (signal_pad .numel (), symm_mem .signal_pad_size // 8 )
210+ self .assertEqual (signal_pad .numel (), symm_mem_hdl .signal_pad_size // 8 )
200211
201212 # Specify both sizes and dtype
202- signal_pad = symm_mem .get_signal_pad (peer_rank , (8 , 8 ), dtype = torch .uint64 )
213+ signal_pad = symm_mem_hdl .get_signal_pad (peer_rank , (8 , 8 ), dtype = torch .uint64 )
203214 self .assertEqual (signal_pad .dtype , torch .uint64 )
204215 self .assertEqual (signal_pad .numel (), 64 )
205216
206217 # Sanity check that writes to buffer doesn't corrupt signal_pad
207- t = _SymmetricMemory .empty_strided_p2p (
208- (0 ,),
209- (0 ,),
210- torch .float32 ,
211- self .device ,
212- dist .group .WORLD .group_name ,
213- )
214- symm_mem = _SymmetricMemory .rendezvous (t )
215- signal_pad = symm_mem .get_signal_pad (self .rank )
218+ t = symm_mem .empty (0 , device = "cuda" )
219+ symm_mem_hdl = symm_mem .rendezvous (t )
220+ signal_pad = symm_mem_hdl .get_signal_pad (self .rank )
216221 signal_pad .fill_ (42 )
217222 t .fill_ (0 )
218223 self .assertTrue (signal_pad .eq (42 ).all ())
@@ -224,14 +229,12 @@ def test_get_signal_pad(self) -> None:
224229 def test_barrier_timeout (self ) -> None :
225230 self ._init_process ()
226231
227- alloc_args = self ._get_test_alloc_args ()
228-
229- t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
230- symm_mem = _SymmetricMemory .rendezvous (t )
232+ t = symm_mem .empty (1 , device = "cuda" )
233+ symm_mem_hdl = _SymmetricMemory .rendezvous (t , group = dist .group .WORLD )
231234
232235 if self .rank == 0 :
233236 with self .assertRaises (RuntimeError ):
234- symm_mem .barrier (timeout_ms = 1000 )
237+ symm_mem_hdl .barrier (timeout_ms = 1000 )
235238 torch .cuda .synchronize ()
236239 else :
237240 torch .cuda .synchronize ()
@@ -247,17 +250,15 @@ def test_barrier_timeout(self) -> None:
247250 def test_put_signal_timeout (self ) -> None :
248251 self ._init_process ()
249252
250- alloc_args = self ._get_test_alloc_args ()
251-
252- t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
253- symm_mem = _SymmetricMemory .rendezvous (t )
253+ t = symm_mem .empty (1 , device = "cuda" )
254+ symm_mem_hdl = _SymmetricMemory .rendezvous (t , group = dist .group .WORLD )
254255
255256 if self .rank == 0 :
256257 with self .assertRaises (RuntimeError ):
257258 # First, put a signal into rank 1's signal pad. Since rank 1
258259 # doesn't wait on this signal, the subsequent put will timeout.
259- symm_mem .put_signal (dst_rank = 1 )
260- symm_mem .put_signal (dst_rank = 1 , timeout_ms = 1000 )
260+ symm_mem_hdl .put_signal (dst_rank = 1 )
261+ symm_mem_hdl .put_signal (dst_rank = 1 , timeout_ms = 1000 )
261262 torch .cuda .synchronize ()
262263 else :
263264 torch .cuda .synchronize ()
@@ -273,14 +274,12 @@ def test_put_signal_timeout(self) -> None:
273274 def test_wait_signal_timeout (self ) -> None :
274275 self ._init_process ()
275276
276- alloc_args = self ._get_test_alloc_args ()
277-
278- t = _SymmetricMemory .empty_strided_p2p (* alloc_args )
279- symm_mem = _SymmetricMemory .rendezvous (t )
277+ t = symm_mem .empty (1 , device = "cuda" )
278+ symm_mem_hdl = _SymmetricMemory .rendezvous (t , group = dist .group .WORLD )
280279
281280 if self .rank == 0 :
282281 with self .assertRaises (RuntimeError ):
283- symm_mem .wait_signal (src_rank = 1 , timeout_ms = 1000 )
282+ symm_mem_hdl .wait_signal (src_rank = 1 , timeout_ms = 1000 )
284283 torch .cuda .synchronize ()
285284 else :
286285 torch .cuda .synchronize ()
@@ -685,7 +684,6 @@ def _init_process(self):
685684 rank = self .rank ,
686685 store = store ,
687686 )
688- enable_symm_mem_for_group (dist .group .WORLD .group_name )
689687 torch .manual_seed (42 + self .rank )
690688
691689 @skipIfRocm
@@ -699,18 +697,10 @@ def test_subgroup(self) -> None:
699697
700698 world = dist .group .WORLD
701699 subgroup = subgroup_0 if world .rank () < world .size () // 2 else subgroup_1
702- enable_symm_mem_for_group (subgroup .group_name )
703700
704- t = _SymmetricMemory .empty_strided_p2p (
705- size = (64 ,),
706- stride = (1 ,),
707- dtype = torch .float32 ,
708- device = self .device ,
709- )
710- symm_mem_world = _SymmetricMemory .rendezvous (t , group_name = world .group_name )
711- symm_mem_subgroup = _SymmetricMemory .rendezvous (
712- t , group_name = subgroup .group_name
713- )
701+ t = symm_mem .empty (64 , device = "cuda" )
702+ symm_mem_world = symm_mem .rendezvous (t , group = world )
703+ symm_mem_subgroup = symm_mem .rendezvous (t , group = subgroup )
714704
715705 self .assertEqual (symm_mem_world .world_size , world .size ())
716706 self .assertEqual (symm_mem_world .rank , world .rank ())
0 commit comments