Skip to content

Commit 78b7d86

Browse files
committed
add where_cuda op
1 parent 5818dae commit 78b7d86

File tree

14 files changed

+520
-94
lines changed

14 files changed

+520
-94
lines changed

env.sh

Lines changed: 0 additions & 6 deletions
This file was deleted.

operatorspy/tests/clip.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Tuple
2222
import numpy as np
2323

24-
PROFILE = False
24+
PROFILE = True
2525
NUM_PRERUN = 10
2626
NUM_ITERATIONS = 1000
2727

@@ -112,28 +112,21 @@ def test(
112112
)
113113
elapsed = (time.time() - start_time) / NUM_ITERATIONS
114114
print(f"lib time: {elapsed :10f}")
115-
print("x:", x)
116-
print("custom op ans:", output)
117-
print("ans:", ans) if max != None or min != None else print("ans:", x)
118115
assert torch.allclose(output, ans, atol=0, rtol=0) if max != None or min != None else torch.allclose(output, x, atol=0, rtol=0)
119116
check_error(lib.infiniopDestroyClipDescriptor(descriptor))
120117

121118
def test_cpu(lib, test_cases):
122119
device = DeviceEnum.DEVICE_CPU
123120
handle = create_handle(lib, device)
124-
for x_shape, min, max in test_cases:
125-
test(lib, handle, "cpu", x_shape, min, max, tensor_dtype=torch.float16)
126-
print("\n")
127-
#test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
121+
for x_shape, min, max, tensor_type in test_cases:
122+
test(lib, handle, "cpu", x_shape, min, max, tensor_dtype=tensor_type)
128123
destroy_handle(lib, handle)
129124

130125
def test_cuda(lib, test_cases):
131126
device = DeviceEnum.DEVICE_CUDA
132127
handle = create_handle(lib, device)
133128
for x_shape, min, max, tensor_type in test_cases:
134129
test(lib, handle, "cuda", x_shape, min, max, tensor_dtype=tensor_type)
135-
print("\n")
136-
#test(lib, handle, "cpu", x_shape, axes, tensor_dtype=torch.float32)
137130
destroy_handle(lib, handle)
138131

139132

@@ -145,15 +138,16 @@ def test_cuda(lib, test_cases):
145138
((3, 4), None, None, torch.float32),
146139
((16), -1, 1, torch.float32),
147140
((1024, 1024), -1, 1, torch.float32),
148-
141+
((4096, 4096), -1, 1, torch.float32),
142+
143+
((13), -1, 1, torch.float32),
149144
((3, 4), -1, 1, torch.float16),
150145
((3, 4), None, 1, torch.float16),
151146
((3, 4), -1, None, torch.float16),
152147
((3, 4), None, None, torch.float16),
153148
((16), -1, 1, torch.float16),
154149
((1024, 1024), -1, 1, torch.float16),
155-
156-
# stride =
150+
((4096, 4096), -1, 1, torch.float16),
157151
]
158152
args = get_args()
159153
lib = open_lib()
@@ -175,5 +169,8 @@ def test_cuda(lib, test_cases):
175169
]
176170
lib.infiniopDestroyClipDescriptor.restype = c_int32
177171
lib.infiniopDestroyClipDescriptor.argtypes = [infiniopClipDescriptor_t]
178-
test_cuda(lib, test_cases)
172+
if args.cuda:
173+
test_cuda(lib, test_cases)
174+
if args.cpu:
175+
test_cpu(lib, test_cases)
179176
print("All tests passed!")

operatorspy/tests/where.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import torch
2121
from typing import Tuple
2222
import numpy as np
23-
import onnx
2423

25-
PROFILE = True
24+
PROFILE = False
2625
NUM_PRERUN = 10
2726
NUM_ITERATIONS = 1000
2827

@@ -43,15 +42,22 @@ def tuple_to_void_p(py_tuple: Tuple):
4342
def inferShape(x_shape, y_shape):
4443
ndim_x = len(x_shape)
4544
ndim_y = len(y_shape)
46-
ndim = 0
47-
output_shape = []
4845
ndim = max(ndim_x, ndim_y)
49-
for i in range(ndim - 1, -1, -1):
50-
dim_x = x_shape[i] if i < ndim_x else 1
51-
dim_y = y_shape[i] if i < ndim_y else 1
52-
output_shape.append(max(dim_x, dim_y))
53-
output_shape.reverse()
46+
output_shape = []
47+
48+
for i in range(-1, -ndim-1, -1):
49+
dim_x = x_shape[i] if i >= -ndim_x else 1
50+
dim_y = y_shape[i] if i >= -ndim_y else 1
51+
52+
if dim_x != dim_y:
53+
if dim_x != 1 and dim_y != 1:
54+
raise ValueError(f"Shapes {x_shape} and {y_shape} cannot be broadcast together")
55+
56+
output_dim = max(dim_x, dim_y)
57+
output_shape.insert(0, output_dim)
58+
5459
return tuple(output_shape)
60+
5561

5662
def test(
5763
lib,
@@ -68,8 +74,7 @@ def test(
6874
condition = torch.randint(0, 2, condition_shape, dtype=torch.uint8).to(torch_device)
6975
src1 = torch.randn(src1_shape, dtype=tensor_dtype, device=torch_device)
7076
src2 = torch.randn(src2_shape, dtype=tensor_dtype, device=torch_device)
71-
output = torch.randn(inferShape(src1_shape, src2_shape), dtype=tensor_dtype, device=torch_device)
72-
77+
output = torch.randn(inferShape(inferShape(src1_shape, src2_shape), condition_shape), dtype=tensor_dtype, device=torch_device)
7378

7479
for i in range(NUM_PRERUN if PROFILE else 1):
7580
ans = where(condition, src1, src2)
@@ -130,18 +135,33 @@ def test(
130135
def test_cpu(lib, test_cases):
131136
device = DeviceEnum.DEVICE_CPU
132137
handle = create_handle(lib, device)
133-
for condition_shape, src1_shape, src2_shape in test_cases:
134-
test(lib, handle, "cpu", condition_shape, src1_shape, src2_shape, tensor_dtype=torch.float16)
138+
for condition_shape, src1_shape, src2_shape, tensor_dtype in test_cases:
139+
test(lib, handle, "cpu", condition_shape, src1_shape, src2_shape, tensor_dtype=tensor_dtype)
140+
print("\n")
141+
destroy_handle(lib, handle)
142+
143+
def test_cuda(lib, test_cases):
144+
device = DeviceEnum.DEVICE_CUDA
145+
handle = create_handle(lib, device)
146+
for condition_shape, src1_shape, src2_shape, tensor_dtype in test_cases:
147+
test(lib, handle, "cuda", condition_shape, src1_shape, src2_shape, tensor_dtype=tensor_dtype)
135148
print("\n")
136149
destroy_handle(lib, handle)
137150

138151

139152
if __name__ == "__main__":
140153
test_cases = [
141-
((2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5)),
142-
((3, 1), (3, 4), (1, 4)),
143-
((1,), (3, 4), (3, 4)),
144-
((2, 1, 3), (1, 4, 3), (2, 4, 1)),
154+
((2, 16), (2, 16), (2, 16), torch.float32),
155+
((2, 3, 1, 1), (1, 4, 5), (2, 3, 4, 5), torch.float32),
156+
((3, 1), (3, 4), (1, 4), torch.float32),
157+
((1,), (3, 4), (3, 4), torch.float32),
158+
((2, 1, 3), (1, 4, 3), (2, 4, 1), torch.float32),
159+
160+
((2, 16), (2, 16), (2, 16), torch.float16),
161+
((2, 3, 1, 1), (1, 4, 5), (2, 3, 4, 5), torch.float16),
162+
((3, 1), (3, 4), (1, 4), torch.float16),
163+
((1,), (3, 4), (3, 4), torch.float16),
164+
((2, 1, 3), (1, 4, 3), (2, 4, 1), torch.float16),
145165
]
146166
args = get_args()
147167
lib = open_lib()
@@ -150,6 +170,9 @@ def test_cpu(lib, test_cases):
150170
infiniopHandle_t,
151171
POINTER(infiniopWhereDescriptor_t),
152172
infiniopTensorDescriptor_t,
173+
infiniopTensorDescriptor_t,
174+
infiniopTensorDescriptor_t,
175+
infiniopTensorDescriptor_t
153176
]
154177
lib.infiniopWhere.restype = c_int32
155178
lib.infiniopWhere.argtypes = [
@@ -162,5 +185,8 @@ def test_cpu(lib, test_cases):
162185
]
163186
lib.infiniopDestroyWhereDescriptor.restype = c_int32
164187
lib.infiniopDestroyWhereDescriptor.argtypes = [infiniopWhereDescriptor_t]
165-
test_cpu(lib, test_cases)
188+
if args.cpu:
189+
test_cpu(lib, test_cases)
190+
if args.cuda:
191+
test_cuda(lib, test_cases)
166192
print("All tests passed!")

src/ops/clip/cuda/clip_cuda.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,11 @@ infiniopStatus_t cudaCreateClipDescriptor(CudaHandle_t handle,
2424
element_num *= x->shape[i];
2525
}
2626
uint64_t ndim = y->ndim;
27-
uint64_t S = ndim == 2 ? y->shape[0] : 1;
28-
uint64_t K = ndim == 2 ? y->shape[1] : 1;
2927
*desc_ptr = new ClipCudaDescriptor{
3028
DevNvGpu,
3129
x->dt,
3230
ndim,
3331
element_num,
34-
S,
35-
K
3632
};
3733
return STATUS_SUCCESS;
3834
}

src/ops/clip/cuda/clip_cuda.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@ __global__ void clip_f32x4_kernel(float *a, float *b, float max_value, float min
3737

3838
__global__ void clip_f16x8_pack_kernel(half *a, half *b, float max_value, float min_value, int N){
3939
int idx = 8 * (blockDim.x * blockIdx.x + threadIdx.x);
40+
if (idx >= N) return;
4041
const half min_half = __float2half(min_value);
4142
const half max_half = __float2half(max_value);
42-
if (idx >= N) return;
4343
half pack_a[8], pack_b[8];
44-
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]);
44+
if (idx + 7 < N) {
45+
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]);
46+
} else {
47+
for (int i = 0; i < 8 && (idx + i) < N; i++) {
48+
pack_a[i] = a[idx + i];
49+
}
50+
}
4551
#pragma unroll
4652
for (int i = 0; i < 8; i++)
4753
{
@@ -66,6 +72,14 @@ infiniopStatus_t clip_nv_gpu(
6672
int per_thread_element,
6773
void* stream) {
6874
uint64_t N = desc->element_num;
75+
dim3 block(256 / per_thread_element);
76+
dim3 grid((N + 256 - 1) / 256);
77+
if constexpr(std::is_same<Tdata, float>::value){
78+
clip_f32x4_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(reinterpret_cast<float *>(x), reinterpret_cast<float *>(y), max_value, min_value, N);
79+
}else{
80+
clip_f16x8_pack_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(reinterpret_cast<half *>(x), reinterpret_cast<half *>(y), max_value, min_value, N);
81+
}
82+
/*
6983
if (desc->ndim != 2){
7084
dim3 block(256 / per_thread_element);
7185
dim3 grid((N + 256 - 1) / 256);
@@ -94,6 +108,7 @@ infiniopStatus_t clip_nv_gpu(
94108
}
95109
}
96110
}
111+
*/
97112
return STATUS_SUCCESS;
98113
}
99114

@@ -105,7 +120,6 @@ infiniopStatus_t cudaClip(ClipCudaDescriptor_t desc,
105120
void *stream){
106121
bool has_min = true;
107122
bool has_max = true;
108-
uint64_t N = desc->element_num;
109123
if (min == nullptr){
110124
has_min = false;
111125
}

src/ops/clip/cuda/clip_cuda.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ typedef struct ClipCudaDescriptor {
1010
DT dtype;
1111
uint64_t ndim;
1212
uint64_t element_num;
13-
uint64_t S;
14-
uint64_t K;
1513
} ClipCudaDescriptor;
1614

1715
typedef struct ClipCudaDescriptor *ClipCudaDescriptor_t;

src/ops/clip/operator.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ __C infiniopStatus_t infiniopCreateClipDescriptor(
2828
}
2929
#endif
3030
}
31-
std::cout << "Creating Clip Descriptorxx" << std::endl;
3231
return STATUS_BAD_DEVICE;
3332
}
3433

src/ops/utils.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,34 @@ inline bool getBroadcastShape(const uint64_t *shape1, uint64_t ndim1,
104104
return true;
105105
}
106106

107+
inline bool getBroadcastShape(const uint64_t *shape1, uint64_t ndim1,
108+
const uint64_t *shape2, uint64_t ndim2,
109+
const uint64_t *shape3, uint64_t ndim3,
110+
uint64_t *broadcast_shape, uint64_t *padded_shape1,
111+
uint64_t *padded_shape2, uint64_t *padded_shape3,
112+
uint64_t max_rank) {
113+
// prepending and initializing
114+
std::fill(padded_shape1, padded_shape1 + max_rank, 1);
115+
std::fill(padded_shape2, padded_shape2 + max_rank, 1);
116+
std::fill(padded_shape3, padded_shape3 + max_rank, 1);
117+
std::copy(shape1, shape1 + ndim1, padded_shape1 + max_rank - ndim1);
118+
std::copy(shape2, shape2 + ndim2, padded_shape2 + max_rank - ndim2);
119+
std::copy(shape3, shape3 + ndim3, padded_shape3 + max_rank - ndim3);
120+
121+
// compute broadcasted shape
122+
for (size_t i = 0; i < max_rank; ++i) {
123+
if ((padded_shape1[i] == padded_shape2[i] || padded_shape1[i] == 1 || padded_shape2[i] == 1) &&
124+
(padded_shape1[i] == padded_shape3[i] || padded_shape1[i] == 1 || padded_shape3[i] == 1)) {
125+
broadcast_shape[i] = std::max(std::max(padded_shape1[i], padded_shape2[i]), padded_shape3[i]);
126+
} else {
127+
return false;
128+
}
129+
}
130+
131+
return true;
132+
}
133+
134+
107135
// check if the shape of tensor c is valid after broadcasting tensors a and b and also get the broadcasted shapes
108136
inline bool isValidBroadcastShape(infiniopTensorDescriptor_t a, infiniopTensorDescriptor_t b, infiniopTensorDescriptor_t c,
109137
uint64_t broadcast_ndim) {

0 commit comments

Comments
 (0)