Skip to content

Commit ef2e741

Browse files
authored
Merge branch 'dev' into add_gemm
2 parents 514cc27 + 49ee9f2 commit ef2e741

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2403
-592
lines changed

include/infini_operators.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
#include "ops/causal_softmax/causal_softmax.h"
55
#include "ops/expand/expand.h"
66
#include "ops/gemm/gemm.h"
7+
#include "ops/conv/conv.h"
78
#include "ops/matmul/matmul.h"
89
#include "ops/mlp/mlp.h"
910
#include "ops/random_sample/random_sample.h"
1011
#include "ops/rearrange/rearrange.h"
12+
#include "ops/relu/relu.h"
1113
#include "ops/rms_norm/rms_norm.h"
1214
#include "ops/rotary_embedding/rotary_embedding.h"
1315
#include "ops/swiglu/swiglu.h"

include/ops/conv/conv.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#ifndef CONV_H
2+
#define CONV_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ConvDescriptor {
8+
Device device;
9+
} ConvDescriptor;
10+
11+
typedef ConvDescriptor *infiniopConvDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateConvDescriptor(infiniopHandle_t handle,
14+
infiniopConvDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x,
17+
infiniopTensorDescriptor_t w,
18+
void *pads,
19+
void *strides,
20+
void *dilations,
21+
uint64_t n);
22+
23+
__C __export infiniopStatus_t infiniopGetConvWorkspaceSize(infiniopConvDescriptor_t desc, uint64_t *size);
24+
25+
__C __export infiniopStatus_t infiniopConv(infiniopConvDescriptor_t desc, void *workspace, uint64_t workspace_size, void *y, void const *x, void const *w, void *stream);
26+
27+
__C __export infiniopStatus_t infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc);
28+
29+
30+
#endif

include/ops/relu/relu.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef RELU_H
2+
#define RELU_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ReluDescriptor {
8+
Device device;
9+
} ReluDescriptor;
10+
11+
typedef ReluDescriptor *infiniopReluDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateReluDescriptor(infiniopHandle_t handle,
14+
infiniopReluDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x);
17+
18+
__C __export infiniopStatus_t infiniopRelu(infiniopReluDescriptor_t desc,
19+
void *y,
20+
void const *x,
21+
void *stream);
22+
23+
__C __export infiniopStatus_t infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc);
24+
25+
#endif

include/ops/rms_norm/rms_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ __C __export infiniopStatus_t infiniopCreateRMSNormDescriptor(
2121
__C __export infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, uint64_t *size);
2222

2323
__C __export infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, uint64_t workspace_size,
24-
void *y, void *x, void *w, void *stream);
24+
void *y, void const *x, void const *w, void *stream);
2525

2626
__C __export infiniopStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc);
2727

include/ops/swiglu/swiglu.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,4 @@ __C __export infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
2424

2525
__C __export infiniopStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc);
2626

27-
// // @deprecated
28-
// __C __export void *createSwigluDescriptor(Device, void *config);
29-
// // @deprecated
30-
// __C __export void destroySwigluDescriptor(SwigluDescriptor *descriptor);
31-
// // @deprecated
32-
// __C __export void swiglu(SwigluDescriptor *descriptor, Tensor gate, Tensor up, void *stream);
33-
3427
#endif

include/tensor.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,4 @@ struct TensorDescriptor {
1717

1818
typedef struct TensorDescriptor *infiniopTensorDescriptor_t;
1919

20-
// @depricated
21-
struct TensorTuple {
22-
infiniopTensorDescriptor_t const layout;
23-
void *data;
24-
};
25-
// @depricated
26-
typedef struct TensorTuple Tensor;
27-
2820
#endif// __TENSOR_H__

operatorspy/tests/conv.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
import torch
21+
import math
22+
import ctypes
23+
from torch.nn import functional as F
24+
from typing import List, Tuple
25+
26+
# constant for control whether profile the pytorch and lib functions
27+
# NOTE: need to manually add synchronization function to the lib function,
28+
# e.g., cudaDeviceSynchronize() for CUDA
29+
PROFILE = False
30+
NUM_PRERUN = 10
31+
NUM_ITERATIONS = 1000
32+
33+
34+
class ConvDescriptor(Structure):
35+
_fields_ = [("device", c_int32)]
36+
37+
38+
infiniopConvDescriptor_t = POINTER(ConvDescriptor)
39+
40+
41+
def conv(x, w, stride, padding, dilation):
42+
match len(x.shape) - 2:
43+
case 1:
44+
return F.conv1d(
45+
x, w, stride=stride, padding=padding, dilation=dilation
46+
)
47+
case 2:
48+
return F.conv2d(
49+
x, w, stride=stride, padding=padding, dilation=dilation
50+
)
51+
case 3:
52+
return F.conv3d(
53+
x, w, stride=stride, padding=padding, dilation=dilation
54+
)
55+
case _:
56+
print("Error: Pytorch -> Unsupported tensor dimension")
57+
return None
58+
59+
60+
# infer the shape of the output given the inputs for a N-ary convolution
61+
def inferShape(
62+
x_shape: List[int],
63+
w_shape: List[int],
64+
pads: List[int],
65+
strides: List[int],
66+
dilations: List[int],
67+
) -> Tuple[int, ...]:
68+
assert (
69+
len(x_shape) == len(w_shape) == len(pads) + 2 == len(dilations) + 2 == len(strides) + 2
70+
), "x and w should have the same length; pads, strides, and dilatinos should have the same length; the length of pads should be that of x - 2"
71+
output_dims = [
72+
math.floor(
73+
(x_shape[i+2] + 2 * pads[i] - dilations[i] * (w_shape[i+2] - 1) - 1)
74+
/ strides[i]
75+
+ 1
76+
)
77+
for i in range(len(pads))
78+
]
79+
return (x_shape[0], w_shape[0]) + tuple(output_dims)
80+
81+
82+
# convert a python tuple to a ctype void pointer
83+
def tuple_to_void_p(py_tuple: Tuple):
84+
array = ctypes.c_int64 * len(py_tuple)
85+
data_array = array(*py_tuple)
86+
return ctypes.cast(data_array, ctypes.c_void_p)
87+
88+
89+
def test(
90+
lib,
91+
handle,
92+
torch_device,
93+
x_shape,
94+
w_shape,
95+
pads,
96+
strides,
97+
dilations,
98+
tensor_stride=None,
99+
tensor_dtype=torch.float16,
100+
):
101+
assert len(pads) == len(strides) == len(dilations)
102+
print(
103+
f"Testing Conv on {torch_device} with x_shape: {x_shape}, w_shape: {w_shape}, b_shape: {w_shape[0]}, pads: {pads}, strides: {strides}, dilations: {dilations}, x_stride: {tensor_stride} dtype:{tensor_dtype}"
104+
)
105+
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
106+
w = torch.rand(w_shape, dtype=tensor_dtype).to(torch_device)
107+
y = torch.zeros(
108+
inferShape(x.shape, w.shape, pads, strides, dilations), dtype=tensor_dtype
109+
).to(torch_device)
110+
111+
for i in range(NUM_PRERUN if PROFILE else 1):
112+
ans = conv(x, w, strides, pads, dilations)
113+
if PROFILE:
114+
start_time = time.time()
115+
for i in range(NUM_ITERATIONS):
116+
_ = conv(x, w, strides, pads, dilations)
117+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
118+
print(f"pytorch time: {elapsed :6f}")
119+
120+
121+
x_tensor = to_tensor(x, lib)
122+
w_tensor = to_tensor(w, lib)
123+
y_tensor = to_tensor(y, lib)
124+
descriptor = infiniopConvDescriptor_t()
125+
126+
check_error(
127+
lib.infiniopCreateConvDescriptor(
128+
handle,
129+
ctypes.byref(descriptor),
130+
y_tensor.descriptor,
131+
x_tensor.descriptor,
132+
w_tensor.descriptor,
133+
tuple_to_void_p(pads),
134+
tuple_to_void_p(strides),
135+
tuple_to_void_p(dilations),
136+
len(pads),
137+
)
138+
)
139+
workspaceSize = ctypes.c_uint64(0)
140+
check_error(
141+
lib.infiniopGetConvWorkspaceSize(descriptor, ctypes.byref(workspaceSize))
142+
)
143+
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(torch_device)
144+
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
145+
146+
for i in range(NUM_PRERUN if PROFILE else 1):
147+
lib.infiniopConv(
148+
descriptor,
149+
workspace_ptr,
150+
workspaceSize,
151+
y_tensor.data,
152+
x_tensor.data,
153+
w_tensor.data,
154+
None,
155+
)
156+
if PROFILE:
157+
start_time = time.time()
158+
for i in range(NUM_ITERATIONS):
159+
lib.infiniopConv(
160+
descriptor,
161+
workspace_ptr,
162+
workspaceSize,
163+
y_tensor.data,
164+
x_tensor.data,
165+
w_tensor.data,
166+
None,
167+
)
168+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
169+
print(f" lib time: {elapsed :6f}")
170+
171+
if (tensor_dtype == torch.float16):
172+
assert torch.allclose(y, ans, atol=0, rtol=1e-2)
173+
else:
174+
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
175+
check_error(lib.infiniopDestroyConvDescriptor(descriptor))
176+
177+
178+
def test_cpu(lib, test_cases):
179+
device = DeviceEnum.DEVICE_CPU
180+
handle = create_handle(lib, device)
181+
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
182+
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
183+
test(lib, handle, "cpu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
184+
destroy_handle(lib, handle)
185+
186+
187+
def test_cuda(lib, test_cases):
188+
device = DeviceEnum.DEVICE_CUDA
189+
handle = create_handle(lib, device)
190+
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
191+
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
192+
test(lib, handle, "cuda", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
193+
destroy_handle(lib, handle)
194+
195+
196+
def test_bang(lib, test_cases):
197+
import torch_mlu
198+
199+
device = DeviceEnum.DEVICE_BANG
200+
handle = create_handle(lib, device)
201+
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
202+
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
203+
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
204+
destroy_handle(lib, handle)
205+
206+
207+
if __name__ == "__main__":
208+
test_cases = [
209+
# x_shape, w_shape, pads, strides, dilations, x_strides
210+
(
211+
(32, 3, 4),
212+
(32, 3, 5),
213+
(1,),
214+
(1,),
215+
(1,),
216+
None,
217+
),
218+
(
219+
(1, 3, 4, 4),
220+
(2, 3, 3, 3),
221+
(1, 1),
222+
(1, 2),
223+
(2, 1),
224+
None,
225+
),
226+
(
227+
(32, 3, 128, 128),
228+
(64, 3, 5, 5),
229+
(2, 2),
230+
(2, 2),
231+
(1, 1),
232+
None,
233+
),
234+
(
235+
(1, 1, 4, 4, 4),
236+
(1, 1, 5, 5, 5),
237+
(1, 1, 1),
238+
(1, 1, 1),
239+
(1, 1, 1),
240+
None,
241+
),
242+
(
243+
(32, 3, 32, 32, 32),
244+
(64, 3, 5, 5, 5),
245+
(3, 2, 2),
246+
(4, 3, 3),
247+
(2, 2, 1),
248+
None,
249+
),
250+
]
251+
args = get_args()
252+
lib = open_lib()
253+
lib.infiniopCreateConvDescriptor.restype = c_int32
254+
lib.infiniopCreateConvDescriptor.argtypes = [
255+
infiniopHandle_t,
256+
POINTER(infiniopConvDescriptor_t),
257+
infiniopTensorDescriptor_t,
258+
infiniopTensorDescriptor_t,
259+
infiniopTensorDescriptor_t,
260+
c_void_p,
261+
c_void_p,
262+
c_void_p,
263+
c_uint64,
264+
]
265+
lib.infiniopConv.restype = c_int32
266+
lib.infiniopConv.argtypes = [
267+
infiniopConvDescriptor_t,
268+
c_void_p,
269+
c_uint64,
270+
c_void_p,
271+
c_void_p,
272+
c_void_p,
273+
c_void_p,
274+
]
275+
lib.infiniopDestroyConvDescriptor.restype = c_int32
276+
lib.infiniopDestroyConvDescriptor.argtypes = [
277+
infiniopConvDescriptor_t,
278+
]
279+
280+
if args.cpu:
281+
test_cpu(lib, test_cases)
282+
if args.cuda:
283+
test_cuda(lib, test_cases)
284+
if args.bang:
285+
test_bang(lib, test_cases)
286+
if not (args.cpu or args.cuda or args.bang):
287+
test_cpu(lib, test_cases)
288+
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)