Skip to content

Commit 49ee9f2

Browse files
Merge pull request #68 from PanZezhong1725/add_relu
Add ReLU CPU and CUDA implementation
2 parents 0cb4203 + e9f3ec2 commit 49ee9f2

File tree

9 files changed

+559
-0
lines changed

9 files changed

+559
-0
lines changed

include/infini_operators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ops/mlp/mlp.h"
88
#include "ops/random_sample/random_sample.h"
99
#include "ops/rearrange/rearrange.h"
10+
#include "ops/relu/relu.h"
1011
#include "ops/rms_norm/rms_norm.h"
1112
#include "ops/rotary_embedding/rotary_embedding.h"
1213
#include "ops/swiglu/swiglu.h"

include/ops/relu/relu.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef RELU_H
2+
#define RELU_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ReluDescriptor {
8+
Device device;
9+
} ReluDescriptor;
10+
11+
typedef ReluDescriptor *infiniopReluDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateReluDescriptor(infiniopHandle_t handle,
14+
infiniopReluDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x);
17+
18+
__C __export infiniopStatus_t infiniopRelu(infiniopReluDescriptor_t desc,
19+
void *y,
20+
void const *x,
21+
void *stream);
22+
23+
__C __export infiniopStatus_t infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc);
24+
25+
#endif

operatorspy/tests/relu.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
from enum import Enum, auto
21+
import torch
22+
23+
# constant for control whether profile the pytorch and lib functions
24+
# NOTE: need to manually add synchronization function to the lib function,
25+
# e.g., cudaDeviceSynchronize() for CUDA
26+
PROFILE = False
27+
NUM_PRERUN = 10
28+
NUM_ITERATIONS = 1000
29+
30+
31+
class Inplace(Enum):
32+
OUT_OF_PLACE = auto()
33+
INPLACE_X = auto()
34+
35+
36+
class ReluDescriptor(Structure):
37+
_fields_ = [("device", c_int32)]
38+
39+
40+
infiniopReluDescriptor_t = POINTER(ReluDescriptor)
41+
42+
43+
def relu(x):
44+
if PROFILE:
45+
ans = torch.nn.functional.relu(x).to(x.dtype)
46+
torch.cuda.synchronize()
47+
return ans
48+
return torch.nn.functional.relu(x).to(x.dtype)
49+
50+
51+
def test(
52+
lib,
53+
handle,
54+
torch_device,
55+
tensor_shape,
56+
tensor_dtype=torch.float16,
57+
inplace=Inplace.OUT_OF_PLACE,
58+
):
59+
print(
60+
f"Testing Relu on {torch_device} with tensor_shape:{tensor_shape} dtype:{tensor_dtype} inplace: {inplace.name}"
61+
)
62+
63+
x = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) * 2 - 1
64+
y = torch.rand(tensor_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else x
65+
66+
for i in range(NUM_PRERUN if PROFILE else 1):
67+
ans = relu(x)
68+
if PROFILE:
69+
start_time = time.time()
70+
for i in range(NUM_ITERATIONS):
71+
_ = relu(x)
72+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
73+
print(f"pytorch time: {elapsed :6f}")
74+
75+
x_tensor = to_tensor(x, lib)
76+
y_tensor = to_tensor(y, lib) if inplace == Inplace.OUT_OF_PLACE else x_tensor
77+
descriptor = infiniopReluDescriptor_t()
78+
79+
check_error(
80+
lib.infiniopCreateReluDescriptor(
81+
handle,
82+
ctypes.byref(descriptor),
83+
y_tensor.descriptor,
84+
x_tensor.descriptor,
85+
)
86+
)
87+
for i in range(NUM_PRERUN if PROFILE else 1):
88+
lib.infiniopRelu(
89+
descriptor, y_tensor.data, x_tensor.data, None
90+
)
91+
if PROFILE:
92+
start_time = time.time()
93+
for i in range(NUM_ITERATIONS):
94+
lib.infiniopRelu(
95+
descriptor, y_tensor.data, x_tensor.data, None
96+
)
97+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
98+
print(f" lib time: {elapsed :6f}")
99+
100+
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
101+
check_error(lib.infiniopDestroyReluDescriptor(descriptor))
102+
103+
104+
def test_cpu(lib, test_cases):
105+
device = DeviceEnum.DEVICE_CPU
106+
handle = create_handle(lib, device)
107+
for tensor_shape, inplace in test_cases:
108+
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
109+
test(lib, handle, "cpu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
110+
destroy_handle(lib, handle)
111+
112+
113+
def test_cuda(lib, test_cases):
114+
device = DeviceEnum.DEVICE_CUDA
115+
handle = create_handle(lib, device)
116+
for tensor_shape, inplace in test_cases:
117+
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
118+
test(lib, handle, "cuda", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
119+
destroy_handle(lib, handle)
120+
121+
122+
def test_bang(lib, test_cases):
123+
import torch_mlu
124+
125+
device = DeviceEnum.DEVICE_BANG
126+
handle = create_handle(lib, device)
127+
for tensor_shape, inplace in test_cases:
128+
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
129+
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
130+
destroy_handle(lib, handle)
131+
132+
133+
if __name__ == "__main__":
134+
test_cases = [
135+
# tensor_shape, inplace
136+
((), Inplace.OUT_OF_PLACE),
137+
((), Inplace.INPLACE_X),
138+
((1, 3), Inplace.OUT_OF_PLACE),
139+
((3, 3), Inplace.OUT_OF_PLACE),
140+
((3, 3, 13, 9, 17), Inplace.INPLACE_X),
141+
((32, 20, 512), Inplace.INPLACE_X),
142+
((33, 333, 333), Inplace.OUT_OF_PLACE),
143+
((32, 256, 112, 112), Inplace.OUT_OF_PLACE),
144+
]
145+
args = get_args()
146+
lib = open_lib()
147+
lib.infiniopCreateReluDescriptor.restype = c_int32
148+
lib.infiniopCreateReluDescriptor.argtypes = [
149+
infiniopHandle_t,
150+
POINTER(infiniopReluDescriptor_t),
151+
infiniopTensorDescriptor_t,
152+
infiniopTensorDescriptor_t,
153+
]
154+
lib.infiniopRelu.restype = c_int32
155+
lib.infiniopRelu.argtypes = [
156+
infiniopReluDescriptor_t,
157+
c_void_p,
158+
c_void_p,
159+
c_void_p,
160+
]
161+
lib.infiniopDestroyReluDescriptor.restype = c_int32
162+
lib.infiniopDestroyReluDescriptor.argtypes = [
163+
infiniopReluDescriptor_t,
164+
]
165+
166+
if args.cpu:
167+
test_cpu(lib, test_cases)
168+
if args.cuda:
169+
test_cuda(lib, test_cases)
170+
if args.bang:
171+
test_bang(lib, test_cases)
172+
if not (args.cpu or args.cuda or args.bang):
173+
test_cpu(lib, test_cases)
174+
print("\033[92mTest passed!\033[0m")
175+

src/ops/relu/cpu/relu_cpu.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include "relu_cpu.h"
2+
#include "../../../devices/cpu/common_cpu.h"
3+
#include "../../utils.h"
4+
5+
infiniopStatus_t cpuCreateReluDescriptor(infiniopHandle_t,
6+
ReluCpuDescriptor_t *desc_ptr,
7+
infiniopTensorDescriptor_t y,
8+
infiniopTensorDescriptor_t x) {
9+
uint64_t ndim = y->ndim;
10+
if (ndim != x->ndim) {
11+
return STATUS_BAD_TENSOR_SHAPE;
12+
}
13+
for (size_t i = 0; i < ndim; ++i) {
14+
if (y->shape[i] != x->shape[i]) {
15+
return STATUS_BAD_TENSOR_SHAPE;
16+
}
17+
}
18+
if (!is_contiguous(y) || !is_contiguous(x)) {
19+
return STATUS_BAD_TENSOR_STRIDES;
20+
}
21+
if (y->dt != F16 && y->dt != F32) {
22+
return STATUS_BAD_TENSOR_DTYPE;
23+
}
24+
if (y->dt != x->dt) {
25+
return STATUS_BAD_TENSOR_DTYPE;
26+
}
27+
28+
uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies<uint64_t>());
29+
30+
*desc_ptr = new ReluCpuDescriptor{
31+
DevCpu,
32+
y->dt,
33+
data_size,
34+
};
35+
36+
return STATUS_SUCCESS;
37+
}
38+
39+
infiniopStatus_t cpuDestroyReluDescriptor(ReluCpuDescriptor_t desc) {
40+
delete desc;
41+
return STATUS_SUCCESS;
42+
}
43+
44+
template<typename Tdata>
45+
infiniopStatus_t relu_cpu(ReluCpuDescriptor_t desc, void *y, void const *x) {
46+
auto x_ = reinterpret_cast<Tdata const *>(x);
47+
auto y_ = reinterpret_cast<Tdata *>(y);
48+
49+
#pragma omp parallel for
50+
for (uint64_t i = 0; i < desc->data_size; ++i) {
51+
if constexpr (std::is_same<Tdata, uint16_t>::value) {
52+
float x_f32 = f16_to_f32(x_[i]);
53+
y_[i] = f32_to_f16(x_f32 < 0 ? 0 : x_f32);
54+
} else {
55+
Tdata x_val = x_[i];
56+
y_[i] = x_val < 0 ? 0 : x_val;
57+
}
58+
}
59+
return STATUS_SUCCESS;
60+
}
61+
62+
infiniopStatus_t cpuRelu(ReluCpuDescriptor_t desc,
63+
void *y, void const *x,
64+
void *stream) {
65+
if (desc->dtype == F16) {
66+
return relu_cpu<uint16_t>(desc, y, x);
67+
}
68+
if (desc->dtype == F32) {
69+
return relu_cpu<float>(desc, y, x);
70+
}
71+
return STATUS_BAD_TENSOR_DTYPE;
72+
}

src/ops/relu/cpu/relu_cpu.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef __CPU_RELU_H__
2+
#define __CPU_RELU_H__
3+
4+
#include "operators.h"
5+
#include <numeric>
6+
7+
struct ReluCpuDescriptor {
8+
Device device;
9+
DT dtype;
10+
uint64_t data_size;
11+
};
12+
13+
typedef struct ReluCpuDescriptor *ReluCpuDescriptor_t;
14+
15+
infiniopStatus_t cpuCreateReluDescriptor(infiniopHandle_t,
16+
ReluCpuDescriptor_t *,
17+
infiniopTensorDescriptor_t y,
18+
infiniopTensorDescriptor_t x);
19+
20+
infiniopStatus_t cpuRelu(ReluCpuDescriptor_t desc,
21+
void *y, void const *x,
22+
void *stream);
23+
24+
infiniopStatus_t cpuDestroyReluDescriptor(ReluCpuDescriptor_t desc);
25+
26+
#endif

src/ops/relu/cuda/relu.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "relu.cuh"
2+
#include "../../../devices/cuda/common_cuda.h"
3+
#include "../../utils.h"
4+
5+
infiniopStatus_t cudaCreateReluDescriptor(CudaHandle_t handle,
6+
ReluCudaDescriptor_t *desc_ptr,
7+
infiniopTensorDescriptor_t y,
8+
infiniopTensorDescriptor_t x) {
9+
uint64_t ndim = y->ndim;
10+
if (ndim != x->ndim) {
11+
return STATUS_BAD_TENSOR_SHAPE;
12+
}
13+
for (size_t i = 0; i < ndim; ++i) {
14+
if (y->shape[i] != x->shape[i]) {
15+
return STATUS_BAD_TENSOR_SHAPE;
16+
}
17+
}
18+
if (!is_contiguous(y) || !is_contiguous(x)) {
19+
return STATUS_BAD_TENSOR_STRIDES;
20+
}
21+
if (y->dt != F16 && y->dt != F32) {
22+
return STATUS_BAD_TENSOR_DTYPE;
23+
}
24+
if (y->dt != x->dt) {
25+
return STATUS_BAD_TENSOR_DTYPE;
26+
}
27+
28+
uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies<uint64_t>());
29+
30+
*desc_ptr = new ReluCudaDescriptor{
31+
DevNvGpu,
32+
y->dt,
33+
handle->device_id,
34+
ndim,
35+
data_size,
36+
static_cast<uint64_t>(handle->prop.maxGridSize[0]),
37+
};
38+
39+
return STATUS_SUCCESS;
40+
}
41+
42+
infiniopStatus_t cudaDestroyReluDescriptor(ReluCudaDescriptor_t desc) {
43+
delete desc;
44+
return STATUS_SUCCESS;
45+
}

0 commit comments

Comments
 (0)