Skip to content

Commit fbd4e0c

Browse files
committed
Add batch norm CPU and CUDA implementation
1 parent fdbf030 commit fbd4e0c

File tree

9 files changed

+690
-14
lines changed

9 files changed

+690
-14
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef BATCH_NORM_H
2+
#define BATCH_NORM_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct BatchNormDescriptor {
8+
Device device;
9+
} BatchNormDescriptor;
10+
11+
typedef BatchNormDescriptor *infiniopBatchNormDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateBatchNormDescriptor(infiniopHandle_t handle,
14+
infiniopBatchNormDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x,
17+
infiniopTensorDescriptor_t scale,
18+
infiniopTensorDescriptor_t b,
19+
infiniopTensorDescriptor_t mean,
20+
infiniopTensorDescriptor_t var,
21+
double eps);
22+
23+
__C __export infiniopStatus_t infiniopBatchNorm(infiniopBatchNormDescriptor_t desc,
24+
void *y,
25+
void const *x,
26+
void const *scale,
27+
void const *b,
28+
void const *mean,
29+
void const *var,
30+
void *stream);
31+
32+
__C __export infiniopStatus_t infiniopDestroyBatchNormDescriptor(infiniopBatchNormDescriptor_t desc);
33+
34+
#endif

operatorspy/tests/batch_norm.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p, c_double
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
import torch
21+
import ctypes
22+
import torch.nn.functional as F
23+
import numpy as np
24+
25+
# constant for control whether profile the pytorch and lib functions
26+
# NOTE: need to manually add synchronization function to the lib function,
27+
# e.g., cudaDeviceSynchronize() for CUDA
28+
PROFILE = False
29+
NUM_PRERUN = 10
30+
NUM_ITERATIONS = 1000
31+
32+
33+
class BatchNormDescriptor(Structure):
34+
_fields_ = [("device", c_int32)]
35+
36+
37+
infiniopBatchNormDescriptor_t = POINTER(BatchNormDescriptor)
38+
39+
40+
def batch_norm(x, scale, b, mean, var, eps):
41+
ndim = len(x.shape)
42+
if ndim <= 1 or ndim > 5:
43+
print("Error: Pytorch -> Unsupported tensor dimension")
44+
return None
45+
if PROFILE:
46+
ans = F.batch_norm(x, mean, var, scale, b, training=False, eps=eps)
47+
torch.cuda.synchronize()
48+
return ans
49+
return F.batch_norm(x, mean, var, scale, b, training=False, eps=eps)
50+
51+
52+
# get the mean and variance of the input tensor across the batch size N and spatial dimensions
53+
def get_mean_variance(x, dtype):
54+
dims = tuple(range(x.ndim))
55+
reduction_dims = tuple(d for d in dims if d != 1) # Exclude the channel dimension
56+
return x.mean(dim=reduction_dims, dtype=dtype), x.var(
57+
dim=reduction_dims, unbiased=False
58+
).to(dtype)
59+
60+
61+
def find_and_print_differing_indices(
62+
x, tensor1, tensor2, mean, scale, var, b, atol=0, rtol=1e-2
63+
):
64+
if tensor1.shape != tensor2.shape:
65+
raise ValueError("Tensors must have the same shape to compare.")
66+
67+
# Calculate the difference mask based on atol and rtol
68+
diff_mask = torch.abs(tensor1 - tensor2) > (atol + rtol * torch.abs(tensor2))
69+
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
70+
71+
# Print the indices and the differing elements
72+
for idx in diff_indices:
73+
index_tuple = tuple(idx.tolist())
74+
print(
75+
f"Index: {index_tuple}, x: {x[index_tuple]}, mean: {mean[index_tuple[1]]}, scale: {scale[index_tuple[1]]}, var: {var[index_tuple[1]]}, b: {b[index_tuple[1]]}, y element: {tensor1[index_tuple]}, ans element: {tensor2[index_tuple]}"
76+
)
77+
78+
return diff_indices
79+
80+
81+
def test(
82+
lib,
83+
handle,
84+
torch_device,
85+
x_shape,
86+
eps=1e-5,
87+
tensor_dtype=torch.float16,
88+
):
89+
print(
90+
f"Testing BatchNorm on {torch_device} with x_shape: {x_shape}, scale_shape: {x_shape[1]}, b_shape: {x_shape[1]}, mean_shape: {x_shape[1]}, var_shape: {x_shape[1]}, eps: {eps} dtype:{tensor_dtype}"
91+
)
92+
num_channel = x_shape[1]
93+
bn_dtype = tensor_dtype if tensor_dtype != torch.float16 else torch.float32
94+
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) * 10 - 2
95+
scale = torch.rand(num_channel, dtype=bn_dtype).to(torch_device)
96+
b = torch.rand(num_channel, dtype=bn_dtype).to(torch_device)
97+
mean, var = get_mean_variance(x, bn_dtype)
98+
y = torch.zeros(x_shape, dtype=tensor_dtype).to(torch_device)
99+
100+
# get the pytorch answer
101+
for i in range(NUM_PRERUN if PROFILE else 1):
102+
ans = batch_norm(x, scale, b, mean, var, eps)
103+
if PROFILE:
104+
start_time = time.time()
105+
for i in range(NUM_ITERATIONS):
106+
_ = batch_norm(x, scale, b, mean, var, eps)
107+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
108+
print(f"pytorch time: {elapsed :6f}")
109+
110+
# get the operators' answer
111+
x_tensor = to_tensor(x, lib)
112+
scale_tensor = to_tensor(scale, lib)
113+
b_tensor = to_tensor(b, lib)
114+
mean_tensor = to_tensor(mean, lib)
115+
var_tensor = to_tensor(var, lib)
116+
y_tensor = to_tensor(y, lib)
117+
descriptor = infiniopBatchNormDescriptor_t()
118+
119+
check_error(
120+
lib.infiniopCreateBatchNormDescriptor(
121+
handle,
122+
ctypes.byref(descriptor),
123+
y_tensor.descriptor,
124+
x_tensor.descriptor,
125+
scale_tensor.descriptor,
126+
b_tensor.descriptor,
127+
mean_tensor.descriptor,
128+
var_tensor.descriptor,
129+
eps,
130+
)
131+
)
132+
133+
for i in range(NUM_PRERUN if PROFILE else 1):
134+
check_error(
135+
lib.infiniopBatchNorm(
136+
descriptor,
137+
y_tensor.data,
138+
x_tensor.data,
139+
scale_tensor.data,
140+
b_tensor.data,
141+
mean_tensor.data,
142+
var_tensor.data,
143+
None,
144+
)
145+
)
146+
if PROFILE:
147+
start_time = time.time()
148+
for i in range(NUM_ITERATIONS):
149+
lib.infiniopBatchNorm(
150+
descriptor,
151+
y_tensor.data,
152+
x_tensor.data,
153+
scale_tensor.data,
154+
b_tensor.data,
155+
mean_tensor.data,
156+
var_tensor.data,
157+
None,
158+
)
159+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
160+
print(f" lib time: {elapsed :6f}")
161+
162+
# print(" - x: \n", x, "\n - y:\n", y, "\n - ans:\n", ans)
163+
# print(" - y:\n", y, "\n - ans:\n", ans)
164+
165+
# find_and_print_differing_indices(x, y, ans, mean, scale, mean, b, atol=1e-7, rtol=1e-3)
166+
# np.testing.assert_allclose(y.numpy(), ans.numpy(), atol=1e-7, rtol=1e-3)
167+
assert torch.allclose(y, ans, atol=1e-7, rtol=1e-3)
168+
check_error(lib.infiniopDestroyBatchNormDescriptor(descriptor))
169+
170+
171+
def test_cpu(lib, test_cases):
172+
device = DeviceEnum.DEVICE_CPU
173+
handle = create_handle(lib, device)
174+
for x_shape, eps in test_cases:
175+
test(lib, handle, "cpu", x_shape, eps, tensor_dtype=torch.float16)
176+
test(lib, handle, "cpu", x_shape, eps, tensor_dtype=torch.float32)
177+
destroy_handle(lib, handle)
178+
179+
180+
def test_cuda(lib, test_cases):
181+
device = DeviceEnum.DEVICE_CUDA
182+
handle = create_handle(lib, device)
183+
for x_shape, eps in test_cases:
184+
test(lib, handle, "cuda", x_shape, eps, tensor_dtype=torch.float16)
185+
test(lib, handle, "cuda", x_shape, eps, tensor_dtype=torch.float32)
186+
destroy_handle(lib, handle)
187+
188+
189+
def test_bang(lib, test_cases):
190+
import torch_mlu
191+
192+
device = DeviceEnum.DEVICE_BANG
193+
handle = create_handle(lib, device)
194+
for x_shape, eps in test_cases:
195+
test(lib, handle, "mlu", x_shape, eps, tensor_dtype=torch.float16)
196+
test(lib, handle, "mlu", x_shape, eps, tensor_dtype=torch.float32)
197+
destroy_handle(lib, handle)
198+
199+
200+
if __name__ == "__main__":
201+
test_cases = [
202+
# x_shape, eps
203+
((2, 3, 4), 1e-5),
204+
((32, 3, 1024), 1e-5),
205+
((1, 3, 4, 4), 1e-5),
206+
((32, 3, 128, 128), 1e-5),
207+
((1, 6, 5, 5, 5), 1e-5),
208+
((32, 3, 64, 64, 64), 1e-5),
209+
]
210+
args = get_args()
211+
lib = open_lib()
212+
lib.infiniopCreateBatchNormDescriptor.restype = c_int32
213+
lib.infiniopCreateBatchNormDescriptor.argtypes = [
214+
infiniopHandle_t,
215+
POINTER(infiniopBatchNormDescriptor_t),
216+
infiniopTensorDescriptor_t,
217+
infiniopTensorDescriptor_t,
218+
infiniopTensorDescriptor_t,
219+
infiniopTensorDescriptor_t,
220+
infiniopTensorDescriptor_t,
221+
infiniopTensorDescriptor_t,
222+
c_double,
223+
]
224+
lib.infiniopBatchNorm.restype = c_int32
225+
lib.infiniopBatchNorm.argtypes = [
226+
infiniopBatchNormDescriptor_t,
227+
c_void_p,
228+
c_void_p,
229+
c_void_p,
230+
c_void_p,
231+
c_void_p,
232+
c_void_p,
233+
c_void_p,
234+
]
235+
lib.infiniopDestroyBatchNormDescriptor.restype = c_int32
236+
lib.infiniopDestroyBatchNormDescriptor.argtypes = [
237+
infiniopBatchNormDescriptor_t,
238+
]
239+
240+
if args.cpu:
241+
test_cpu(lib, test_cases)
242+
if args.cuda:
243+
test_cuda(lib, test_cases)
244+
if args.bang:
245+
test_bang(lib, test_cases)
246+
if not (args.cpu or args.cuda or args.bang):
247+
test_cpu(lib, test_cases)
248+
print("\033[92mTest passed!\033[0m")

src/devices/cpu/common_cpu.cc

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,19 @@ uint16_t f32_to_f16(float val) {
4444
int32_t exponent = ((f32 >> 23) & 0xFF) - 127;// Extract and de-bias the exponent
4545
uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part)
4646

47-
if (exponent == 128) {// Special case for Inf and NaN
48-
if (mantissa != 0) {
49-
// NaN
50-
return sign | 0x7C00 | (mantissa >> 13);// Convert the NaN payload
51-
} else {
52-
// Infinity
53-
return sign | 0x7C00;
47+
if (exponent >= 31) {// Special cases for Inf and NaN
48+
// NaN
49+
if (exponent == 128 && mantissa != 0) {
50+
return sign | 0x7E00;
5451
}
55-
} else if (exponent > 15) { // Overflow: Larger than float16 max
56-
return sign | 0x7C00; // Return infinity
57-
} else if (exponent >= -14) {// Normalized float16
52+
// Infinity
53+
return sign | 0x7C00;
54+
} else if (exponent >= -14) {// Normalized case
5855
return sign | ((exponent + 15) << 10) | (mantissa >> 13);
59-
} else if (exponent >= -24) { // Subnormal float16 (leading denormals)
60-
mantissa |= 0x800000; // Add implicit leading 1
61-
int32_t shift = -exponent - 1;// Calculate shift for subnormal numbers
62-
return sign | (mantissa >> (13 + shift));
56+
} else if (exponent >= -24) {
57+
mantissa |= 0x800000;// Add implicit leading 1
58+
mantissa >>= (-14 - exponent);
59+
return sign | (mantissa >> 13);
6360
} else {
6461
// Too small for subnormal: return signed zero
6562
return sign;

0 commit comments

Comments
 (0)