Skip to content

Commit 3f90f69

Browse files
committed
modified random_sample.py
1 parent fe5bd5b commit 3f90f69

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

operatorspy/tests/random_sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
8383
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}"
8484
)
8585

86-
data = torch.rand((voc), dtype=x_dtype).to(torch_device)
86+
data = torch.arange(start=1, end=voc + 1, dtype=x_dtype).to(torch_device) / voc
87+
data = data[torch.randperm(voc)]
8788
if(topp > 0 and topk > 1):
8889
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
8990
else:

0 commit comments

Comments
 (0)