Skip to content

Commit d519ed1

Browse files
committed
issue/66: modified random_sample, swiglu, rms_norm, test
1 parent aff079a commit d519ed1

File tree

3 files changed

+25
-50
lines changed

3 files changed

+25
-50
lines changed

test/infiniop/random_sample.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,9 @@ def lib_random_sample():
188188
# Profiling workflow
189189
if PROFILE:
190190
# fmt: off
191-
if topp > 0 and topk > 1:
192-
profile_operation("PyTorch", lambda: random_sample(
193-
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
191+
profile_operation("PyTorch", lambda: random_sample(
192+
data, random_val, topp, topk, voc, temperature, torch_device
194193
), torch_device, NUM_PRERUN, NUM_ITERATIONS)
195-
else:
196-
profile_operation("PyTorch", lambda: random_sample_0(data), torch_device, NUM_PRERUN, NUM_ITERATIONS)
197-
198194
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
199195
# fmt: on
200196
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))

test/infiniop/rms_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def lib_rms_norm():
133133
if DEBUG:
134134
debug(y, ans, atol=atol, rtol=rtol)
135135
assert torch.allclose(y, ans, atol=atol, rtol=rtol)
136+
136137
# Profiling workflow
137138
if PROFILE:
138139
# fmt: off

test/infiniop/swiglu.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,50 +22,29 @@
2222
# Configuration (Internal Use Only)
2323
# ==============================================================================
2424
# These are not meant to be imported from other modules
25+
_TEST_CASES_ = [
26+
((13, 4), None, None, None),
27+
((13, 4), (10, 1), (10, 1), (10, 1)),
28+
((13, 4, 4), None, None, None),
29+
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
30+
((16, 5632), None, None, None),
31+
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
32+
((4, 4, 5632), None, None, None),
33+
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
34+
]
35+
36+
# Inplace options applied for each test case in _TEST_CASES_
37+
_INPLACE = [
38+
"Inplace.OUT_OF_PLACE",
39+
"Inplace.INPLACE_A",
40+
"Inplace.INPLACE_B",
41+
]
42+
43+
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
2544
_TEST_CASES = [
26-
# shape, a_stride, b_stride, c_stride, inplace
27-
((13, 4), None, None, None, Inplace.OUT_OF_PLACE),
28-
((13, 4), None, None, None, Inplace.INPLACE_A),
29-
((13, 4), None, None, None, Inplace.INPLACE_B),
30-
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE),
31-
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A),
32-
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B),
33-
((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE),
34-
((13, 4, 4), None, None, None, Inplace.INPLACE_A),
35-
((13, 4, 4), None, None, None, Inplace.INPLACE_B),
36-
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE),
37-
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A),
38-
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B),
39-
((16, 5632), None, None, None, Inplace.OUT_OF_PLACE),
40-
((16, 5632), None, None, None, Inplace.INPLACE_A),
41-
((16, 5632), None, None, None, Inplace.INPLACE_B),
42-
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE),
43-
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A),
44-
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B),
45-
((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE),
46-
((4, 4, 5632), None, None, None, Inplace.INPLACE_A),
47-
((4, 4, 5632), None, None, None, Inplace.INPLACE_B),
48-
(
49-
(4, 4, 5632),
50-
(45056, 5632, 1),
51-
(45056, 5632, 1),
52-
(45056, 5632, 1),
53-
Inplace.OUT_OF_PLACE,
54-
),
55-
(
56-
(4, 4, 5632),
57-
(45056, 5632, 1),
58-
(45056, 5632, 1),
59-
(45056, 5632, 1),
60-
Inplace.INPLACE_A,
61-
),
62-
(
63-
(4, 4, 5632),
64-
(45056, 5632, 1),
65-
(45056, 5632, 1),
66-
(45056, 5632, 1),
67-
Inplace.INPLACE_B,
68-
),
45+
test_case + (inplace_item,)
46+
for test_case in _TEST_CASES_
47+
for inplace_item in _INPLACE
6948
]
7049

7150
# Data types used for testing
@@ -166,7 +145,6 @@ def lib_swiglu():
166145
if DEBUG:
167146
debug(c, ans, atol=atol, rtol=rtol)
168147
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
169-
print("out-of-place Test passed!")
170148

171149
# Profiling workflow
172150
if PROFILE:

0 commit comments

Comments
 (0)