diff --git a/include/infini_operators.h b/include/infini_operators.h index 9a5a2555..7373bcca 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -2,11 +2,12 @@ #include "ops/add/add.h" #include "ops/attention/attention.h" #include "ops/avg_pool/avg_pool.h" +#include "ops/batch_norm.h" #include "ops/causal_softmax/causal_softmax.h" -#include "ops/global_avg_pool/global_avg_pool.h" +#include "ops/conv/conv.h" #include "ops/expand/expand.h" #include "ops/gemm/gemm.h" -#include "ops/conv/conv.h" +#include "ops/global_avg_pool/global_avg_pool.h" #include "ops/matmul/matmul.h" #include "ops/max_pool/max_pool.h" #include "ops/mlp/mlp.h" diff --git a/include/ops/batch_norm/batch_norm.h b/include/ops/batch_norm/batch_norm.h new file mode 100644 index 00000000..5943e3b8 --- /dev/null +++ b/include/ops/batch_norm/batch_norm.h @@ -0,0 +1,34 @@ +#ifndef BATCH_NORM_H +#define BATCH_NORM_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct BatchNormDescriptor { + Device device; +} BatchNormDescriptor; + +typedef BatchNormDescriptor *infiniopBatchNormDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateBatchNormDescriptor(infiniopHandle_t handle, + infiniopBatchNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t scale, + infiniopTensorDescriptor_t b, + infiniopTensorDescriptor_t mean, + infiniopTensorDescriptor_t var, + double eps); + +__C __export infiniopStatus_t infiniopBatchNorm(infiniopBatchNormDescriptor_t desc, + void *y, + void const *x, + void const *scale, + void const *b, + void const *mean, + void const *var, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyBatchNormDescriptor(infiniopBatchNormDescriptor_t desc); + +#endif diff --git a/operatorspy/tests/batch_norm.py b/operatorspy/tests/batch_norm.py new file mode 100644 index 00000000..8c43c739 --- /dev/null +++ b/operatorspy/tests/batch_norm.py @@ -0,0 +1,216 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_double +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, + device_enum_to_str, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch +import ctypes +import torch.nn.functional as F + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + + +class BatchNormDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopBatchNormDescriptor_t = POINTER(BatchNormDescriptor) + + +def batch_norm(x, scale, b, mean, var, eps): + ndim = len(x.shape) + if ndim <= 1 or ndim > 5: + print("Error: Pytorch -> Unsupported tensor dimension") + return None + if PROFILE: + ans = F.batch_norm(x, mean, var, scale, b, training=False, eps=eps) + torch.cuda.synchronize() + return ans + return F.batch_norm(x, mean, var, scale, b, training=False, eps=eps) + + +# get the mean and variance of the input tensor across the batch size N and spatial dimensions +def get_mean_variance(x, dtype): + dims = tuple(range(x.ndim)) + reduction_dims = tuple(d for d in dims if d != 1) # Exclude the channel dimension + return x.mean(dim=reduction_dims, dtype=dtype), x.var( + dim=reduction_dims, unbiased=False + ).to(dtype) + + +def test( + lib, + handle, + torch_device, + x_shape, + eps=1e-5, + tensor_dtype=torch.float16, + inplace=Inplace.OUT_OF_PLACE, +): + print( + f"Testing BatchNorm on {torch_device} with x_shape: {x_shape}, scale_shape: {x_shape[1]}, b_shape: {x_shape[1]}, mean_shape: {x_shape[1]}, var_shape: {x_shape[1]}, eps: {eps}, dtype:{tensor_dtype}, Inplace:{inplace}" + ) + num_channel = x_shape[1] + bn_dtype = tensor_dtype if tensor_dtype != torch.float16 else torch.float32 + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) * 10 - 2 + scale = torch.rand(num_channel, dtype=bn_dtype).to(torch_device) + b = torch.rand(num_channel, dtype=bn_dtype).to(torch_device) + mean, var = get_mean_variance(x, bn_dtype) + y = torch.zeros(x_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else x + + # get the pytorch answer + for i in range(NUM_PRERUN if PROFILE else 1): + ans = batch_norm(x, scale, b, mean, var, eps) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + _ = batch_norm(x, scale, b, mean, var, eps) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + # get the operators' answer + x_tensor = to_tensor(x, lib) + scale_tensor = to_tensor(scale, lib) + b_tensor = to_tensor(b, lib) + mean_tensor = to_tensor(mean, lib) + var_tensor = to_tensor(var, lib) + y_tensor = to_tensor(y, lib) if inplace == Inplace.OUT_OF_PLACE else x_tensor + descriptor = infiniopBatchNormDescriptor_t() + + check_error( + lib.infiniopCreateBatchNormDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + scale_tensor.descriptor, + b_tensor.descriptor, + mean_tensor.descriptor, + var_tensor.descriptor, + eps, + ) + ) + + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopBatchNorm( + descriptor, + y_tensor.data, + x_tensor.data, + scale_tensor.data, + b_tensor.data, + mean_tensor.data, + var_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopBatchNorm( + descriptor, + y_tensor.data, + x_tensor.data, + scale_tensor.data, + b_tensor.data, + mean_tensor.data, + var_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + if (tensor_dtype == torch.float16): + assert torch.allclose(y, ans, atol=1e-5, rtol=1e-3) + else: # float32 + assert torch.allclose(y, ans, atol=1e-6, rtol=1e-5) + check_error(lib.infiniopDestroyBatchNormDescriptor(descriptor)) + + +def test_operator(lib, device, test_cases, tensor_dtypes): + handle = create_handle(lib, device) + for x_shape, eps, inplace in test_cases: + for tensor_dtype in tensor_dtypes: + test(lib, handle, device_enum_to_str(device), x_shape, eps, inplace=inplace, tensor_dtype=tensor_dtype) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, eps, inplace + ((2, 5, 7), 1e-5, Inplace.OUT_OF_PLACE), + ((2, 5, 7), 1e-5, Inplace.INPLACE_X), + ((32, 3, 1024), 1e-5, Inplace.OUT_OF_PLACE), + ((32, 3, 128, 128), 1e-5, Inplace.OUT_OF_PLACE), + ((32, 3, 64, 64, 64), 1e-5, Inplace.OUT_OF_PLACE), + ] + tensor_dtypes = [ + torch.float16, torch.float32, + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateBatchNormDescriptor.restype = c_int32 + lib.infiniopCreateBatchNormDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopBatchNormDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_double, + ] + lib.infiniopBatchNorm.restype = c_int32 + lib.infiniopBatchNorm.argtypes = [ + infiniopBatchNormDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyBatchNormDescriptor.restype = c_int32 + lib.infiniopDestroyBatchNormDescriptor.argtypes = [ + infiniopBatchNormDescriptor_t, + ] + + if args.cpu: + test_operator(lib, DeviceEnum.DEVICE_CPU, test_cases, tensor_dtypes) + if args.cuda: + test_operator(lib, DeviceEnum.DEVICE_CUDA, test_cases, tensor_dtypes) + if args.bang: + import torch_mlu + test_operator(lib, DeviceEnum.DEVICE_BANG, test_cases, tensor_dtypes) + if not (args.cpu or args.cuda or args.bang): + test_operator(lib, DeviceEnum.DEVICE_CPU, test_cases, tensor_dtypes) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/utils.py b/operatorspy/utils.py index b079d871..0a1c3ced 100644 --- a/operatorspy/utils.py +++ b/operatorspy/utils.py @@ -1,4 +1,5 @@ import ctypes +from .devices import DeviceEnum from .data_layout import * from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t @@ -106,3 +107,14 @@ def rearrange_tensor(tensor, new_strides): new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides)) return new_tensor + +def device_enum_to_str(device: DeviceEnum): + if device == DeviceEnum.DEVICE_CPU: + return "cpu" + if device == DeviceEnum.DEVICE_CUDA: + return "cuda" + if device == DeviceEnum.DEVICE_BANG: + return "mlu" + if device == DeviceEnum.DEVICE_ASCEND: + return "npu" + return "" diff --git a/src/ops/batch_norm/cpu/batch_norm_cpu.cc b/src/ops/batch_norm/cpu/batch_norm_cpu.cc new file mode 100644 index 00000000..b4f8d3ef --- /dev/null +++ b/src/ops/batch_norm/cpu/batch_norm_cpu.cc @@ -0,0 +1,98 @@ +#include "batch_norm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" + +infiniopStatus_t cpuCreateBatchNormDescriptor(infiniopHandle_t, + BatchNormCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t scale, + infiniopTensorDescriptor_t b, + infiniopTensorDescriptor_t mean, + infiniopTensorDescriptor_t var, + double eps) { + uint64_t ndim = y->ndim; + if (ndim != x->ndim || scale->ndim != b->ndim || scale->ndim != mean->ndim || scale->ndim != var->ndim || scale->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (size_t i = 0; i < ndim; ++i) { + if (y->shape[i] != x->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (x->shape[1] != scale->shape[0] || scale->shape[0] != b->shape[0] || scale->shape[0] != mean->shape[0] || scale->shape[0] != var->shape[0]) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!is_contiguous(y) || !is_contiguous(x)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (eps < 0) { + return STATUS_BAD_PARAM; + } + + uint64_t spatial_data_size = std::accumulate(x->shape + 2, x->shape + x->ndim, 1ULL, std::multiplies()); + uint64_t batch_size = x->shape[0]; + uint64_t channel_size = x->shape[1]; + + *desc_ptr = new BatchNormCpuDescriptor{ + DevCpu, + y->dt, + batch_size, + channel_size, + spatial_data_size, + channel_size * spatial_data_size, + eps, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyBatchNormDescriptor(BatchNormCpuDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t batch_norm_cpu(BatchNormCpuDescriptor_t desc, void *y, void const *x, + void const *scale, void const *b, void const *mean, void const *var) { + auto x_ = reinterpret_cast(x); + auto scale_ = reinterpret_cast(scale); + auto b_ = reinterpret_cast(b); + auto mean_ = reinterpret_cast(mean); + auto var_ = reinterpret_cast(var); + auto y_ = reinterpret_cast(y); + +#pragma omp parallel for collapse(3) + for (uint64_t i = 0; i < desc->batch_size; ++i) { + for (uint64_t c = 0; c < desc->channel_size; ++c) { + for (uint64_t j = 0; j < desc->spatial_data_size; ++j) { + auto idx = (i * desc->channel_size + c) * desc->spatial_data_size + j; + Pdata invsqrt = 1 / std::sqrt(var_[c] + desc->eps); + if constexpr (std::is_same::value) { + y_[idx] = f32_to_f16((f16_to_f32(x_[idx]) - mean_[c]) * invsqrt * scale_[c] + b_[c]); + } else { + y_[idx] = (x_[idx] - mean_[c]) * invsqrt * scale_[c] + b_[c]; + } + } + } + } + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuBatchNorm(BatchNormCpuDescriptor_t desc, + void *y, void const *x, void const *scale, void const *b, + void const *mean, void const *var, void *stream) { + if (desc->dtype == F16) { + return batch_norm_cpu(desc, y, x, scale, b, mean, var); + } + if (desc->dtype == F32) { + return batch_norm_cpu(desc, y, x, scale, b, mean, var); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/batch_norm/cpu/batch_norm_cpu.h b/src/ops/batch_norm/cpu/batch_norm_cpu.h new file mode 100644 index 00000000..726cc4ad --- /dev/null +++ b/src/ops/batch_norm/cpu/batch_norm_cpu.h @@ -0,0 +1,36 @@ +#ifndef __CPU_BATCH_NORM_H__ +#define __CPU_BATCH_NORM_H__ + +#include "operators.h" +#include +#include + +struct BatchNormCpuDescriptor { + Device device; + DT dtype; + uint64_t batch_size; + uint64_t channel_size; + uint64_t spatial_data_size; + uint64_t per_batch_data_size; + double eps; +}; + +typedef struct BatchNormCpuDescriptor *BatchNormCpuDescriptor_t; + +infiniopStatus_t cpuCreateBatchNormDescriptor(infiniopHandle_t, + BatchNormCpuDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t scale, + infiniopTensorDescriptor_t b, + infiniopTensorDescriptor_t mean, + infiniopTensorDescriptor_t var, + double eps); + +infiniopStatus_t cpuBatchNorm(BatchNormCpuDescriptor_t desc, + void *y, void const *x, void const *scale, void const *b, + void const *mean, void const *var, void *stream); + +infiniopStatus_t cpuDestroyBatchNormDescriptor(BatchNormCpuDescriptor_t desc); + +#endif diff --git a/src/ops/batch_norm/cuda/batch_norm.cc b/src/ops/batch_norm/cuda/batch_norm.cc new file mode 100644 index 00000000..e902f586 --- /dev/null +++ b/src/ops/batch_norm/cuda/batch_norm.cc @@ -0,0 +1,112 @@ +#include "batch_norm.cuh" +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" + +infiniopStatus_t cudaCreateBatchNormDescriptor(CudaHandle_t handle, + BatchNormCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t scale, + infiniopTensorDescriptor_t b, + infiniopTensorDescriptor_t mean, + infiniopTensorDescriptor_t var, + double eps) { + uint64_t ndim = y->ndim; + if (ndim != x->ndim || scale->ndim != b->ndim || scale->ndim != mean->ndim || scale->ndim != var->ndim || scale->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (size_t i = 0; i < ndim; ++i) { + if (y->shape[i] != x->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (x->shape[1] != scale->shape[0] || scale->shape[0] != b->shape[0] || scale->shape[0] != mean->shape[0] || scale->shape[0] != var->shape[0]) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!is_contiguous(y) || !is_contiguous(x)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (eps < CUDNN_BN_MIN_EPSILON) { + return STATUS_BAD_PARAM; + } + + const auto new_ndim = std::max(4UL, ndim); + int32_t x_shape[new_ndim]; + int32_t y_shape[new_ndim]; + int32_t x_strides[new_ndim]; + int32_t y_strides[new_ndim]; + int32_t bn_shape[new_ndim]; + int32_t bn_strides[new_ndim]; + for (size_t i = 0; i < new_ndim; ++i) { + x_shape[i] = i < ndim ? static_cast(x->shape[i]) : 1; + x_strides[i] = i < ndim ? static_cast(x->strides[i]) : 1; + y_shape[i] = i < ndim ? static_cast(y->shape[i]) : 1; + y_strides[i] = i < ndim ? static_cast(y->strides[i]) : 1; + bn_shape[i] = i == 1 ? x->shape[i] : 1; + bn_strides[i] = 1; + } + + // get the data types of the tensors and the conv operator + CREATE_CHECK_ERROR(auto tensor_dt = dataTypeMap[x->dt], tensor_dt, -1, STATUS_BAD_PARAM); + cudnnDataType_t bn_dt = [&] { + switch (tensor_dt) { + case CUDNN_DATA_INT8: + case CUDNN_DATA_HALF: + case CUDNN_DATA_FLOAT: + return CUDNN_DATA_FLOAT; + default: + return CUDNN_DATA_DOUBLE; + } + }(); + + // get the input tensor descriptor + cudnnTensorDescriptor_t x_desc; + checkCudnnError(cudnnCreateTensorDescriptor(&x_desc)); + checkCudnnError(cudnnSetTensorNdDescriptor(x_desc, static_cast(tensor_dt), new_ndim, x_shape, x_strides)); + + // get the secondary tensor descriptor + cudnnTensorDescriptor_t bn_desc; + cudnnBatchNormMode_t mode; + checkCudnnError(cudnnCreateTensorDescriptor(&bn_desc)); + if (handle->compute_capability_major > 6 || (handle->compute_capability_major == 6 && handle->compute_capability_minor >= 0)) { + mode = CUDNN_BATCHNORM_SPATIAL; + } else { + mode = CUDNN_BATCHNORM_SPATIAL; + } + // checkCudnnError(cudnnDeriveBNTensorDescriptor(bn_desc, x_desc, mode)); + checkCudnnError(cudnnSetTensorNdDescriptor(bn_desc, static_cast(bn_dt), new_ndim, bn_shape, bn_strides)); + + // get the output tensor descriptor + cudnnTensorDescriptor_t y_desc; + checkCudnnError(cudnnCreateTensorDescriptor(&y_desc)); + checkCudnnError(cudnnSetTensorNdDescriptor(y_desc, static_cast(tensor_dt), new_ndim, y_shape, y_strides)); + + float alpha = 1.0f, beta = 0.0f; + + *desc_ptr = new BatchNormCudaDescriptor{ + DevNvGpu, + y->dt, + handle->device_id, + handle->cudnn_handles_t, + x_desc, + bn_desc, + y_desc, + alpha, + beta, + eps, + mode, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyBatchNormDescriptor(BatchNormCudaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/batch_norm/cuda/batch_norm.cu b/src/ops/batch_norm/cuda/batch_norm.cu new file mode 100644 index 00000000..7e0acb96 --- /dev/null +++ b/src/ops/batch_norm/cuda/batch_norm.cu @@ -0,0 +1,21 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "batch_norm.cuh" + +infiniopStatus_t batch_norm_nv_gpu(BatchNormCudaDescriptor_t desc, void *y, void const *x, void const *scale, void const *b, void const *mean, void const *var, void *stream) { + checkCudaError(cudaSetDevice(desc->device_id)); + checkCudnnError(use_cudnn(desc->cudnn_handles_t, desc->device_id, (cudaStream_t) stream, + [&](cudnnHandle_t handle) { return cudnnBatchNormalizationForwardInference(handle, desc->mode, &desc->alpha, &desc->beta, + desc->x_desc, x, desc->y_desc, y, desc->bn_desc, + scale, b, mean, var, desc->eps); })); + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaBatchNorm(BatchNormCudaDescriptor_t desc, void *y, void const *x, + void const *scale, void const *b, void const *mean, void const *var, + void *stream) { + if (desc->dtype == F16 || desc->dtype == F32) { + return batch_norm_nv_gpu(desc, y, x, scale, b, mean, var, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/batch_norm/cuda/batch_norm.cuh b/src/ops/batch_norm/cuda/batch_norm.cuh new file mode 100644 index 00000000..e7aff0e3 --- /dev/null +++ b/src/ops/batch_norm/cuda/batch_norm.cuh @@ -0,0 +1,42 @@ +#ifndef __CUDA_BATCH_NORM_H__ +#define __CUDA_BATCH_NORM_H__ + +#include "../../../devices/cuda/common_cuda.h" +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" +#include +#include + +struct BatchNormCudaDescriptor { + Device device; + DT dtype; + int device_id; + std::shared_ptr> cudnn_handles_t; + const cudnnTensorDescriptor_t x_desc; + const cudnnTensorDescriptor_t bn_desc; + const cudnnTensorDescriptor_t y_desc; + const float alpha; + const float beta; + const double eps; + cudnnBatchNormMode_t mode; +}; + +typedef struct BatchNormCudaDescriptor *BatchNormCudaDescriptor_t; + +infiniopStatus_t cudaCreateBatchNormDescriptor(CudaHandle_t, + BatchNormCudaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t scale, + infiniopTensorDescriptor_t b, + infiniopTensorDescriptor_t mean, + infiniopTensorDescriptor_t var, + double eps); + +infiniopStatus_t cudaBatchNorm(BatchNormCudaDescriptor_t desc, + void *y, void const *x, void const *scale, void const *b, + void const *mean, void const *var, void *stream); + +infiniopStatus_t cudaDestroyBatchNormDescriptor(BatchNormCudaDescriptor_t desc); + +#endif diff --git a/src/ops/batch_norm/operator.cc b/src/ops/batch_norm/operator.cc new file mode 100644 index 00000000..ccf159d4 --- /dev/null +++ b/src/ops/batch_norm/operator.cc @@ -0,0 +1,78 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/batch_norm/batch_norm.h" + +#ifdef ENABLE_CPU +#include "cpu/batch_norm_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/batch_norm.cuh" +#endif + +__C infiniopStatus_t infiniopCreateBatchNormDescriptor( + infiniopHandle_t handle, + infiniopBatchNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t scale, + infiniopTensorDescriptor_t b, + infiniopTensorDescriptor_t mean, + infiniopTensorDescriptor_t var, + double eps) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateBatchNormDescriptor(handle, (BatchNormCpuDescriptor_t *) desc_ptr, y, x, scale, b, mean, var, eps); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateBatchNormDescriptor((CudaHandle_t) handle, (BatchNormCudaDescriptor_t *) desc_ptr, y, x, scale, b, mean, var, eps); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopBatchNorm(infiniopBatchNormDescriptor_t desc, void *y, void const *x, void const *scale, void const *b, + void const *mean, void const *var, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuBatchNorm((BatchNormCpuDescriptor_t) desc, y, x, scale, b, mean, var, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaBatchNorm((BatchNormCudaDescriptor_t) desc, y, x, scale, b, mean, var, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyBatchNormDescriptor(infiniopBatchNormDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyBatchNormDescriptor((BatchNormCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyBatchNormDescriptor((BatchNormCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + // TODO +#endif + } + return STATUS_BAD_DEVICE; +}