Skip to content

Commit fdbf030

Browse files
Merge pull request #89 from PanZezhong1725/add_global_avg_pool
Add Global Average Pool
2 parents 1e74a5f + 5e26117 commit fdbf030

File tree

9 files changed

+1089
-0
lines changed

9 files changed

+1089
-0
lines changed

include/infini_operators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ops/add/add.h"
33
#include "ops/attention/attention.h"
44
#include "ops/causal_softmax/causal_softmax.h"
5+
#include "ops/global_avg_pool/global_avg_pool.h"
56
#include "ops/expand/expand.h"
67
#include "ops/gemm/gemm.h"
78
#include "ops/conv/conv.h"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef GLOBAL_AVG_POOL_H
2+
#define GLOBAL_AVG_POOL_H
3+
4+
#include "../../export.h"
5+
#include "../../operators.h"
6+
7+
typedef struct GlobalAvgPoolDescriptor {
8+
Device device;
9+
} GlobalAvgPoolDescriptor;
10+
11+
typedef GlobalAvgPoolDescriptor *infiniopGlobalAvgPoolDescriptor_t;
12+
13+
__C __export infiniopStatus_t infiniopCreateGlobalAvgPoolDescriptor(infiniopHandle_t handle,
14+
infiniopGlobalAvgPoolDescriptor_t *desc_ptr,
15+
infiniopTensorDescriptor_t y,
16+
infiniopTensorDescriptor_t x);
17+
18+
__C __export infiniopStatus_t infiniopGetGlobalAvgPoolWorkspaceSize(infiniopGlobalAvgPoolDescriptor_t desc, uint64_t *size);
19+
20+
__C __export infiniopStatus_t infiniopGlobalAvgPool(infiniopGlobalAvgPoolDescriptor_t desc,
21+
void *workspace, uint64_t workspace_size,
22+
void *y, void const *x, void *stream);
23+
24+
__C __export infiniopStatus_t infiniopDestroyGlobalAvgPoolDescriptor(infiniopGlobalAvgPoolDescriptor_t desc);
25+
26+
#endif
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
2+
import ctypes
3+
import sys
4+
import os
5+
import time
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8+
from operatorspy import (
9+
open_lib,
10+
to_tensor,
11+
DeviceEnum,
12+
infiniopHandle_t,
13+
infiniopTensorDescriptor_t,
14+
create_handle,
15+
destroy_handle,
16+
check_error,
17+
)
18+
19+
from operatorspy.tests.test_utils import get_args
20+
import torch, time
21+
22+
# constant for control whether profile the pytorch and lib functions
23+
# NOTE: need to manually add synchronization function to the lib function,
24+
# e.g., cudaDeviceSynchronize() for CUDA
25+
PROFILE = False
26+
NUM_PRERUN = 10
27+
NUM_ITERATIONS = 1000
28+
29+
30+
class GlobalAvgPoolDescriptor(Structure):
31+
_fields_ = [("device", c_int32)]
32+
33+
34+
infiniopGlobalAvgPoolDescriptor_t = POINTER(GlobalAvgPoolDescriptor)
35+
36+
37+
def inferShape(x):
38+
return x.shape[:2] + (1,) * (x.dim() - 2)
39+
40+
41+
def globalAvgPool(x):
42+
y = torch.mean(x, dim=tuple(range(2, x.dim())), keepdim=True)
43+
if PROFILE:
44+
torch.cuda.synchronize()
45+
return y.view(*inferShape(x))
46+
47+
48+
def test(
49+
lib,
50+
handle,
51+
torch_device,
52+
x_shape,
53+
tensor_dtype=torch.float16,
54+
):
55+
print(
56+
f"Testing GlobalAvgPool on {torch_device} with input tensor_shape: {x_shape} dtype: {tensor_dtype}"
57+
)
58+
59+
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
60+
y = torch.zeros(inferShape(x), dtype=tensor_dtype).to(torch_device)
61+
62+
for i in range(NUM_PRERUN if PROFILE else 1):
63+
ans = globalAvgPool(x)
64+
if PROFILE:
65+
start_time = time.time()
66+
for i in range(NUM_ITERATIONS):
67+
_ = globalAvgPool(x)
68+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
69+
print(f"pytorch time: {elapsed :6f}")
70+
71+
x_tensor = to_tensor(x, lib)
72+
y_tensor = to_tensor(y, lib)
73+
descriptor = infiniopGlobalAvgPoolDescriptor_t()
74+
75+
check_error(
76+
lib.infiniopCreateGlobalAvgPoolDescriptor(
77+
handle,
78+
ctypes.byref(descriptor),
79+
y_tensor.descriptor,
80+
x_tensor.descriptor,
81+
)
82+
)
83+
workspaceSize = ctypes.c_uint64(0)
84+
check_error(
85+
lib.infiniopGetGlobalAvgPoolWorkspaceSize(
86+
descriptor, ctypes.byref(workspaceSize)
87+
)
88+
)
89+
workspace = torch.zeros(int(workspaceSize.value), dtype=torch.uint8).to(
90+
torch_device
91+
)
92+
workspace_ptr = ctypes.cast(workspace.data_ptr(), ctypes.POINTER(ctypes.c_uint8))
93+
94+
95+
for i in range(NUM_PRERUN if PROFILE else 1):
96+
check_error(
97+
lib.infiniopGlobalAvgPool(
98+
descriptor, workspace_ptr, workspaceSize, y_tensor.data, x_tensor.data, None
99+
)
100+
)
101+
if PROFILE:
102+
start_time = time.time()
103+
for i in range(NUM_ITERATIONS):
104+
lib.infiniopGlobalAvgPool(
105+
descriptor, workspace_ptr, workspaceSize, y_tensor.data, x_tensor.data, None
106+
)
107+
elapsed = (time.time() - start_time) / NUM_ITERATIONS
108+
print(f" lib time: {elapsed :6f}")
109+
110+
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
111+
check_error(lib.infiniopDestroyGlobalAvgPoolDescriptor(descriptor))
112+
113+
114+
def test_cpu(lib, test_cases):
115+
device = DeviceEnum.DEVICE_CPU
116+
handle = create_handle(lib, device)
117+
for x_shape in test_cases:
118+
test(lib, handle, "cpu", x_shape, tensor_dtype=torch.float16)
119+
test(lib, handle, "cpu", x_shape, tensor_dtype=torch.float32)
120+
destroy_handle(lib, handle)
121+
122+
123+
def test_cuda(lib, test_cases):
124+
device = DeviceEnum.DEVICE_CUDA
125+
handle = create_handle(lib, device)
126+
for x_shape in test_cases:
127+
test(lib, handle, "cuda", x_shape, tensor_dtype=torch.float16)
128+
test(lib, handle, "cuda", x_shape, tensor_dtype=torch.float32)
129+
destroy_handle(lib, handle)
130+
131+
132+
def test_bang(lib, test_cases):
133+
import torch_mlu
134+
135+
device = DeviceEnum.DEVICE_BANG
136+
handle = create_handle(lib, device)
137+
for x_shape in test_cases:
138+
test(lib, handle, "mlu", x_shape, tensor_dtype=torch.float16)
139+
test(lib, handle, "mlu", x_shape, tensor_dtype=torch.float32)
140+
destroy_handle(lib, handle)
141+
142+
143+
if __name__ == "__main__":
144+
test_cases = [
145+
# x_shape
146+
((1, 3, 3)),
147+
((1, 3, 1, 1, 3)),
148+
((1, 3, 1, 1, 257)),
149+
((1, 2, 1, 1, 514)),
150+
((1, 3, 1, 1, 1025)),
151+
((32, 256, 1, 112, 112)),
152+
((2, 3, 2048000)),
153+
((2, 1, 10243)),
154+
((2, 20, 100)),
155+
((3, 33, 333)),
156+
((32, 20, 512)),
157+
((3, 3, 11, 11, 11, 3, 2)),
158+
((32, 256, 1, 112, 112)),
159+
((32, 256, 112, 112)),
160+
]
161+
args = get_args()
162+
lib = open_lib()
163+
lib.infiniopCreateGlobalAvgPoolDescriptor.restype = c_int32
164+
lib.infiniopCreateGlobalAvgPoolDescriptor.argtypes = [
165+
infiniopHandle_t,
166+
POINTER(infiniopGlobalAvgPoolDescriptor_t),
167+
infiniopTensorDescriptor_t,
168+
infiniopTensorDescriptor_t,
169+
]
170+
lib.infiniopGetGlobalAvgPoolWorkspaceSize.restype = c_int32
171+
lib.infiniopGetGlobalAvgPoolWorkspaceSize.argtypes = [
172+
infiniopGlobalAvgPoolDescriptor_t,
173+
POINTER(c_uint64),
174+
]
175+
lib.infiniopGlobalAvgPool.restype = c_int32
176+
lib.infiniopGlobalAvgPool.argtypes = [
177+
infiniopGlobalAvgPoolDescriptor_t,
178+
c_void_p,
179+
c_uint64,
180+
c_void_p,
181+
c_void_p,
182+
c_void_p,
183+
]
184+
lib.infiniopDestroyGlobalAvgPoolDescriptor.restype = c_int32
185+
lib.infiniopDestroyGlobalAvgPoolDescriptor.argtypes = [
186+
infiniopGlobalAvgPoolDescriptor_t,
187+
]
188+
189+
if args.cpu:
190+
test_cpu(lib, test_cases)
191+
if args.cuda:
192+
test_cuda(lib, test_cases)
193+
if args.bang:
194+
test_bang(lib, test_cases)
195+
if not (args.cpu or args.cuda or args.bang):
196+
test_cpu(lib, test_cases)
197+
print("\033[92mTest passed!\033[0m")
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include "global_avg_pool_cpu.h"
2+
#include "../../../devices/cpu/common_cpu.h"
3+
#include "../../utils.h"
4+
5+
infiniopStatus_t cpuCreateGlobalAvgPoolDescriptor(infiniopHandle_t,
6+
GlobalAvgPoolCpuDescriptor_t *desc_ptr,
7+
infiniopTensorDescriptor_t y,
8+
infiniopTensorDescriptor_t x) {
9+
uint64_t ndim = y->ndim;
10+
if (ndim < 2 || ndim != x->ndim) {
11+
return STATUS_BAD_TENSOR_SHAPE;
12+
}
13+
for (size_t i = 0; i < ndim; ++i) {
14+
if (i < 2 && y->shape[i] != x->shape[i]) {
15+
return STATUS_BAD_TENSOR_SHAPE;
16+
} else if (i >= 2 && y->shape[i] != 1) {
17+
return STATUS_BAD_TENSOR_SHAPE;
18+
}
19+
}
20+
if (!is_contiguous(y) || !is_contiguous(x)) {
21+
return STATUS_BAD_TENSOR_STRIDES;
22+
}
23+
if (y->dt != F16 && y->dt != F32) {
24+
return STATUS_BAD_TENSOR_DTYPE;
25+
}
26+
if (y->dt != x->dt) {
27+
return STATUS_BAD_TENSOR_DTYPE;
28+
}
29+
30+
uint64_t y_data_size = std::accumulate(y->shape, y->shape + 2, 1ULL, std::multiplies<uint64_t>());
31+
uint64_t x_per_NC_data_size = std::accumulate(x->shape + 2, x->shape + ndim, 1ULL, std::multiplies<uint64_t>());
32+
33+
*desc_ptr = new GlobalAvgPoolCpuDescriptor{
34+
DevCpu,
35+
y->dt,
36+
y_data_size,
37+
x_per_NC_data_size,
38+
};
39+
40+
return STATUS_SUCCESS;
41+
}
42+
43+
infiniopStatus_t cpuGetGlobalAvgPoolWorkspaceSize(GlobalAvgPoolCpuDescriptor_t desc, uint64_t *size) {
44+
*size = 0;
45+
return STATUS_SUCCESS;
46+
}
47+
48+
infiniopStatus_t cpuDestroyGlobalAvgPoolDescriptor(GlobalAvgPoolCpuDescriptor_t desc) {
49+
delete desc;
50+
return STATUS_SUCCESS;
51+
}
52+
53+
template<typename Tdata>
54+
infiniopStatus_t global_avg_pool_cpu(GlobalAvgPoolCpuDescriptor_t desc, void *y, void const *x) {
55+
auto x_ = reinterpret_cast<Tdata const *>(x);
56+
auto y_ = reinterpret_cast<Tdata *>(y);
57+
const auto x_size = desc->x_per_NC_data_size;
58+
59+
#pragma omp parallel for
60+
for (uint64_t i = 0; i < desc->y_data_size; ++i) {
61+
if constexpr (std::is_same<Tdata, uint16_t>::value) {
62+
float sum = std::accumulate(x_ + i * x_size, x_ + (i + 1) * x_size, 0.0f,
63+
[](float res, uint16_t value) {
64+
return res + f16_to_f32(value);
65+
});
66+
y_[i] = f32_to_f16(sum / x_size);
67+
} else {
68+
y_[i] = std::accumulate(x_ + i * x_size, x_ + (i + 1) * x_size, Tdata(0)) / x_size;
69+
}
70+
}
71+
return STATUS_SUCCESS;
72+
}
73+
74+
infiniopStatus_t cpuGlobalAvgPool(GlobalAvgPoolCpuDescriptor_t desc,
75+
void *workspace, uint64_t workspace_size, void *y, void const *x,
76+
void *stream) {
77+
if (desc->dtype == F16) {
78+
return global_avg_pool_cpu<uint16_t>(desc, y, x);
79+
}
80+
if (desc->dtype == F32) {
81+
return global_avg_pool_cpu<float>(desc, y, x);
82+
}
83+
return STATUS_BAD_TENSOR_DTYPE;
84+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef __CPU_GLOBAL_AVG_POOL_H__
2+
#define __CPU_GLOBAL_AVG_POOL_H__
3+
4+
#include "operators.h"
5+
#include <numeric>
6+
7+
struct GlobalAvgPoolCpuDescriptor {
8+
Device device;
9+
DT dtype;
10+
uint64_t y_data_size;
11+
uint64_t x_per_NC_data_size;
12+
};
13+
14+
typedef struct GlobalAvgPoolCpuDescriptor *GlobalAvgPoolCpuDescriptor_t;
15+
16+
infiniopStatus_t cpuCreateGlobalAvgPoolDescriptor(infiniopHandle_t,
17+
GlobalAvgPoolCpuDescriptor_t *,
18+
infiniopTensorDescriptor_t y,
19+
infiniopTensorDescriptor_t x);
20+
21+
infiniopStatus_t cpuGetGlobalAvgPoolWorkspaceSize(GlobalAvgPoolCpuDescriptor_t desc, uint64_t *size);
22+
23+
infiniopStatus_t cpuGlobalAvgPool(GlobalAvgPoolCpuDescriptor_t desc,
24+
void *workspace, uint64_t workspace_size, void *y, void const *x,
25+
void *stream);
26+
27+
infiniopStatus_t cpuDestroyGlobalAvgPoolDescriptor(GlobalAvgPoolCpuDescriptor_t desc);
28+
29+
#endif

0 commit comments

Comments
 (0)