@@ -92,6 +92,21 @@ def get_zeros_int64(shape):
9292 """Get zeros."""
9393 return np .zeros (shape ).astype (np .int64 )
9494
95+ def get_ones_int32 (shape ):
96+ """Get ones."""
97+ return np .ones (shape ).astype (np .int32 )
98+
99+ def get_small_rand_int32 (shape ):
100+ """Get random ints in range [1, 99]"""
101+ return np .random .randint (low = 1 , high = 100 , size = shape , dtype = np .int32 )
102+
103+ def get_zeros_then_ones (shape ):
104+ """Fill half the tensor with zeros and the rest with ones"""
105+ cnt = np .prod (shape )
106+ zeros_cnt = cnt // 2
107+ ones_cnt = cnt - zeros_cnt
108+ return np .concatenate ((np .zeros (zeros_cnt , dtype = np .int32 ), np .ones (ones_cnt , dtype = np .int32 ))).reshape (shape )
109+
95110def get_wav (shape ):
96111 """Get sound data."""
97112 return np .sin (np .linspace (- np .pi , np .pi , shape [0 ]), dtype = np .float32 )
@@ -107,8 +122,12 @@ def get_wav(shape):
107122 "get_wav" : get_wav ,
108123 "get_zeros_int32" : get_zeros_int32 ,
109124 "get_zeros_int64" : get_zeros_int64 ,
125+ "get_ones_int32" : get_ones_int32 ,
126+ "get_small_rand_int32" : get_small_rand_int32 ,
127+ "get_zeros_then_ones" : get_zeros_then_ones
110128}
111129
130+
112131OpsetConstraint = namedtuple ("OpsetConstraint" , "domain, min_version, max_version, excluded_version" )
113132
114133
0 commit comments