Skip to content

Commit fc15152

Browse files
committed
contigous axes reduce op
1 parent a2d8f77 commit fc15152

File tree

6 files changed

+248
-84
lines changed

6 files changed

+248
-84
lines changed

operatorspy/tests/reducemax.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Tuple
2222
import numpy as np
2323

24-
PROFILE = False
24+
PROFILE = True
2525
NUM_PRERUN = 1
2626
NUM_ITERATIONS = 1
2727

@@ -141,8 +141,9 @@ def test(
141141
)
142142
elapsed = (time.time() - start_time) / NUM_ITERATIONS
143143
print(f"lib time: {elapsed :10f}")
144-
print(f"custom op output:{y}")
145-
print(f"pytorch output:{ans}")
144+
# print(f"input : {x}")
145+
# print(f"custom op output:{y}")
146+
# print(f"pytorch output:{ans}")
146147
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
147148

148149
check_error(lib.infiniopDestroyReducemaxDescriptor(descriptor))
@@ -182,10 +183,13 @@ def test_cuda(lib, test_cases):
182183
# ((2, 10, 24, 10), None, False, False, None),
183184
# ((2, 3, 4), [0, 1], False, False, None),
184185
#((2, 10, 24, 10), [], True),
185-
((4,), [0], False, False, None, torch.float32),
186-
((1000, 3), [0, 1], False, False, None, torch.float16),
187-
((50, 3), [0, 1], False, False, None, torch.float32),
188-
((1000, 3), [0, 1], False, False, None, torch.float32),
186+
#((4,), [0], False, False, None, torch.float32),
187+
((1000, 300), [0, 1], False, False, None, torch.float16),
188+
((50, 3), [0, 1], False, False, None, torch.float16),
189+
((1000, 300), [0, 1], False, False, None, torch.float16),
190+
((2000, 200, 50), [0, 1], False, True, None, torch.float32),
191+
((1000, 200, 500), [0, 1], False, True, None, torch.float16),
192+
((1000, 200, 50), [0, 1], False, True, None, torch.float32),
189193
]
190194
args = get_args()
191195
lib = open_lib()

operatorspy/tests/reducemean.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Tuple
2222
import numpy as np
2323

24-
PROFILE = False
24+
PROFILE = True
2525
NUM_PRERUN = 1
2626
NUM_ITERATIONS = 1
2727

@@ -177,8 +177,13 @@ def test_cuda(lib, test_cases):
177177
# # stride =
178178
# ((2, 10, 24, 10), [0, 1], False, True, None),
179179
# ((2, 10, 24, 10), [2, 3], False , True, None),
180-
((50, 3), [0, 1], False, False, None, torch.float16),
181-
((1000, 3), [0, 1], False, False, None, torch.float16),
180+
#((1000, 300), [0, 1], False, False, None, torch.float16),
181+
((30, 50, 20, 1000), [0, 1, 2, 3], False, False, None, torch.float16),
182+
((30000, 1000, 40), [0, 1], False, False, None, torch.float32),
183+
#((1000, 300), [0, 1], False, False, None, torch.float16),
184+
((2, 2, 5), [0, 1], False, True, None, torch.float32),
185+
((1000, 200, 500), [0, 1], False, True, None, torch.float16),
186+
((1000, 200, 50), [0, 1], False, True, None, torch.float32),
182187
# validate attribute noop_with_empty_axes and keepdims
183188
# ((2, 10, 24, 10), None, True, True, None),
184189
# ((2, 10, 24, 10), None, True, False, None),

operatorspy/tests/reducemin.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Tuple
2222
import numpy as np
2323

24-
PROFILE = False
24+
PROFILE = True
2525
NUM_PRERUN = 1
2626
NUM_ITERATIONS = 1
2727

@@ -141,41 +141,50 @@ def test(
141141
)
142142
elapsed = (time.time() - start_time) / NUM_ITERATIONS
143143
print(f"lib time: {elapsed :10f}")
144-
print(f"custom op output:{y}")
145-
print(f"pytorch output:{ans}")
144+
# print(f"custom op output:{y}")
145+
# print(f"pytorch output:{ans}")
146146
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
147147

148148
check_error(lib.infiniopDestroyReducemaxDescriptor(descriptor))
149149

150150
def test_cpu(lib, test_cases):
151151
device = DeviceEnum.DEVICE_CPU
152152
handle = create_handle(lib, device)
153-
for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes in test_cases:
154-
print(dynamic_axes)
155-
test(lib, handle, "cpu", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=torch.float16)
156-
print("\n")
153+
for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases:
154+
test(lib, handle, "cpu", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype)
157155
#test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
158156
destroy_handle(lib, handle)
159157

158+
def test_cuda(lib, test_cases):
159+
device = DeviceEnum.DEVICE_CUDA
160+
handle = create_handle(lib, device)
161+
for x_shape, axes, noop_with_empty_axes, keepdims, dynamic_axes, tensor_dtype in test_cases:
162+
test(lib, handle, "cuda", x_shape, axes, dynamic_axes, noop_with_empty_axes, keepdims, tensor_dtype=tensor_dtype)
163+
print("\n")
164+
destroy_handle(lib, handle)
160165

161166
if __name__ == "__main__":
162167
test_cases = [
163168
# dynamic calc test eg
164-
((2, 3, 4, 5), [0, 2], False, True, None),
165-
((2, 3, 4, 5), [0, 2], False, True, None),
166-
#(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
167-
((2, 10, 24, 10), [0, 2], False, True, None),
168-
# stride =
169-
((2, 10, 24, 10), [0, 1], False, True, None),
170-
((2, 10, 24, 10), [2, 3], False , True, None),
171-
((2, 10, 24, 10), [0, 1, 2, 3], False, True, None),
172-
# validate attribute noop_with_empty_axes and keepdims
173-
((2, 10, 24, 10), None, True, True, None),
174-
((2, 10, 24, 10), None, True, False, None),
175-
((2, 10, 24, 10), None, False, True, None),
176-
((2, 10, 24, 10), None, False, False, None),
177-
((2, 3, 4), [0, 1], False, False, None),
178-
#((2, 10, 24, 10), [], True),
169+
# ((2, 3, 4, 5), [0, 2], False, True, None),
170+
# ((2, 3, 4, 5), [0, 2], False, True, None),
171+
# #(input_shape, axis, noop_with_empty_axes, keepdims, dynamic_axes)
172+
# ((2, 10, 24, 10), [0, 2], False, True, None),
173+
# # stride =
174+
# ((2, 10, 24, 10), [0, 1], False, True, None),
175+
# ((2, 10, 24, 10), [2, 3], False , True, None),
176+
# ((2, 10, 24, 10), [0, 1, 2, 3], False, True, None),
177+
# # validate attribute noop_with_empty_axes and keepdims
178+
# ((2, 10, 24, 10), None, True, True, None),
179+
# ((2, 10, 24, 10), None, True, False, None),
180+
# ((2, 10, 24, 10), None, False, True, None),
181+
# ((2, 10, 24, 10), None, False, False, None),
182+
# ((2, 3, 4), [0, 1], False, False, None),
183+
# #((2, 10, 24, 10), [], True),
184+
((2, 1000), [0, 1], False, False, None, torch.float32),
185+
((2, 2, 5), [0, 1], False, True, None, torch.float32),
186+
((1000, 200, 500), [0, 1], False, True, None, torch.float16),
187+
((1000, 200, 50), [0, 1], False, True, None, torch.float32),
179188
]
180189
args = get_args()
181190
lib = open_lib()
@@ -201,5 +210,8 @@ def test_cpu(lib, test_cases):
201210
]
202211
lib.infiniopDestroyReduceminDescriptor.restype = c_int32
203212
lib.infiniopDestroyReduceminDescriptor.argtypes = [infiniopReduceminDescriptor_t]
204-
test_cpu(lib, test_cases)
213+
if args.cpu:
214+
test_cpu(lib, test_cases)
215+
if args.cuda:
216+
test_cuda(lib, test_cases)
205217
print("All tests passed!")

src/ops/reduce/cuda/reduce_cuda.cc

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle,
3636
output_size *= y->shape[i];
3737
}
3838
uint64_t reduce_size = element_num / output_size;
39-
uint64_t ndim = y->ndim;
4039

4140
bool *reduce_mask = new bool[x->ndim];
4241
int64_t *input_strides = new int64_t[x->ndim];
4342
int64_t *output_strides = new int64_t[y->ndim];
4443
uint64_t *input_shape = new uint64_t[x->ndim];
4544
uint64_t *output_shape = new uint64_t[y->ndim];
46-
45+
memset(reduce_mask, 0, x->ndim * sizeof(bool));
4746
memcpy(input_shape, x->shape, x->ndim * sizeof(uint64_t));
4847
memcpy(output_shape, y->shape, y->ndim * sizeof(uint64_t));
4948
memcpy(input_strides, x->strides, x->ndim * sizeof(int64_t));
5049
memcpy(output_strides, y->strides, y->ndim * sizeof(int64_t));
51-
50+
int prefix_size = 1, suffix_size = 1;
5251
bool if_reduce_axes_contiguous = true;
5352
int reduce_mode = 0;
5453
for (uint64_t i = 0; i < axes_size; i++) {
@@ -60,21 +59,25 @@ infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle,
6059
if (if_reduce_axes_contiguous) {
6160
if (axes_size == x->ndim) {
6261
// all axes are reduced
63-
int reduce_mode = 0;
62+
reduce_mode = 0;
6463
} else {
65-
// some axes are not reduced but axes are contiguous
66-
if (reduce_size > 1024 && output_size < 128) reduce_mode = 1; // multi-thread for each output element
67-
else reduce_mode = 2; // one thread for each output element
64+
for (uint64_t i = 0; i < axes[0]; i++) {
65+
prefix_size *= x->shape[i];
66+
}
67+
for (uint64_t i = axes[axes_size - 1] + 1; i < x->ndim; i++) {
68+
suffix_size *= x->shape[i];
69+
}
70+
reduce_mode = 1;
6871
}
6972
} else {
70-
if (reduce_size > 1024 && output_size < 128) reduce_mode = 3;
73+
if (reduce_size > 1024 && output_size > 128) reduce_mode = 3;
7174
else reduce_mode = 4;
7275
}
73-
bool *d_reduce_mask = new bool[x->ndim];
74-
int64_t *d_input_strides = new int64_t[x->ndim];
75-
int64_t *d_output_strides = new int64_t[y->ndim];
76-
uint64_t *d_input_shape = new uint64_t[x->ndim];
77-
uint64_t *d_output_shape = new uint64_t[y->ndim];
76+
bool *d_reduce_mask;
77+
int64_t *d_input_strides;
78+
int64_t *d_output_strides;
79+
uint64_t *d_input_shape;
80+
uint64_t *d_output_shape;
7881

7982
checkCudaErrorWithCode(cudaMalloc((void**)&d_reduce_mask, x->ndim * sizeof(bool)), STATUS_MEMORY_NOT_ALLOCATED);
8083
checkCudaErrorWithCode(cudaMalloc((void**)&d_input_strides, x->ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED);
@@ -91,7 +94,8 @@ infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle,
9194
*desc_ptr = new ReduceCudaDescriptor{
9295
DevNvGpu,
9396
x->dt,
94-
ndim,
97+
y->ndim,
98+
x->ndim,
9599
d_reduce_mask,
96100
d_input_strides,
97101
d_output_strides,
@@ -101,7 +105,13 @@ infiniopStatus_t cudaCreateReduceDescriptor(CudaHandle_t handle,
101105
element_num,
102106
output_size,
103107
static_cast<int>(reduce_op_type),
104-
reduce_mode
108+
reduce_mode,
109+
axes_size,
110+
keepdims,
111+
axes[0],
112+
axes[axes_size - 1],
113+
prefix_size,
114+
suffix_size
105115
};
106116
delete [] reduce_mask;
107117
delete [] input_strides;

0 commit comments

Comments
 (0)