@@ -59,6 +59,7 @@ typedef struct shmStruct_st {
5959 size_t nprocesses;
6060 int barrier;
6161 int sense;
62+ cudaMemAllocationHandleType handleType;
6263 int devices[MAX_DEVICES];
6364 cudaMemPoolPtrExportData exportPtrData[MAX_DEVICES];
6465} shmStruct;
@@ -126,7 +127,7 @@ static void childProcess(int id) {
126127
127128 std::vector<cudaMemPool_t> pools (shm->nprocesses );
128129
129- cudaMemAllocationHandleType handleType = cudaMemHandleTypePosixFileDescriptor ;
130+ cudaMemAllocationHandleType handleType = shm-> handleType ;
130131
131132 // Import mem pools from all the devices created in the master
132133 // process using shareable handles received via socket
@@ -239,6 +240,7 @@ static void parentProcess(char *app) {
239240 volatile shmStruct *shm = NULL ;
240241 std::vector<void *> ptrs;
241242 std::vector<Process> processes;
243+ cudaMemAllocationHandleType handleType = cudaMemHandleTypeNone;
242244
243245 checkCudaErrors (cudaGetDeviceCount (&devCount));
244246 std::vector<CUdevice> devices (devCount);
@@ -270,22 +272,32 @@ static void parentProcess(char *app) {
270272 printf (" Device %d does not support cuda memory pools, skipping...\n " , i);
271273 continue ;
272274 }
273- int deviceSupportsIpcHandle = 0 ;
274- #if defined(__linux__)
275- checkCudaErrors (cuDeviceGetAttribute (
276- &deviceSupportsIpcHandle,
277- CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED,
278- devices[i]));
279- #else
280- cuDeviceGetAttribute (&deviceSupportsIpcHandle,
281- CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED,
282- devices[i]);
283- #endif
284-
285- if (!deviceSupportsIpcHandle) {
286- printf (" Device %d does not support CUDA IPC Handle, skipping...\n " , i);
275+ int supportedHandleTypes = 0 ;
276+ checkCudaErrors (cudaDeviceGetAttribute (&supportedHandleTypes,
277+ cudaDevAttrMemoryPoolSupportedHandleTypes, i));
278+ if (supportedHandleTypes == 0 ) {
279+ printf (" Device %d does not support Memory pool based IPC, skipping...\n " , i);
287280 continue ;
288281 }
282+
283+ if (handleType == cudaMemHandleTypeNone) {
284+ if (supportedHandleTypes & cudaMemHandleTypePosixFileDescriptor) {
285+ handleType = cudaMemHandleTypePosixFileDescriptor;
286+ }
287+ else if (supportedHandleTypes & cudaMemHandleTypeWin32) {
288+ handleType = cudaMemHandleTypeWin32;
289+ }
290+ else {
291+ printf (" Device %d does not support any supported handle types, skipping...\n " , i);
292+ continue ;
293+ }
294+ }
295+ else {
296+ if ((supportedHandleTypes & handleType) != handleType) {
297+ printf (" Mixed handle types are not supported, waiving test\n " );
298+ exit (EXIT_WAIVED);
299+ }
300+ }
289301 // This sample requires two processes accessing each device, so we need
290302 // to ensure exclusive or prohibited mode is not set
291303 if (prop.computeMode != cudaComputeModeDefault) {
@@ -337,6 +349,11 @@ static void parentProcess(char *app) {
337349 exit (EXIT_WAIVED);
338350 }
339351
352+ if (handleType == cudaMemHandleTypeNone) {
353+ printf (" No supported handle types found, waiving test\n " );
354+ exit (EXIT_WAIVED);
355+ }
356+
340357 std::vector<ShareableHandle> shareableHandles (shm->nprocesses );
341358 std::vector<cudaStream_t> streams (shm->nprocesses );
342359 std::vector<cudaMemPool_t> pools (shm->nprocesses );
@@ -352,16 +369,14 @@ static void parentProcess(char *app) {
352369 cudaMemPoolProps poolProps;
353370 memset (&poolProps, 0 , sizeof (cudaMemPoolProps));
354371 poolProps.allocType = cudaMemAllocationTypePinned;
355- poolProps.handleTypes = cudaMemHandleTypePosixFileDescriptor ;
372+ poolProps.handleTypes = handleType ;
356373
357374 poolProps.location .type = cudaMemLocationTypeDevice;
358375 poolProps.location .id = shm->devices [i];
359376
360377 checkCudaErrors (cudaMemPoolCreate (&pools[i], &poolProps));
361378
362379 // Query the shareable handle for the pool
363- cudaMemAllocationHandleType handleType =
364- cudaMemHandleTypePosixFileDescriptor;
365380 // Allocate memory in a stream from the pool just created
366381 checkCudaErrors (cudaMallocAsync (&ptr, DATA_SIZE, pools[i], streams[i]));
367382
@@ -378,6 +393,8 @@ static void parentProcess(char *app) {
378393 ptrs.push_back (ptr);
379394 }
380395
396+ shm->handleType = handleType;
397+
381398 // Launch the child processes!
382399 for (i = 0 ; i < shm->nprocesses ; i++) {
383400 char devIdx[10 ];
@@ -430,7 +447,7 @@ static void parentProcess(char *app) {
430447int main (int argc, char **argv) {
431448#if defined(__arm__) || defined(__aarch64__) || defined(WIN32) || \
432449 defined (_WIN32) || defined (WIN64) || defined (_WIN64)
433- printf (" Not supported on ARM\n " );
450+ printf (" Not supported on ARM or Windows \n " );
434451 return EXIT_WAIVED;
435452#else
436453 if (argc == 1 ) {
0 commit comments