Skip to content

Commit aff079a

Browse files
committed
modified random sample test function
1 parent c927faf commit aff079a

File tree

2 files changed

+49
-54
lines changed

2 files changed

+49
-54
lines changed

test/infiniop/causal_softmax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# Configuration (Internal Use Only)
2222
# ==============================================================================
2323
# These are not meant to be imported from other modules
24-
2524
_TEST_CASES = [
2625
# x_shape, x_stride
2726
((32, 512), None),

test/infiniop/random_sample.py

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
# Configuration (Internal Use Only)
2323
# ==============================================================================
2424
# These are not meant to be imported from other modules
25-
2625
_TEST_CASES = [
2726
# voc, random_val, topp, topk, temperature
2827
(512, 0.8, 0.8, 3, 0.5),
@@ -59,53 +58,52 @@ class RandomSampleDescriptor(Structure):
5958

6059

6160
def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
62-
indices = torch.zeros([topk], dtype=torch.int64)
63-
dataNp = data.clone().detach()
64-
sorted_indices = torch.arange(voc)
65-
66-
for i in range(topk):
67-
for j in range(i + 1, voc):
68-
if dataNp[i] < dataNp[j]:
69-
tmp = dataNp[i].clone().detach()
70-
dataNp[i] = dataNp[j].clone().detach()
71-
dataNp[j] = tmp
72-
73-
tmpInd = sorted_indices[i].clone().detach()
74-
sorted_indices[i] = sorted_indices[j].clone().detach()
75-
sorted_indices[j] = tmpInd
76-
77-
# sorted_indices = torch.argsort(dataNp, descending=True)
78-
indices = sorted_indices[:topk]
79-
80-
dataNp = dataNp[sorted_indices]
81-
82-
globalM = dataNp[0]
83-
dataNp = (dataNp - globalM) / temperature
84-
dataNp = torch.softmax(dataNp.float(), dim=0)
85-
sum_s = 0
86-
for end in range(topk):
87-
sum_s += dataNp[end]
88-
if sum_s >= topp:
89-
break
90-
if end < topk - 1:
91-
end += 1
61+
if topp > 0 and topk > 1:
62+
indices = torch.zeros([topk], dtype=torch.int64)
63+
dataNp = data.clone().detach()
64+
sorted_indices = torch.arange(voc)
65+
66+
for i in range(topk):
67+
for j in range(i + 1, voc):
68+
if dataNp[i] < dataNp[j]:
69+
tmp = dataNp[i].clone().detach()
70+
dataNp[i] = dataNp[j].clone().detach()
71+
dataNp[j] = tmp
72+
73+
tmpInd = sorted_indices[i].clone().detach()
74+
sorted_indices[i] = sorted_indices[j].clone().detach()
75+
sorted_indices[j] = tmpInd
76+
77+
# sorted_indices = torch.argsort(dataNp, descending=True)
78+
indices = sorted_indices[:topk]
79+
80+
dataNp = dataNp[sorted_indices]
81+
82+
globalM = dataNp[0]
83+
dataNp = (dataNp - globalM) / temperature
84+
dataNp = torch.softmax(dataNp.float(), dim=0)
85+
sum_s = 0
86+
for end in range(topk):
87+
sum_s += dataNp[end]
88+
if sum_s >= topp:
89+
break
90+
if end < topk - 1:
91+
end += 1
92+
else:
93+
end = topk
94+
95+
sum_s = 0
96+
for i in range(end):
97+
sum_s += dataNp[i]
98+
random_val *= sum_s
99+
100+
sum_s = 0
101+
for i in range(end):
102+
sum_s += dataNp[i]
103+
if random_val < sum_s:
104+
return indices[i]
92105
else:
93-
end = topk
94-
95-
sum_s = 0
96-
for i in range(end):
97-
sum_s += dataNp[i]
98-
random_val *= sum_s
99-
100-
sum_s = 0
101-
for i in range(end):
102-
sum_s += dataNp[i]
103-
if random_val < sum_s:
104-
return indices[i]
105-
106-
107-
def random_sample_0(data):
108-
return torch.argmax(data)
106+
return torch.argmax(data)
109107

110108

111109
def test(
@@ -124,12 +122,10 @@ def test(
124122
data = torch.arange(voc).float() * 0.0001
125123
_perm = torch.randperm(voc)
126124
data = data[_perm].to(x_dtype).to(torch_device)
127-
if topp > 0 and topk > 1:
128-
ans = random_sample(
129-
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
130-
)
131-
else:
132-
ans = random_sample_0(data)
125+
126+
ans = random_sample(
127+
data, random_val, topp, topk, voc, temperature, torch_device
128+
) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
133129

134130
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
135131

0 commit comments

Comments
 (0)