@@ -144,6 +144,19 @@ std::tuple<void*, size_t> adjust_to_page_boundaries(void* ptr, size_t size) {
144144 return std::make_tuple ((void *)raw_ptr_adjusted, (size_t )size_adjusted);
145145}
146146
147+ #ifdef USE_ROCM
148+ using gpuMemLocation = hipMemLocation;
149+ #else
150+ using gpuMemLocation = cudaMemLocation;
151+ #endif
152+
153+ inline gpuMemLocation new_mem_location_from_device (const int device_id) {
154+ gpuMemLocation deviceLoc;
155+ deviceLoc.type = cudaMemLocationTypeDevice;
156+ deviceLoc.id = device_id;
157+ return deviceLoc;
158+ }
159+
147160} // namespace
148161
149162Tensor new_managed_tensor (
@@ -158,11 +171,31 @@ Tensor new_managed_tensor(
158171
159172 // Set preferred memory location to host memory
160173 AT_CUDA_CHECK (cudaMemAdvise (
161- ptr, size_bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId));
174+ ptr,
175+ size_bytes,
176+ cudaMemAdviseSetPreferredLocation,
177+ #if CUDA_VERSION >= 13000
178+ // Starting with CUDA 13, the deviceId arg (int) is replaced with
179+ // cudaMemLocation (struct)
180+ new_mem_location_from_device (cudaCpuDeviceId)
181+ #else
182+ cudaCpuDeviceId
183+ #endif
184+ ));
185+
162186 // User hints with "accessed by": GPU will establish direct mapping of data
163187 // in CPU memory, no page faults will be generated
164188 AT_CUDA_CHECK (cudaMemAdvise (
165- ptr, size_bytes, cudaMemAdviseSetAccessedBy, at::cuda::current_device ()));
189+ ptr,
190+ size_bytes,
191+ cudaMemAdviseSetAccessedBy,
192+ #if CUDA_VERSION >= 13000
193+ new_mem_location_from_device (at::cuda::current_device ())
194+ #else
195+ at::cuda::current_device ()
196+ #endif
197+ ));
198+
166199 C10_CUDA_KERNEL_LAUNCH_CHECK ();
167200
168201 // Work around fork issue - see uvm_mem_advice_dont_fork for details
@@ -353,7 +386,12 @@ void uvm_cuda_mem_advise(const Tensor& t, int64_t cuda_memory_advise) {
353386 ptr,
354387 size_bytes,
355388 static_cast <enum cudaMemoryAdvise>(cuda_memory_advise),
356- hint_device));
389+ #if CUDA_VERSION >= 13000
390+ new_mem_location_from_device (hint_device)
391+ #else
392+ hint_device
393+ #endif
394+ ));
357395 return ;
358396}
359397
@@ -379,7 +417,18 @@ void uvm_cuda_mem_prefetch_async(
379417
380418 auto stream = at::cuda::getCurrentCUDAStream ();
381419
382- AT_CUDA_CHECK (cudaMemPrefetchAsync (ptr, size_bytes, prefetch_device, stream));
420+ AT_CUDA_CHECK (cudaMemPrefetchAsync (
421+ ptr,
422+ size_bytes,
423+ #if CUDA_VERSION >= 13000
424+ new_mem_location_from_device (prefetch_device),
425+ // Flags argument needs to be set to zero for now, see:
426+ // https://docs.nvidia.com/cuda/archive/13.0.0/cuda-runtime-api/group__CUDART__MEMORY.html
427+ 0 ,
428+ #else
429+ prefetch_device,
430+ #endif
431+ stream));
383432
384433 return ;
385434}
0 commit comments