Skip to content

Commit 514cc27

Browse files
committed
Add cudaDeviceProp and compute capability numbers into cuda handle
1 parent 1fe5d07 commit 514cc27

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

src/devices/cuda/cuda_handle.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,25 @@ infiniopStatus_t createCudaHandle(CudaHandle_t *handle_ptr, int device_id) {
2323
checkCudnnError(cudnnCreate(&cudnn_handle));
2424
cudnn_pool->push(std::move(cudnn_handle));
2525

26-
*handle_ptr = new CudaContext{DevNvGpu, device_id, std::move(pool), std::move(cudnn_pool)};
26+
// set CUDA device property
27+
cudaDeviceProp prop;
28+
cudaGetDeviceProperties(&prop, device_id);
29+
30+
// set device compute capability numbers
31+
int capability_major;
32+
int capability_minor;
33+
cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device_id);
34+
cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device_id);
35+
36+
*handle_ptr = new CudaContext{
37+
DevNvGpu,
38+
device_id,
39+
std::move(pool),
40+
std::move(cudnn_pool),
41+
std::move(prop),
42+
capability_major,
43+
capability_minor,
44+
};
2745

2846
return STATUS_SUCCESS;
2947
}

src/devices/cuda/cuda_handle.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ struct CudaContext {
1515
int device_id;
1616
std::shared_ptr<Pool<cublasHandle_t>> cublas_handles_t;
1717
std::shared_ptr<Pool<cudnnHandle_t>> cudnn_handles_t;
18+
cudaDeviceProp prop;
19+
int compute_capability_major;
20+
int compute_capability_minor;
1821
};
1922
typedef struct CudaContext *CudaHandle_t;
2023

@@ -35,12 +38,13 @@ void use_cublas(std::shared_ptr<Pool<cublasHandle_t>> cublas_handles_t, int devi
3538
}
3639

3740
template<typename T>
38-
cudnnStatus_t use_cudnn(std::shared_ptr<Pool<cudnnHandle_t>> cudnn_handles_t, int device_id, T const &f) {
41+
cudnnStatus_t use_cudnn(std::shared_ptr<Pool<cudnnHandle_t>> cudnn_handles_t, int device_id, cudaStream_t stream, T const &f) {
3942
auto handle = cudnn_handles_t->pop();
4043
if (!handle) {
4144
cudaSetDevice(device_id);
4245
cudnnCreate(&(*handle));
4346
}
47+
cudnnSetStream(*handle, stream);
4448
cudnnStatus_t status = f(*handle);
4549
cudnn_handles_t->push(std::move(*handle));
4650
return status;

src/ops/expand/cuda/expand.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle,
2222
x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim];
2323
}
2424

25-
cudaDeviceProp prop;
26-
cudaGetDeviceProperties(&prop, handle->device_id);
27-
2825
int64_t *x_strides_d, *y_strides_d;
2926
char *strides_and_shape_d;
3027
checkCudaErrorWithCode(cudaMalloc(&strides_and_shape_d, ndim * (2 * sizeof(int64_t) + sizeof(uint64_t))), STATUS_MEMORY_NOT_ALLOCATED);
@@ -38,7 +35,7 @@ infiniopStatus_t cudaCreateExpandDescriptor(CudaHandle_t handle,
3835
handle->device_id,
3936
ndim,
4037
y_data_size,
41-
static_cast<uint64_t>(prop.maxGridSize[0]),
38+
static_cast<uint64_t>(handle->prop.maxGridSize[0]),
4239
strides_and_shape_d,
4340
};
4441

0 commit comments

Comments
 (0)