|
20 | 20 | import torch |
21 | 21 | import ctypes |
22 | 22 | import torch.nn.functional as F |
23 | | -import numpy as np |
24 | 23 |
|
25 | 24 | # constant for control whether profile the pytorch and lib functions |
26 | 25 | # NOTE: need to manually add synchronization function to the lib function, |
@@ -58,26 +57,6 @@ def get_mean_variance(x, dtype): |
58 | 57 | ).to(dtype) |
59 | 58 |
|
60 | 59 |
|
61 | | -def find_and_print_differing_indices( |
62 | | - x, tensor1, tensor2, mean, scale, var, b, atol=0, rtol=1e-2 |
63 | | -): |
64 | | - if tensor1.shape != tensor2.shape: |
65 | | - raise ValueError("Tensors must have the same shape to compare.") |
66 | | - |
67 | | - # Calculate the difference mask based on atol and rtol |
68 | | - diff_mask = torch.abs(tensor1 - tensor2) > (atol + rtol * torch.abs(tensor2)) |
69 | | - diff_indices = torch.nonzero(diff_mask, as_tuple=False) |
70 | | - |
71 | | - # Print the indices and the differing elements |
72 | | - for idx in diff_indices: |
73 | | - index_tuple = tuple(idx.tolist()) |
74 | | - print( |
75 | | - f"Index: {index_tuple}, x: {x[index_tuple]}, mean: {mean[index_tuple[1]]}, scale: {scale[index_tuple[1]]}, var: {var[index_tuple[1]]}, b: {b[index_tuple[1]]}, y element: {tensor1[index_tuple]}, ans element: {tensor2[index_tuple]}" |
76 | | - ) |
77 | | - |
78 | | - return diff_indices |
79 | | - |
80 | | - |
81 | 60 | def test( |
82 | 61 | lib, |
83 | 62 | handle, |
@@ -159,11 +138,6 @@ def test( |
159 | 138 | elapsed = (time.time() - start_time) / NUM_ITERATIONS |
160 | 139 | print(f" lib time: {elapsed :6f}") |
161 | 140 |
|
162 | | - # print(" - x: \n", x, "\n - y:\n", y, "\n - ans:\n", ans) |
163 | | - # print(" - y:\n", y, "\n - ans:\n", ans) |
164 | | - |
165 | | - # find_and_print_differing_indices(x, y, ans, mean, scale, mean, b, atol=1e-7, rtol=1e-3) |
166 | | - # np.testing.assert_allclose(y.numpy(), ans.numpy(), atol=1e-7, rtol=1e-3) |
167 | 141 | assert torch.allclose(y, ans, atol=1e-7, rtol=1e-3) |
168 | 142 | check_error(lib.infiniopDestroyBatchNormDescriptor(descriptor)) |
169 | 143 |
|
@@ -200,11 +174,9 @@ def test_bang(lib, test_cases): |
200 | 174 | if __name__ == "__main__": |
201 | 175 | test_cases = [ |
202 | 176 | # x_shape, eps |
203 | | - ((2, 3, 4), 1e-5), |
| 177 | + ((2, 5, 7), 1e-5), |
204 | 178 | ((32, 3, 1024), 1e-5), |
205 | | - ((1, 3, 4, 4), 1e-5), |
206 | 179 | ((32, 3, 128, 128), 1e-5), |
207 | | - ((1, 6, 5, 5, 5), 1e-5), |
208 | 180 | ((32, 3, 64, 64, 64), 1e-5), |
209 | 181 | ] |
210 | 182 | args = get_args() |
|
0 commit comments