Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions crates/cudnn/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{error::Error, ffi::CStr, fmt::Display};
use std::{error::Error, ffi::CStr, fmt::Display};

/// Enum encapsulating function status returns. All cuDNN library functions return their status.
///
Expand Down Expand Up @@ -52,6 +52,18 @@ pub enum CudnnError {
RuntimeFpOverflow,
#[cfg(not(cudnn9))]
VersionMismatch,
/// A version mismatch was detected between cuDNN sub-libraries (cuDNN 9+).
#[cfg(cudnn9)]
SublibraryVersionMismatch,
/// A serialization version mismatch was detected (cuDNN 9+).
#[cfg(cudnn9)]
SerializationVersionMismatch,
/// A deprecated API was called (cuDNN 9+).
#[cfg(cudnn9)]
Deprecated,
/// A required sub-library could not be loaded (cuDNN 9+).
#[cfg(cudnn9)]
SublibraryLoadingFailed,
}

impl CudnnError {
Expand All @@ -78,6 +90,14 @@ impl CudnnError {
CudnnError::RuntimeFpOverflow => CUDNN_STATUS_RUNTIME_FP_OVERFLOW,
#[cfg(not(cudnn9))]
CudnnError::VersionMismatch => CUDNN_STATUS_VERSION_MISMATCH,
#[cfg(cudnn9)]
CudnnError::SublibraryVersionMismatch => CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH,
#[cfg(cudnn9)]
CudnnError::SerializationVersionMismatch => CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH,
#[cfg(cudnn9)]
CudnnError::Deprecated => CUDNN_STATUS_DEPRECATED,
#[cfg(cudnn9)]
CudnnError::SublibraryLoadingFailed => CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED,
}
}
}
Expand Down Expand Up @@ -124,8 +144,33 @@ impl IntoResult for cudnn_sys::cudnnStatus_t {
CUDNN_STATUS_RUNTIME_FP_OVERFLOW => CudnnError::RuntimeFpOverflow,
#[cfg(not(cudnn9))]
CUDNN_STATUS_VERSION_MISMATCH => CudnnError::VersionMismatch,
// TODO(adamcavendish): implement cuDNN 9 error codes.
_ => todo!(),
// cuDNN 9 introduced a hierarchical status code system. Specific sub-codes
// (e.g. CUDNN_STATUS_BAD_PARAM_NULL_POINTER = 2002) are mapped to their
// parent category variant for backwards-compatible error handling.
#[cfg(cudnn9)]
CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH => CudnnError::SublibraryVersionMismatch,
#[cfg(cudnn9)]
CUDNN_STATUS_SERIALIZATION_VERSION_MISMATCH => CudnnError::SerializationVersionMismatch,
#[cfg(cudnn9)]
CUDNN_STATUS_DEPRECATED => CudnnError::Deprecated,
#[cfg(cudnn9)]
CUDNN_STATUS_SUBLIBRARY_LOADING_FAILED => CudnnError::SublibraryLoadingFailed,
#[cfg(cudnn9)]
s => {
use cudnn_sys::cudnnStatus_t::*;
// Map cuDNN 9 hierarchical sub-codes to their parent category variant.
// Sub-codes share the same thousands digit as their parent:
// 2xxx -> BAD_PARAM, 3xxx -> NOT_SUPPORTED,
// 4xxx -> INTERNAL_ERROR, 5xxx -> EXECUTION_FAILED
let category = (s as u32) / 1000 * 1000;
match category {
c if c == CUDNN_STATUS_BAD_PARAM as u32 => CudnnError::BadParam,
c if c == CUDNN_STATUS_NOT_SUPPORTED as u32 => CudnnError::NotSupported,
c if c == CUDNN_STATUS_INTERNAL_ERROR as u32 => CudnnError::InternalError,
c if c == CUDNN_STATUS_EXECUTION_FAILED as u32 => CudnnError::ExecutionFailed,
_ => CudnnError::InternalError,
}
}
})
}
}
}
Loading