Skip to content

Commit 1e74a5f

Browse files
Merge pull request #93 from PanZezhong1725/add_gemm
Add GEMM & Expand
2 parents 17ce764 + 82de992 commit 1e74a5f

File tree

24 files changed

+1197
-93
lines changed

24 files changed

+1197
-93
lines changed

include/infini_operators.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "ops/add/add.h"
33
#include "ops/attention/attention.h"
44
#include "ops/causal_softmax/causal_softmax.h"
5+
#include "ops/expand/expand.h"
6+
#include "ops/gemm/gemm.h"
57
#include "ops/conv/conv.h"
68
#include "ops/matmul/matmul.h"
79
#include "ops/mlp/mlp.h"

include/ops/expand/expand.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef EXPAND_H
2+
#define EXPAND_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct ExpandDescriptor {
8+
Device device;
9+
} ExpandDescriptor;
10+
11+
typedef ExpandDescriptor *infiniopExpandDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateExpandDescriptor(infiniopHandle_t handle,
14+
infiniopExpandDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x);
17+
18+
__C __export infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc,
19+
void *y,
20+
void const *x,
21+
void *stream);
22+
23+
__C __export infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t desc);
24+
25+
#endif

include/ops/gemm/gemm.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef GEMM_H
2+
#define GEMM_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct GEMMDescriptor {
8+
Device device;
9+
} GEMMDescriptor;
10+
11+
typedef GEMMDescriptor *infiniopGEMMDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateGEMMDescriptor(infiniopHandle_t handle,
14+
infiniopGEMMDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y_desc,
16+
infiniopTensorDescriptor_t a_desc,
17+
infiniopTensorDescriptor_t b_desc,
18+
infiniopTensorDescriptor_t c_desc,
19+
float alpha,
20+
float beta,
21+
bool transA,
22+
bool transB);
23+
24+
__C __export infiniopStatus_t infiniopGetGEMMWorkspaceSize(infiniopGEMMDescriptor_t desc, uint64_t *size);
25+
26+
__C __export infiniopStatus_t infiniopGEMM(infiniopGEMMDescriptor_t desc,
27+
void *workspace,
28+
uint64_t workspace_size,
29+
void *y,
30+
void const *a,
31+
void const *b,
32+
void const *c,
33+
void *stream);
34+
35+
__C __export infiniopStatus_t infiniopDestroyGEMMDescriptor(infiniopGEMMDescriptor_t desc);
36+
#endif

operatorspy/tests/expand.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
rearrange_tensor,
18+
)
19+
20+
from operatorspy.tests.test_utils import get_args
21+
import torch
22+
23+
# constant for control whether profile the pytorch and lib functions
24+
# NOTE: need to manually add synchronization function to the lib function,
25+
# e.g., cudaDeviceSynchronize() for CUDA
26+
PROFILE = False
27+
NUM_PRERUN = 10
28+
NUM_ITERATIONS = 1000
29+
30+
31+
class ExpandDescriptor(Structure):
32+
_fields_ = [("device", c_int32)]
33+
34+
35+
infiniopExpandDescriptor_t = POINTER(ExpandDescriptor)
36+
37+
38+
def expand(x, y):
39+
if PROFILE:
40+
ans = x.expand_as(y).clone()
41+
torch.cuda.synchronize()
42+
return ans
43+
return x.expand_as(y)
44+
45+
46+
def test(
47+
lib,
48+
handle,
49+
torch_device,
50+
y_shape,
51+
x_shape,
52+
y_stride=None,
53+
x_stride=None,
54+
tensor_dtype=torch.float16,
55+
):
56+
print(
57+
f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}"
58+
)
59+
60+
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
61+
y = torch.rand(y_shape, dtype=tensor_dtype).to(torch_device)
62+
63+
if x_stride is not None:
64+
x = rearrange_tensor(x, x_stride)
65+
if y_stride is not None:
66+
y = rearrange_tensor(y, y_stride)
67+
68+
for i in range(NUM_PRERUN if PROFILE else 1):
69+
ans = expand(x, y)
70+
if PROFILE:
71+
start_time = time.time()
72+
for i in range(NUM_ITERATIONS):
73+
_ = expand(x, y)
74+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
75+
print(f"pytorch time: {elapsed :6f}")
76+
77+
x_tensor = to_tensor(x, lib)
78+
y_tensor = to_tensor(y, lib)
79+
descriptor = infiniopExpandDescriptor_t()
80+
81+
check_error(
82+
lib.infiniopCreateExpandDescriptor(
83+
handle,
84+
ctypes.byref(descriptor),
85+
y_tensor.descriptor,
86+
x_tensor.descriptor,
87+
)
88+
)
89+
90+
for i in range(NUM_PRERUN if PROFILE else 1):
91+
lib.infiniopExpand(
92+
descriptor, y_tensor.data, x_tensor.data, None
93+
)
94+
if PROFILE:
95+
start_time = time.time()
96+
for i in range(NUM_ITERATIONS):
97+
lib.infiniopExpand(
98+
descriptor, y_tensor.data, x_tensor.data, None
99+
)
100+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
101+
print(f" lib time: {elapsed :6f}")
102+
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
103+
check_error(lib.infiniopDestroyExpandDescriptor(descriptor))
104+
105+
106+
def test_cpu(lib, test_cases):
107+
device = DeviceEnum.DEVICE_CPU
108+
handle = create_handle(lib, device)
109+
for y_shape, x_shape, y_stride, x_stride in test_cases:
110+
test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
111+
test(lib, handle, "cpu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
112+
destroy_handle(lib, handle)
113+
114+
115+
def test_cuda(lib, test_cases):
116+
device = DeviceEnum.DEVICE_CUDA
117+
handle = create_handle(lib, device)
118+
for y_shape, x_shape, y_stride, x_stride in test_cases:
119+
test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
120+
test(lib, handle, "cuda", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
121+
destroy_handle(lib, handle)
122+
123+
124+
def test_bang(lib, test_cases):
125+
import torch_mlu
126+
127+
device = DeviceEnum.DEVICE_BANG
128+
handle = create_handle(lib, device)
129+
for y_shape, x_shape, y_stride, x_stride in test_cases:
130+
test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
131+
test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
132+
destroy_handle(lib, handle)
133+
134+
135+
if __name__ == "__main__":
136+
test_cases = [
137+
# y_shape, x_shape, y_stride, x_stride
138+
((), (), None, None),
139+
((3, 3), (1,), None, None),
140+
((5, 4, 3), (4, 3,), None, (6, 1)),
141+
((99, 111), (111,), None, None),
142+
((2, 4, 3), (1, 3), None, None),
143+
((2, 20, 3), (2, 1, 3), None, None),
144+
((2, 3, 4, 5), (5,), None, None),
145+
((3, 2, 4, 5), (3, 2, 1, 1), None, None),
146+
((32, 256, 112, 112), (32, 256, 112, 1), None, None),
147+
]
148+
args = get_args()
149+
lib = open_lib()
150+
lib.infiniopCreateExpandDescriptor.restype = c_int32
151+
lib.infiniopCreateExpandDescriptor.argtypes = [
152+
infiniopHandle_t,
153+
POINTER(infiniopExpandDescriptor_t),
154+
infiniopTensorDescriptor_t,
155+
infiniopTensorDescriptor_t,
156+
]
157+
lib.infiniopExpand.restype = c_int32
158+
lib.infiniopExpand.argtypes = [
159+
infiniopExpandDescriptor_t,
160+
c_void_p,
161+
c_void_p,
162+
c_void_p,
163+
]
164+
lib.infiniopDestroyExpandDescriptor.restype = c_int32
165+
lib.infiniopDestroyExpandDescriptor.argtypes = [
166+
infiniopExpandDescriptor_t,
167+
]
168+
169+
if args.cpu:
170+
test_cpu(lib, test_cases)
171+
if args.cuda:
172+
test_cuda(lib, test_cases)
173+
if args.bang:
174+
test_bang(lib, test_cases)
175+
if not (args.cpu or args.cuda or args.bang):
176+
test_cpu(lib, test_cases)
177+
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)