|
22 | 22 | # Configuration (Internal Use Only) |
23 | 23 | # ============================================================================== |
24 | 24 | # These are not meant to be imported from other modules |
| 25 | +_TEST_CASES_ = [ |
| 26 | + ((13, 4), None, None, None), |
| 27 | + ((13, 4), (10, 1), (10, 1), (10, 1)), |
| 28 | + ((13, 4, 4), None, None, None), |
| 29 | + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), |
| 30 | + ((16, 5632), None, None, None), |
| 31 | + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), |
| 32 | + ((4, 4, 5632), None, None, None), |
| 33 | + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), |
| 34 | +] |
| 35 | + |
| 36 | +# Inplace options applied for each test case in _TEST_CASES_ |
| 37 | +_INPLACE = [ |
| 38 | + "Inplace.OUT_OF_PLACE", |
| 39 | + "Inplace.INPLACE_A", |
| 40 | + "Inplace.INPLACE_B", |
| 41 | +] |
| 42 | + |
| 43 | +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ |
25 | 44 | _TEST_CASES = [ |
26 | | - # shape, a_stride, b_stride, c_stride, inplace |
27 | | - ((13, 4), None, None, None, Inplace.OUT_OF_PLACE), |
28 | | - ((13, 4), None, None, None, Inplace.INPLACE_A), |
29 | | - ((13, 4), None, None, None, Inplace.INPLACE_B), |
30 | | - ((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE), |
31 | | - ((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A), |
32 | | - ((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B), |
33 | | - ((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE), |
34 | | - ((13, 4, 4), None, None, None, Inplace.INPLACE_A), |
35 | | - ((13, 4, 4), None, None, None, Inplace.INPLACE_B), |
36 | | - ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE), |
37 | | - ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A), |
38 | | - ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B), |
39 | | - ((16, 5632), None, None, None, Inplace.OUT_OF_PLACE), |
40 | | - ((16, 5632), None, None, None, Inplace.INPLACE_A), |
41 | | - ((16, 5632), None, None, None, Inplace.INPLACE_B), |
42 | | - ((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE), |
43 | | - ((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A), |
44 | | - ((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B), |
45 | | - ((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE), |
46 | | - ((4, 4, 5632), None, None, None, Inplace.INPLACE_A), |
47 | | - ((4, 4, 5632), None, None, None, Inplace.INPLACE_B), |
48 | | - ( |
49 | | - (4, 4, 5632), |
50 | | - (45056, 5632, 1), |
51 | | - (45056, 5632, 1), |
52 | | - (45056, 5632, 1), |
53 | | - Inplace.OUT_OF_PLACE, |
54 | | - ), |
55 | | - ( |
56 | | - (4, 4, 5632), |
57 | | - (45056, 5632, 1), |
58 | | - (45056, 5632, 1), |
59 | | - (45056, 5632, 1), |
60 | | - Inplace.INPLACE_A, |
61 | | - ), |
62 | | - ( |
63 | | - (4, 4, 5632), |
64 | | - (45056, 5632, 1), |
65 | | - (45056, 5632, 1), |
66 | | - (45056, 5632, 1), |
67 | | - Inplace.INPLACE_B, |
68 | | - ), |
| 45 | + test_case + (inplace_item,) |
| 46 | + for test_case in _TEST_CASES_ |
| 47 | + for inplace_item in _INPLACE |
69 | 48 | ] |
70 | 49 |
|
71 | 50 | # Data types used for testing |
@@ -166,7 +145,6 @@ def lib_swiglu(): |
166 | 145 | if DEBUG: |
167 | 146 | debug(c, ans, atol=atol, rtol=rtol) |
168 | 147 | assert torch.allclose(c, ans, atol=atol, rtol=rtol) |
169 | | - print("out-of-place Test passed!") |
170 | 148 |
|
171 | 149 | # Profiling workflow |
172 | 150 | if PROFILE: |
|
0 commit comments