Skip to content

Commit cd0d4ea

Browse files
committed
wip
1 parent d668f13 commit cd0d4ea

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

fbgemm_gpu/src/memory_utils/memory_utils.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ inline gpuMemLocation new_mem_location_from_device(const int device_id) {
157157
return deviceLoc;
158158
}
159159

160+
inline gpuMemLocation new_mem_location_cpu() {
161+
gpuMemLocation deviceLoc;
162+
deviceLoc.type = cudaMemLocationTypeHost;
163+
deviceLoc.id = cudaCpuDeviceId;
164+
return deviceLoc;
165+
}
166+
160167
} // namespace
161168

162169
Tensor new_managed_tensor(
@@ -177,7 +184,7 @@ Tensor new_managed_tensor(
177184
#if CUDART_VERSION >= 13000
178185
// Starting with CUDA 13, the deviceId arg (int) is replaced with
179186
// cudaMemLocation (struct)
180-
new_mem_location_from_device(cudaCpuDeviceId)
187+
new_mem_location_cpu()
181188
#else
182189
cudaCpuDeviceId
183190
#endif

0 commit comments

Comments
 (0)