Skip to content

Commit 4fdd9b1

Browse files
committed
Add batch norm CPU and CUDA implementation
1 parent 0ab53f1 commit 4fdd9b1

File tree

8 files changed

+679
-0
lines changed

8 files changed

+679
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef BATCH_NORM_H
2+
#define BATCH_NORM_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct BatchNormDescriptor {
8+
Device device;
9+
} BatchNormDescriptor;
10+
11+
typedef BatchNormDescriptor *infiniopBatchNormDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateBatchNormDescriptor(infiniopHandle_t handle,
14+
infiniopBatchNormDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x,
17+
infiniopTensorDescriptor_t scale,
18+
infiniopTensorDescriptor_t b,
19+
infiniopTensorDescriptor_t mean,
20+
infiniopTensorDescriptor_t var,
21+
double eps);
22+
23+
__C __export infiniopStatus_t infiniopBatchNorm(infiniopBatchNormDescriptor_t desc,
24+
void *y,
25+
void const *x,
26+
void const *scale,
27+
void const *b,
28+
void const *mean,
29+
void const *var,
30+
void *stream);
31+
32+
__C __export infiniopStatus_t infiniopDestroyBatchNormDescriptor(infiniopBatchNormDescriptor_t desc);
33+
34+
#endif

operatorspy/tests/batch_norm.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p, c_double
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
import torch
21+
import ctypes
22+
import torch.nn.functional as F
23+
import numpy as np
24+
25+
# constant for control whether profile the pytorch and lib functions
26+
# NOTE: need to manually add synchronization function to the lib function,
27+
# e.g., cudaDeviceSynchronize() for CUDA
28+
PROFILE = False
29+
NUM_PRERUN = 10
30+
NUM_ITERATIONS = 1000
31+
32+
33+
class BatchNormDescriptor(Structure):
34+
_fields_ = [("device", c_int32)]
35+
36+
37+
infiniopBatchNormDescriptor_t = POINTER(BatchNormDescriptor)
38+
39+
40+
def batch_norm(x, scale, b, mean, var, eps):
41+
ndim = len(x.shape)
42+
if ndim <= 1 or ndim > 5:
43+
print("Error: Pytorch -> Unsupported tensor dimension")
44+
return None
45+
if PROFILE:
46+
ans = F.batch_norm(x, mean, var, scale, b, training=False, eps=eps)
47+
torch.cuda.synchronize()
48+
return ans
49+
return F.batch_norm(x, mean, var, scale, b, training=False, eps=eps)
50+
51+
52+
# get the mean and variance of the input tensor across the batch size N and spatial dimensions
53+
def get_mean_variance(x, dtype):
54+
dims = tuple(range(x.ndim))
55+
reduction_dims = tuple(d for d in dims if d != 1) # Exclude the channel dimension
56+
return x.mean(dim=reduction_dims, dtype=dtype), x.var(
57+
dim=reduction_dims, unbiased=False
58+
).to(dtype)
59+
60+
61+
def find_and_print_differing_indices(
62+
x, tensor1, tensor2, mean, scale, var, b, atol=0, rtol=1e-2
63+
):
64+
if tensor1.shape != tensor2.shape:
65+
raise ValueError("Tensors must have the same shape to compare.")
66+
67+
# Calculate the difference mask based on atol and rtol
68+
diff_mask = torch.abs(tensor1 - tensor2) > (atol + rtol * torch.abs(tensor2))
69+
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
70+
71+
# Print the indices and the differing elements
72+
for idx in diff_indices:
73+
index_tuple = tuple(idx.tolist())
74+
print(
75+
f"Index: {index_tuple}, x: {x[index_tuple]}, mean: {mean[index_tuple[1]]}, scale: {scale[index_tuple[1]]}, var: {var[index_tuple[1]]}, b: {b[index_tuple[1]]}, y element: {tensor1[index_tuple]}, ans element: {tensor2[index_tuple]}"
76+
)
77+
78+
return diff_indices
79+
80+
81+
def test(
82+
lib,
83+
handle,
84+
torch_device,
85+
x_shape,
86+
eps=1e-5,
87+
tensor_dtype=torch.float16,
88+
):
89+
print(
90+
f"Testing BatchNorm on {torch_device} with x_shape: {x_shape}, scale_shape: {x_shape[1]}, b_shape: {x_shape[1]}, mean_shape: {x_shape[1]}, var_shape: {x_shape[1]}, eps: {eps} dtype:{tensor_dtype}"
91+
)
92+
num_channel = x_shape[1]
93+
bn_dtype = tensor_dtype if tensor_dtype != torch.float16 else torch.float32
94+
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) * 10 - 2
95+
scale = torch.rand(num_channel, dtype=bn_dtype).to(torch_device)
96+
b = torch.rand(num_channel, dtype=bn_dtype).to(torch_device)
97+
mean, var = get_mean_variance(x, bn_dtype)
98+
y = torch.zeros(x_shape, dtype=tensor_dtype).to(torch_device)
99+
100+
# get the pytorch answer
101+
for i in range(NUM_PRERUN if PROFILE else 1):
102+
ans = batch_norm(x, scale, b, mean, var, eps)
103+
if PROFILE:
104+
start_time = time.time()
105+
for i in range(NUM_ITERATIONS):
106+
_ = batch_norm(x, scale, b, mean, var, eps)
107+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
108+
print(f"pytorch time: {elapsed :6f}")
109+
110+
# get the operators' answer
111+
x_tensor = to_tensor(x, lib)
112+
scale_tensor = to_tensor(scale, lib)
113+
b_tensor = to_tensor(b, lib)
114+
mean_tensor = to_tensor(mean, lib)
115+
var_tensor = to_tensor(var, lib)
116+
y_tensor = to_tensor(y, lib)
117+
descriptor = infiniopBatchNormDescriptor_t()
118+
119+
check_error(
120+
lib.infiniopCreateBatchNormDescriptor(
121+
handle,
122+
ctypes.byref(descriptor),
123+
y_tensor.descriptor,
124+
x_tensor.descriptor,
125+
scale_tensor.descriptor,
126+
b_tensor.descriptor,
127+
mean_tensor.descriptor,
128+
var_tensor.descriptor,
129+
eps,
130+
)
131+
)
132+
133+
for i in range(NUM_PRERUN if PROFILE else 1):
134+
check_error(
135+
lib.infiniopBatchNorm(
136+
descriptor,
137+
y_tensor.data,
138+
x_tensor.data,
139+
scale_tensor.data,
140+
b_tensor.data,
141+
mean_tensor.data,
142+
var_tensor.data,
143+
None,
144+
)
145+
)
146+
if PROFILE:
147+
start_time = time.time()
148+
for i in range(NUM_ITERATIONS):
149+
lib.infiniopBatchNorm(
150+
descriptor,
151+
y_tensor.data,
152+
x_tensor.data,
153+
scale_tensor.data,
154+
b_tensor.data,
155+
mean_tensor.data,
156+
var_tensor.data,
157+
None,
158+
)
159+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
160+
print(f" lib time: {elapsed :6f}")
161+
162+
# print(" - x: \n", x, "\n - y:\n", y, "\n - ans:\n", ans)
163+
# print(" - y:\n", y, "\n - ans:\n", ans)
164+
165+
# find_and_print_differing_indices(x, y, ans, mean, scale, mean, b, atol=1e-7, rtol=1e-3)
166+
# np.testing.assert_allclose(y.numpy(), ans.numpy(), atol=1e-7, rtol=1e-3)
167+
assert torch.allclose(y, ans, atol=1e-7, rtol=1e-3)
168+
check_error(lib.infiniopDestroyBatchNormDescriptor(descriptor))
169+
170+
171+
def test_cpu(lib, test_cases):
172+
device = DeviceEnum.DEVICE_CPU
173+
handle = create_handle(lib, device)
174+
for x_shape, eps in test_cases:
175+
test(lib, handle, "cpu", x_shape, eps, tensor_dtype=torch.float16)
176+
test(lib, handle, "cpu", x_shape, eps, tensor_dtype=torch.float32)
177+
destroy_handle(lib, handle)
178+
179+
180+
def test_cuda(lib, test_cases):
181+
device = DeviceEnum.DEVICE_CUDA
182+
handle = create_handle(lib, device)
183+
for x_shape, eps in test_cases:
184+
test(lib, handle, "cuda", x_shape, eps, tensor_dtype=torch.float16)
185+
test(lib, handle, "cuda", x_shape, eps, tensor_dtype=torch.float32)
186+
destroy_handle(lib, handle)
187+
188+
189+
def test_bang(lib, test_cases):
190+
import torch_mlu
191+
192+
device = DeviceEnum.DEVICE_BANG
193+
handle = create_handle(lib, device)
194+
for x_shape, eps in test_cases:
195+
test(lib, handle, "mlu", x_shape, eps, tensor_dtype=torch.float16)
196+
test(lib, handle, "mlu", x_shape, eps, tensor_dtype=torch.float32)
197+
destroy_handle(lib, handle)
198+
199+
200+
if __name__ == "__main__":
201+
test_cases = [
202+
# x_shape, eps
203+
((2, 3, 4), 1e-5),
204+
((32, 3, 1024), 1e-5),
205+
((1, 3, 4, 4), 1e-5),
206+
((32, 3, 128, 128), 1e-5),
207+
((1, 6, 5, 5, 5), 1e-5),
208+
((32, 3, 64, 64, 64), 1e-5),
209+
]
210+
args = get_args()
211+
lib = open_lib()
212+
lib.infiniopCreateBatchNormDescriptor.restype = c_int32
213+
lib.infiniopCreateBatchNormDescriptor.argtypes = [
214+
infiniopHandle_t,
215+
POINTER(infiniopBatchNormDescriptor_t),
216+
infiniopTensorDescriptor_t,
217+
infiniopTensorDescriptor_t,
218+
infiniopTensorDescriptor_t,
219+
infiniopTensorDescriptor_t,
220+
infiniopTensorDescriptor_t,
221+
infiniopTensorDescriptor_t,
222+
c_double,
223+
]
224+
lib.infiniopBatchNorm.restype = c_int32
225+
lib.infiniopBatchNorm.argtypes = [
226+
infiniopBatchNormDescriptor_t,
227+
c_void_p,
228+
c_void_p,
229+
c_void_p,
230+
c_void_p,
231+
c_void_p,
232+
c_void_p,
233+
c_void_p,
234+
]
235+
lib.infiniopDestroyBatchNormDescriptor.restype = c_int32
236+
lib.infiniopDestroyBatchNormDescriptor.argtypes = [
237+
infiniopBatchNormDescriptor_t,
238+
]
239+
240+
if args.cpu:
241+
test_cpu(lib, test_cases)
242+
if args.cuda:
243+
test_cuda(lib, test_cases)
244+
if args.bang:
245+
test_bang(lib, test_cases)
246+
if not (args.cpu or args.cuda or args.bang):
247+
test_cpu(lib, test_cases)
248+
print("\033[92mTest passed!\033[0m")
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include "batch_norm_cpu.h"
2+
#include "../../../devices/cpu/common_cpu.h"
3+
#include "../../utils.h"
4+
5+
infiniopStatus_t cpuCreateBatchNormDescriptor(infiniopHandle_t,
6+
BatchNormCpuDescriptor_t *desc_ptr,
7+
infiniopTensorDescriptor_t y,
8+
infiniopTensorDescriptor_t x,
9+
infiniopTensorDescriptor_t scale,
10+
infiniopTensorDescriptor_t b,
11+
infiniopTensorDescriptor_t mean,
12+
infiniopTensorDescriptor_t var,
13+
double eps) {
14+
uint64_t ndim = y->ndim;
15+
if (ndim != x->ndim || scale->ndim != b->ndim || scale->ndim != mean->ndim || scale->ndim != var->ndim) {
16+
return STATUS_BAD_TENSOR_SHAPE;
17+
}
18+
for (size_t i = 0; i < ndim; ++i) {
19+
if (y->shape[i] != x->shape[i]) {
20+
return STATUS_BAD_TENSOR_SHAPE;
21+
}
22+
}
23+
for (size_t i = 0; i < scale->ndim; ++i) {
24+
if (x->shape[1] != scale->shape[i] || scale->shape[i] != b->shape[i] || scale->shape[i] != mean->shape[i] || scale->shape[i] != var->shape[i]) {
25+
return STATUS_BAD_TENSOR_SHAPE;
26+
}
27+
}
28+
if (!is_contiguous(y) || !is_contiguous(x)) {
29+
return STATUS_BAD_TENSOR_STRIDES;
30+
}
31+
if (y->dt != F16 && y->dt != F32) {
32+
return STATUS_BAD_TENSOR_DTYPE;
33+
}
34+
if (y->dt != x->dt) {
35+
return STATUS_BAD_TENSOR_DTYPE;
36+
}
37+
if (eps < 0) {
38+
return STATUS_BAD_PARAM;
39+
}
40+
41+
uint64_t spatial_data_size = std::accumulate(x->shape + 2, x->shape + x->ndim, 1ULL, std::multiplies<uint64_t>());
42+
uint64_t batch_size = x->shape[0];
43+
uint64_t channel_size = x->shape[1];
44+
45+
*desc_ptr = new BatchNormCpuDescriptor{
46+
DevCpu,
47+
y->dt,
48+
batch_size,
49+
channel_size,
50+
spatial_data_size,
51+
channel_size * spatial_data_size,
52+
eps,
53+
};
54+
55+
return STATUS_SUCCESS;
56+
}
57+
58+
infiniopStatus_t cpuDestroyBatchNormDescriptor(BatchNormCpuDescriptor_t desc) {
59+
delete desc;
60+
return STATUS_SUCCESS;
61+
}
62+
63+
template<typename Tdata, typename Pdata>
64+
infiniopStatus_t batch_norm_cpu(BatchNormCpuDescriptor_t desc, void *y, void const *x,
65+
void const *scale, void const *b, void const *mean, void const *var) {
66+
auto x_ = reinterpret_cast<Tdata const *>(x);
67+
auto scale_ = reinterpret_cast<Pdata const *>(scale);
68+
auto b_ = reinterpret_cast<Pdata const *>(b);
69+
auto mean_ = reinterpret_cast<Pdata const *>(mean);
70+
auto var_ = reinterpret_cast<Pdata const *>(var);
71+
auto y_ = reinterpret_cast<Tdata *>(y);
72+
73+
#pragma omp parallel for collapse(3)
74+
for (uint64_t i = 0; i < desc->batch_size; ++i) {
75+
for (uint64_t c = 0; c < desc->channel_size; ++c) {
76+
for (uint64_t j = 0; j < desc->spatial_data_size; ++j) {
77+
auto idx = (i * desc->channel_size + c) * desc->spatial_data_size + j;
78+
Pdata invsqrt = 1 / std::sqrt(var_[c] + desc->eps);
79+
if constexpr (std::is_same<Tdata, uint16_t>::value) {
80+
y_[idx] = f32_to_f16((f16_to_f32(x_[idx]) - mean_[c]) * invsqrt * scale_[c] + b_[c]);
81+
} else {
82+
y_[idx] = (x_[idx] - mean_[c]) * invsqrt * scale_[c] + b_[c];
83+
}
84+
}
85+
}
86+
}
87+
return STATUS_SUCCESS;
88+
}
89+
90+
infiniopStatus_t cpuBatchNorm(BatchNormCpuDescriptor_t desc,
91+
void *y, void const *x, void const *scale, void const *b,
92+
void const *mean, void const *var, void *stream) {
93+
if (desc->dtype == F16) {
94+
return batch_norm_cpu<uint16_t, float>(desc, y, x, scale, b, mean, var);
95+
}
96+
if (desc->dtype == F32) {
97+
return batch_norm_cpu<float, float>(desc, y, x, scale, b, mean, var);
98+
}
99+
return STATUS_BAD_TENSOR_DTYPE;
100+
}

0 commit comments

Comments
 (0)