2222# Configuration (Internal Use Only)
2323# ==============================================================================
2424# These are not meant to be imported from other modules
25-
2625_TEST_CASES = [
2726 # voc, random_val, topp, topk, temperature
2827 (512 , 0.8 , 0.8 , 3 , 0.5 ),
@@ -59,53 +58,52 @@ class RandomSampleDescriptor(Structure):
5958
6059
6160def random_sample (data , random_val , topp , topk , voc , temperature , torch_device ):
62- indices = torch .zeros ([topk ], dtype = torch .int64 )
63- dataNp = data .clone ().detach ()
64- sorted_indices = torch .arange (voc )
65-
66- for i in range (topk ):
67- for j in range (i + 1 , voc ):
68- if dataNp [i ] < dataNp [j ]:
69- tmp = dataNp [i ].clone ().detach ()
70- dataNp [i ] = dataNp [j ].clone ().detach ()
71- dataNp [j ] = tmp
72-
73- tmpInd = sorted_indices [i ].clone ().detach ()
74- sorted_indices [i ] = sorted_indices [j ].clone ().detach ()
75- sorted_indices [j ] = tmpInd
76-
77- # sorted_indices = torch.argsort(dataNp, descending=True)
78- indices = sorted_indices [:topk ]
79-
80- dataNp = dataNp [sorted_indices ]
81-
82- globalM = dataNp [0 ]
83- dataNp = (dataNp - globalM ) / temperature
84- dataNp = torch .softmax (dataNp .float (), dim = 0 )
85- sum_s = 0
86- for end in range (topk ):
87- sum_s += dataNp [end ]
88- if sum_s >= topp :
89- break
90- if end < topk - 1 :
91- end += 1
61+ if topp > 0 and topk > 1 :
62+ indices = torch .zeros ([topk ], dtype = torch .int64 )
63+ dataNp = data .clone ().detach ()
64+ sorted_indices = torch .arange (voc )
65+
66+ for i in range (topk ):
67+ for j in range (i + 1 , voc ):
68+ if dataNp [i ] < dataNp [j ]:
69+ tmp = dataNp [i ].clone ().detach ()
70+ dataNp [i ] = dataNp [j ].clone ().detach ()
71+ dataNp [j ] = tmp
72+
73+ tmpInd = sorted_indices [i ].clone ().detach ()
74+ sorted_indices [i ] = sorted_indices [j ].clone ().detach ()
75+ sorted_indices [j ] = tmpInd
76+
77+ # sorted_indices = torch.argsort(dataNp, descending=True)
78+ indices = sorted_indices [:topk ]
79+
80+ dataNp = dataNp [sorted_indices ]
81+
82+ globalM = dataNp [0 ]
83+ dataNp = (dataNp - globalM ) / temperature
84+ dataNp = torch .softmax (dataNp .float (), dim = 0 )
85+ sum_s = 0
86+ for end in range (topk ):
87+ sum_s += dataNp [end ]
88+ if sum_s >= topp :
89+ break
90+ if end < topk - 1 :
91+ end += 1
92+ else :
93+ end = topk
94+
95+ sum_s = 0
96+ for i in range (end ):
97+ sum_s += dataNp [i ]
98+ random_val *= sum_s
99+
100+ sum_s = 0
101+ for i in range (end ):
102+ sum_s += dataNp [i ]
103+ if random_val < sum_s :
104+ return indices [i ]
92105 else :
93- end = topk
94-
95- sum_s = 0
96- for i in range (end ):
97- sum_s += dataNp [i ]
98- random_val *= sum_s
99-
100- sum_s = 0
101- for i in range (end ):
102- sum_s += dataNp [i ]
103- if random_val < sum_s :
104- return indices [i ]
105-
106-
107- def random_sample_0 (data ):
108- return torch .argmax (data )
106+ return torch .argmax (data )
109107
110108
111109def test (
@@ -124,12 +122,10 @@ def test(
124122 data = torch .arange (voc ).float () * 0.0001
125123 _perm = torch .randperm (voc )
126124 data = data [_perm ].to (x_dtype ).to (torch_device )
127- if topp > 0 and topk > 1 :
128- ans = random_sample (
129- data .to ("cpu" ), random_val , topp , topk , voc , temperature , "cpu"
130- )
131- else :
132- ans = random_sample_0 (data )
125+
126+ ans = random_sample (
127+ data , random_val , topp , topk , voc , temperature , torch_device
128+ ) # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
133129
134130 indices = torch .zeros ([1 ], dtype = torch .int64 ).to (torch_device )
135131
0 commit comments