Skip to content

Commit 6298a03

Browse files
committed
Remove direct dependency on libcuda
Pull Request resolved: #1926 Dynamically load any cuda driver functions so that monarch library can be loaded a machine that doesn't have a gpu even if the library is built with one. Differential Revision: [D87380631](https://our.internmc.facebook.com/intern/diff/D87380631/) ghstack-source-id: 324224674
1 parent 2edc5dc commit 6298a03

File tree

11 files changed

+427
-97
lines changed

11 files changed

+427
-97
lines changed

cuda-sys/build.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ fn main() {
3131
.clang_arg("c++")
3232
.clang_arg("-std=gnu++20")
3333
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
34-
// Allow the specified functions and types
35-
.allowlist_function("cu.*")
36-
.allowlist_function("CU.*")
37-
.allowlist_type("cu.*")
38-
.allowlist_type("CU.*")
34+
// Allow the specified functions and types (CUDA Runtime API only)
35+
.allowlist_function("cuda.*")
36+
.allowlist_function("CUDA.*")
37+
.allowlist_type("cuda.*")
38+
.allowlist_type("CUDA.*")
3939
// Use newtype enum style
4040
.default_enum_style(bindgen::EnumVariation::NewType {
4141
is_bitfield: false,
@@ -78,7 +78,6 @@ fn main() {
7878
}
7979
};
8080
println!("cargo:rustc-link-search=native={}", cuda_lib_dir);
81-
println!("cargo:rustc-link-lib=cuda");
8281
println!("cargo:rustc-link-lib=cudart");
8382

8483
// Generate bindings - fail fast if this doesn't work

cuda-sys/src/lib.rs

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,3 @@ mod inner {
3434
}
3535

3636
pub use inner::*;
37-
38-
#[cfg(test)]
39-
mod tests {
40-
use std::mem::MaybeUninit;
41-
42-
use super::*;
43-
44-
#[test]
45-
fn sanity() {
46-
// SAFETY: testing bindings
47-
unsafe {
48-
let mut version = MaybeUninit::<i32>::uninit();
49-
let result = cuDriverGetVersion(version.as_mut_ptr());
50-
assert_eq!(result, cudaError_enum(0));
51-
}
52-
}
53-
}

cuda-sys/src/wrapper.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@
88

99
#pragma once
1010

11-
#include <cuda.h>
1211
#include <cuda_runtime.h>

monarch_rdma/src/macros.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
#[macro_export]
1010
macro_rules! cu_check {
1111
($result:expr) => {
12-
if $result != cuda_sys::CUresult::CUDA_SUCCESS {
12+
if $result != rdmaxcel_sys::CUDA_SUCCESS {
1313
let mut error_string: *const std::os::raw::c_char = std::ptr::null();
14-
cuda_sys::cuGetErrorString($result, &mut error_string);
14+
rdmaxcel_sys::rdmaxcel_cuGetErrorString($result, &mut error_string);
1515
panic!(
1616
"cuda failure {}:{} {:?} '{}'",
1717
file!(),

monarch_rdma/src/rdma_manager_actor.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,13 @@ impl RdmaManagerActor {
366366
) -> Result<(RdmaMemoryRegionView, String), anyhow::Error> {
367367
unsafe {
368368
let mut mem_type: i32 = 0;
369-
let ptr = addr as cuda_sys::CUdeviceptr;
370-
let err = cuda_sys::cuPointerGetAttribute(
369+
let ptr = addr as rdmaxcel_sys::CUdeviceptr;
370+
let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
371371
&mut mem_type as *mut _ as *mut std::ffi::c_void,
372-
cuda_sys::CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
372+
rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
373373
ptr,
374374
);
375-
let is_cuda = err == cuda_sys::CUresult::CUDA_SUCCESS;
375+
let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
376376

377377
let mut selected_rdma_device = None;
378378

@@ -457,11 +457,11 @@ impl RdmaManagerActor {
457457
mrv = maybe_mrv.unwrap();
458458
} else if is_cuda {
459459
let mut fd: i32 = -1;
460-
cuda_sys::cuMemGetHandleForAddressRange(
461-
&mut fd as *mut i32 as *mut std::ffi::c_void,
462-
addr as cuda_sys::CUdeviceptr,
460+
rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
461+
&mut fd,
462+
addr as rdmaxcel_sys::CUdeviceptr,
463463
size,
464-
cuda_sys::CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
464+
rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
465465
0,
466466
);
467467
mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain_pd, 0, size, 0, fd, access.0 as i32);

monarch_rdma/src/test_utils.rs

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,25 @@ pub fn is_cuda_available() -> bool {
4646
fn check_cuda_available() -> bool {
4747
unsafe {
4848
// Try to initialize CUDA
49-
let result = cuda_sys::cuInit(0);
49+
let result = rdmaxcel_sys::rdmaxcel_cuInit(0);
5050

51-
if result != cuda_sys::CUresult::CUDA_SUCCESS {
51+
if result != rdmaxcel_sys::CUDA_SUCCESS {
5252
return false;
5353
}
5454

5555
// Check if there are any CUDA devices
5656
let mut device_count: i32 = 0;
57-
let count_result = cuda_sys::cuDeviceGetCount(&mut device_count);
57+
let count_result = rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut device_count);
5858

59-
if count_result != cuda_sys::CUresult::CUDA_SUCCESS || device_count <= 0 {
59+
if count_result != rdmaxcel_sys::CUDA_SUCCESS || device_count <= 0 {
6060
return false;
6161
}
6262

6363
// Try to get the first device to verify it's actually accessible
64-
let mut device: cuda_sys::CUdevice = std::mem::zeroed();
65-
let device_result = cuda_sys::cuDeviceGet(&mut device, 0);
64+
let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
65+
let device_result = rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, 0);
6666

67-
if device_result != cuda_sys::CUresult::CUDA_SUCCESS {
67+
if device_result != rdmaxcel_sys::CUDA_SUCCESS {
6868
return false;
6969
}
7070

@@ -270,8 +270,8 @@ pub mod test_utils {
270270
pub actor_2: ActorRef<RdmaManagerActor>,
271271
pub rdma_handle_1: RdmaBuffer,
272272
pub rdma_handle_2: RdmaBuffer,
273-
cuda_context_1: Option<cuda_sys::CUcontext>,
274-
cuda_context_2: Option<cuda_sys::CUcontext>,
273+
cuda_context_1: Option<rdmaxcel_sys::CUcontext>,
274+
cuda_context_2: Option<rdmaxcel_sys::CUcontext>,
275275
}
276276

277277
#[derive(Debug, Clone)]
@@ -375,46 +375,53 @@ pub mod test_utils {
375375
}
376376
// CUDA case
377377
unsafe {
378-
cu_check!(cuda_sys::cuInit(0));
378+
cu_check!(rdmaxcel_sys::rdmaxcel_cuInit(0));
379379

380-
let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed();
381-
let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed();
380+
let mut dptr: rdmaxcel_sys::CUdeviceptr = std::mem::zeroed();
381+
let mut handle: rdmaxcel_sys::CUmemGenericAllocationHandle = std::mem::zeroed();
382382

383-
let mut device: cuda_sys::CUdevice = std::mem::zeroed();
384-
cu_check!(cuda_sys::cuDeviceGet(&mut device, accel.1 as i32));
383+
let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
384+
cu_check!(rdmaxcel_sys::rdmaxcel_cuDeviceGet(
385+
&mut device,
386+
accel.1 as i32
387+
));
385388

386-
let mut context: cuda_sys::CUcontext = std::mem::zeroed();
387-
cu_check!(cuda_sys::cuCtxCreate_v2(&mut context, 0, accel.1 as i32));
388-
cu_check!(cuda_sys::cuCtxSetCurrent(context));
389+
let mut context: rdmaxcel_sys::CUcontext = std::mem::zeroed();
390+
cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxCreate_v2(
391+
&mut context,
392+
0,
393+
accel.1 as i32
394+
));
395+
cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(context));
389396

390397
let mut granularity: usize = 0;
391-
let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed();
392-
prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
393-
prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
398+
let mut prop: rdmaxcel_sys::CUmemAllocationProp = std::mem::zeroed();
399+
prop.type_ = rdmaxcel_sys::CU_MEM_ALLOCATION_TYPE_PINNED;
400+
prop.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
394401
prop.location.id = device;
395402
prop.allocFlags.gpuDirectRDMACapable = 1;
396403
prop.requestedHandleTypes =
397-
cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
404+
rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
398405

399-
cu_check!(cuda_sys::cuMemGetAllocationGranularity(
406+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemGetAllocationGranularity(
400407
&mut granularity as *mut usize,
401408
&prop,
402-
cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
409+
rdmaxcel_sys::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
403410
));
404411

405412
// ensure our size is aligned
406413
let /*mut*/ padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity;
407414
assert!(padded_size == buffer_size);
408415

409-
cu_check!(cuda_sys::cuMemCreate(
410-
&mut handle as *mut cuda_sys::CUmemGenericAllocationHandle,
416+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemCreate(
417+
&mut handle as *mut rdmaxcel_sys::CUmemGenericAllocationHandle,
411418
padded_size,
412419
&prop,
413420
0
414421
));
415422
// reserve and map the memory
416-
cu_check!(cuda_sys::cuMemAddressReserve(
417-
&mut dptr as *mut cuda_sys::CUdeviceptr,
423+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressReserve(
424+
&mut dptr as *mut rdmaxcel_sys::CUdeviceptr,
418425
padded_size,
419426
0,
420427
0,
@@ -424,25 +431,28 @@ pub mod test_utils {
424431
assert!(padded_size.is_multiple_of(granularity));
425432

426433
// fails if a add cu_check macro; but passes if we don't
427-
let err = cuda_sys::cuMemMap(
428-
dptr as cuda_sys::CUdeviceptr,
434+
let err = rdmaxcel_sys::rdmaxcel_cuMemMap(
435+
dptr as rdmaxcel_sys::CUdeviceptr,
429436
padded_size,
430437
0,
431-
handle as cuda_sys::CUmemGenericAllocationHandle,
438+
handle as rdmaxcel_sys::CUmemGenericAllocationHandle,
432439
0,
433440
);
434-
if err != cuda_sys::CUresult::CUDA_SUCCESS {
441+
if err != rdmaxcel_sys::CUDA_SUCCESS {
435442
panic!("failed reserving and mapping memory {:?}", err);
436443
}
437444

438445
// set access
439-
let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed();
440-
access_desc.location.type_ =
441-
cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
446+
let mut access_desc: rdmaxcel_sys::CUmemAccessDesc = std::mem::zeroed();
447+
access_desc.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
442448
access_desc.location.id = device;
443-
access_desc.flags =
444-
cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
445-
cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1));
449+
access_desc.flags = rdmaxcel_sys::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
450+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemSetAccess(
451+
dptr,
452+
padded_size,
453+
&access_desc,
454+
1
455+
));
446456
buf_vec.push(Buffer {
447457
ptr: dptr,
448458
len: padded_size,
@@ -460,11 +470,11 @@ pub mod test_utils {
460470
}
461471
unsafe {
462472
// Use the CUDA context that was created for the first buffer
463-
cu_check!(cuda_sys::cuCtxSetCurrent(
473+
cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(
464474
cuda_contexts[0].expect("No CUDA context found")
465475
));
466476

467-
cu_check!(cuda_sys::cuMemcpyHtoD_v2(
477+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
468478
buf_vec[0].ptr,
469479
temp_buffer.as_ptr() as *const std::ffi::c_void,
470480
temp_buffer.len()
@@ -514,30 +524,30 @@ pub mod test_utils {
514524
.await?;
515525
if self.cuda_context_1.is_some() {
516526
unsafe {
517-
cu_check!(cuda_sys::cuCtxSetCurrent(
527+
cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(
518528
self.cuda_context_1.expect("No CUDA context found")
519529
));
520-
cu_check!(cuda_sys::cuMemUnmap(
521-
self.buffer_1.ptr as cuda_sys::CUdeviceptr,
530+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemUnmap(
531+
self.buffer_1.ptr as rdmaxcel_sys::CUdeviceptr,
522532
self.buffer_1.len
523533
));
524-
cu_check!(cuda_sys::cuMemAddressFree(
525-
self.buffer_1.ptr as cuda_sys::CUdeviceptr,
534+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressFree(
535+
self.buffer_1.ptr as rdmaxcel_sys::CUdeviceptr,
526536
self.buffer_1.len
527537
));
528538
}
529539
}
530540
if self.cuda_context_2.is_some() {
531541
unsafe {
532-
cu_check!(cuda_sys::cuCtxSetCurrent(
542+
cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(
533543
self.cuda_context_2.expect("No CUDA context found")
534544
));
535-
cu_check!(cuda_sys::cuMemUnmap(
536-
self.buffer_2.ptr as cuda_sys::CUdeviceptr,
545+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemUnmap(
546+
self.buffer_2.ptr as rdmaxcel_sys::CUdeviceptr,
537547
self.buffer_2.len
538548
));
539-
cu_check!(cuda_sys::cuMemAddressFree(
540-
self.buffer_2.ptr as cuda_sys::CUdeviceptr,
549+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressFree(
550+
self.buffer_2.ptr as rdmaxcel_sys::CUdeviceptr,
541551
self.buffer_2.len
542552
));
543553
}
@@ -579,12 +589,12 @@ pub mod test_utils {
579589
let mut temp_buffer = vec![0u8; size].into_boxed_slice();
580590
// SAFETY: The buffer is allocated with the correct size and the pointer is valid.
581591
unsafe {
582-
cu_check!(cuda_sys::cuCtxSetCurrent(
592+
cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(
583593
cuda_context.expect("No CUDA context found")
584594
));
585-
cu_check!(cuda_sys::cuMemcpyDtoH_v2(
595+
cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
586596
temp_buffer.as_mut_ptr() as *mut std::ffi::c_void,
587-
virtual_addr as cuda_sys::CUdeviceptr,
597+
virtual_addr as rdmaxcel_sys::CUdeviceptr,
588598
size
589599
));
590600
}

rdmaxcel-sys/build.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ fn main() {
2121
// Link against the mlx5 library
2222
println!("cargo:rustc-link-lib=mlx5");
2323

24+
// Link against dl for dynamic loading
25+
println!("cargo:rustc-link-lib=dl");
26+
2427
// Tell cargo to invalidate the built crate whenever the wrapper changes
2528
println!("cargo:rerun-if-changed=src/rdmaxcel.h");
2629
println!("cargo:rerun-if-changed=src/rdmaxcel.c");
2730
println!("cargo:rerun-if-changed=src/rdmaxcel.cpp");
31+
println!("cargo:rerun-if-changed=src/driver_api.h");
32+
println!("cargo:rerun-if-changed=src/driver_api.cpp");
2833

2934
// Validate CUDA installation and get CUDA home path
3035
let cuda_home = match build_utils::validate_cuda_installation() {
@@ -88,6 +93,10 @@ fn main() {
8893
.allowlist_function("pt_cuda_allocator_compatibility")
8994
.allowlist_function("register_segments")
9095
.allowlist_function("deregister_segments")
96+
.allowlist_function("rdmaxcel_cu.*")
97+
.allowlist_function("get_cuda_pci_address_from_ptr")
98+
.allowlist_function("rdmaxcel_print_device_info")
99+
.allowlist_function("rdmaxcel_error_string")
91100
.allowlist_type("ibv_.*")
92101
.allowlist_type("mlx5dv_.*")
93102
.allowlist_type("mlx5_wqe_.*")
@@ -149,7 +158,8 @@ fn main() {
149158
}
150159
};
151160
println!("cargo:rustc-link-search=native={}", cuda_lib_dir);
152-
println!("cargo:rustc-link-lib=cuda");
161+
// Note: libcuda is now loaded dynamically via dlopen in driver_api.cpp
162+
// Only link cudart (CUDA Runtime API)
153163
println!("cargo:rustc-link-lib=cudart");
154164

155165
// Link PyTorch C++ libraries for c10 symbols
@@ -213,7 +223,8 @@ fn main() {
213223

214224
// Compile the C++ source file for CUDA allocator compatibility
215225
let cpp_source_path = format!("{}/src/rdmaxcel.cpp", manifest_dir);
216-
if Path::new(&cpp_source_path).exists() {
226+
let driver_api_cpp_path = format!("{}/src/driver_api.cpp", manifest_dir);
227+
if Path::new(&cpp_source_path).exists() && Path::new(&driver_api_cpp_path).exists() {
217228
let mut libtorch_include_dirs: Vec<PathBuf> = vec![];
218229

219230
// Use the same approach as torch-sys: Python discovery first, env vars as fallback
@@ -249,6 +260,7 @@ fn main() {
249260
let mut cpp_build = cc::Build::new();
250261
cpp_build
251262
.file(&cpp_source_path)
263+
.file(&driver_api_cpp_path)
252264
.include(format!("{}/src", manifest_dir))
253265
.flag("-fPIC")
254266
.cpp(true)
@@ -270,7 +282,15 @@ fn main() {
270282

271283
cpp_build.compile("rdmaxcel_cpp");
272284
} else {
273-
panic!("C++ source file not found at {}", cpp_source_path);
285+
if !Path::new(&cpp_source_path).exists() {
286+
panic!("C++ source file not found at {}", cpp_source_path);
287+
}
288+
if !Path::new(&driver_api_cpp_path).exists() {
289+
panic!(
290+
"Driver API C++ source file not found at {}",
291+
driver_api_cpp_path
292+
);
293+
}
274294
}
275295
// Compile the CUDA source file
276296
let cuda_source_path = format!("{}/src/rdmaxcel.cu", manifest_dir);

0 commit comments

Comments
 (0)