@@ -73,7 +73,12 @@ def permute_sparse_features_ref_(
7373 )
7474 @settings (max_examples = 20 , deadline = None )
7575 def test_permute_sparse_features (
76- self , B : int , T : int , L : int , long_index : bool , has_weight : bool
76+ self ,
77+ B : int ,
78+ T : int ,
79+ L : int ,
80+ long_index : bool ,
81+ has_weight : bool ,
7782 ) -> None :
7883 index_dtype = torch .int64 if long_index else torch .int32
7984 lengths = torch .randint (low = 1 , high = L , size = (T , B )).type (index_dtype )
@@ -193,6 +198,137 @@ def test_permute_sparse_features_with_repeats(
193198 assert permuted_weights_cpu is None
194199
195200
201+ class Permute1DSparseFeaturesTest (unittest .TestCase ):
202+ @unittest .skipIf (* gpu_unavailable )
203+ @given (
204+ T = st .integers (min_value = 1 , max_value = 20 ),
205+ L = st .integers (min_value = 2 , max_value = 20 ),
206+ long_index = st .booleans (),
207+ has_weight = st .booleans (),
208+ weight_columns = st .integers (min_value = 1 , max_value = 20 ),
209+ )
210+ @settings (
211+ max_examples = 20 ,
212+ deadline = None ,
213+ )
214+ def test_permute_1D_sparse_data (
215+ self ,
216+ T : int ,
217+ L : int ,
218+ long_index : bool ,
219+ has_weight : bool ,
220+ weight_columns : int ,
221+ ) -> None :
222+ # Setup: Choose index data type based on test parameter
223+ index_dtype = torch .int64 if long_index else torch .int32
224+
225+ # Create 1D lengths tensor representing sparse feature counts for T features
226+ lengths = torch .randint (
227+ low = 1 ,
228+ high = L ,
229+ size = (T ,), # 1D tensor with T elements
230+ device = torch .accelerator .current_accelerator (),
231+ ).type (index_dtype )
232+
233+ # Create optional 2D weights tensor with dimensions [total_indices, weight_columns]
234+ weights = (
235+ torch .rand (
236+ int (lengths .sum ().item ()),
237+ weight_columns ,
238+ device = torch .accelerator .current_accelerator (),
239+ ).float ()
240+ if has_weight
241+ else None
242+ )
243+
244+ # Create indices tensor containing sparse feature indices
245+ indices = torch .randint (
246+ low = 1 ,
247+ high = int (1e5 ),
248+ size = cast (tuple [int , ...], (lengths .sum ().item (),)),
249+ device = torch .accelerator .current_accelerator (),
250+ ).type (index_dtype )
251+
252+ # Create random permutation for shuffling features
253+ permute_list = list (range (T ))
254+ random .shuffle (permute_list )
255+ permute = torch .IntTensor (permute_list ).to (
256+ device = torch .accelerator .current_accelerator ()
257+ )
258+ # Execute: Call the permute_1D_sparse_data operation
259+ (
260+ lengths_actual ,
261+ values_actual ,
262+ weights_actual ,
263+ ) = torch .ops .fbgemm .permute_1D_sparse_data (
264+ permute , lengths , indices , weights , indices .numel ()
265+ )
266+
267+ # Assert: Verify that the lengths were correctly permuted
268+ # The permuted lengths should match the original lengths indexed by the permutation
269+ self .assertTrue (
270+ torch .equal (
271+ lengths_actual , torch .index_select (lengths , dim = 0 , index = permute )
272+ )
273+ )
274+
275+ # Track the current position in the permuted output for validation
276+ permuted_cumulated_index = 0
277+
278+ # Compute cumulative offsets to locate each feature's data in the original arrays
279+ # Prepend a zero to get offsets: [0, lengths[0], lengths[0]+lengths[1], ...]
280+ cumulative_indices = torch .cumsum (
281+ torch .cat (
282+ (
283+ torch .zeros ((1 ,), dtype = index_dtype , device = lengths .device ),
284+ lengths ,
285+ )
286+ ),
287+ dim = 0 ,
288+ )
289+
290+ # Verify each feature's data was correctly permuted
291+ for i in range (T ):
292+ # Get the original feature index that should appear at position i in the permuted output
293+ permuted_index = permute [i ]
294+
295+ # Assert: Verify that the indices for this feature were correctly copied
296+ # Compare the segment in the permuted output against the original segment
297+ self .assertTrue (
298+ torch .equal (
299+ values_actual [
300+ permuted_cumulated_index : permuted_cumulated_index
301+ + lengths [permuted_index ]
302+ ],
303+ indices [
304+ cumulative_indices [permuted_index ] : lengths [permuted_index ]
305+ + cumulative_indices [permuted_index ]
306+ ],
307+ )
308+ )
309+
310+ # Assert: If weights are present, verify they were also correctly permuted
311+ if has_weight and weights is not None :
312+ self .assertTrue (
313+ torch .equal (
314+ weights_actual [
315+ permuted_cumulated_index : permuted_cumulated_index
316+ + lengths [permuted_index ]
317+ ],
318+ weights [
319+ cumulative_indices [permuted_index ] : lengths [permuted_index ]
320+ + cumulative_indices [permuted_index ]
321+ ],
322+ )
323+ )
324+ else :
325+ # Assert: If no weights were provided, ensure the output also has no weights
326+ assert weights_actual is None
327+
328+ # Move to the next segment in the permuted output
329+ permuted_cumulated_index += lengths [permuted_index ]
330+
331+
196332class Permute2DSparseFeaturesTest (unittest .TestCase ):
197333 @unittest .skipIf (* gpu_unavailable )
198334 def test_permute_2D_sparse_data (self ) -> None :
0 commit comments