Skip to content

Commit 069bf45

Browse files
committed
Remove unnecessary testing utils in the frontend test
1 parent fbd4e0c commit 069bf45

File tree

1 file changed

+1
-29
lines changed

1 file changed

+1
-29
lines changed

operatorspy/tests/batch_norm.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch
2121
import ctypes
2222
import torch.nn.functional as F
23-
import numpy as np
2423

2524
# constant for control whether profile the pytorch and lib functions
2625
# NOTE: need to manually add synchronization function to the lib function,
@@ -58,26 +57,6 @@ def get_mean_variance(x, dtype):
5857
).to(dtype)
5958

6059

61-
def find_and_print_differing_indices(
62-
x, tensor1, tensor2, mean, scale, var, b, atol=0, rtol=1e-2
63-
):
64-
if tensor1.shape != tensor2.shape:
65-
raise ValueError("Tensors must have the same shape to compare.")
66-
67-
# Calculate the difference mask based on atol and rtol
68-
diff_mask = torch.abs(tensor1 - tensor2) > (atol + rtol * torch.abs(tensor2))
69-
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
70-
71-
# Print the indices and the differing elements
72-
for idx in diff_indices:
73-
index_tuple = tuple(idx.tolist())
74-
print(
75-
f"Index: {index_tuple}, x: {x[index_tuple]}, mean: {mean[index_tuple[1]]}, scale: {scale[index_tuple[1]]}, var: {var[index_tuple[1]]}, b: {b[index_tuple[1]]}, y element: {tensor1[index_tuple]}, ans element: {tensor2[index_tuple]}"
76-
)
77-
78-
return diff_indices
79-
80-
8160
def test(
8261
lib,
8362
handle,
@@ -159,11 +138,6 @@ def test(
159138
elapsed = (time.time() - start_time) / NUM_ITERATIONS
160139
print(f" lib time: {elapsed :6f}")
161140

162-
# print(" - x: \n", x, "\n - y:\n", y, "\n - ans:\n", ans)
163-
# print(" - y:\n", y, "\n - ans:\n", ans)
164-
165-
# find_and_print_differing_indices(x, y, ans, mean, scale, mean, b, atol=1e-7, rtol=1e-3)
166-
# np.testing.assert_allclose(y.numpy(), ans.numpy(), atol=1e-7, rtol=1e-3)
167141
assert torch.allclose(y, ans, atol=1e-7, rtol=1e-3)
168142
check_error(lib.infiniopDestroyBatchNormDescriptor(descriptor))
169143

@@ -200,11 +174,9 @@ def test_bang(lib, test_cases):
200174
if __name__ == "__main__":
201175
test_cases = [
202176
# x_shape, eps
203-
((2, 3, 4), 1e-5),
177+
((2, 5, 7), 1e-5),
204178
((32, 3, 1024), 1e-5),
205-
((1, 3, 4, 4), 1e-5),
206179
((32, 3, 128, 128), 1e-5),
207-
((1, 6, 5, 5, 5), 1e-5),
208180
((32, 3, 64, 64, 64), 1e-5),
209181
]
210182
args = get_args()

0 commit comments

Comments
 (0)