11//! External sorter.
22
33use log;
4+ use std:: cmp:: Ordering ;
45use std:: error:: Error ;
56use std:: fmt;
67use std:: fmt:: { Debug , Display } ;
@@ -64,7 +65,7 @@ impl<S: Error, D: Error, I: Error> Display for SortError<S, D, I> {
6465#[ derive( Clone ) ]
6566pub struct ExternalSorterBuilder < T , E , B = LimitedBufferBuilder , C = RmpExternalChunk < T > >
6667where
67- T : Ord + Send ,
68+ T : Send ,
6869 E : Error ,
6970 B : ChunkBufferBuilder < T > ,
7071 C : ExternalChunk < T > ,
8889
8990impl < T , E , B , C > ExternalSorterBuilder < T , E , B , C >
9091where
91- T : Ord + Send ,
92+ T : Send ,
9293 E : Error ,
9394 B : ChunkBufferBuilder < T > ,
9495 C : ExternalChunk < T > ,
@@ -137,7 +138,7 @@ where
137138
138139impl < T , E , B , C > Default for ExternalSorterBuilder < T , E , B , C >
139140where
140- T : Ord + Send ,
141+ T : Send ,
141142 E : Error ,
142143 B : ChunkBufferBuilder < T > ,
143144 C : ExternalChunk < T > ,
@@ -158,7 +159,7 @@ where
158159/// External sorter.
159160pub struct ExternalSorter < T , E , B = LimitedBufferBuilder , C = RmpExternalChunk < T > >
160161where
161- T : Ord + Send ,
162+ T : Send ,
162163 E : Error ,
163164 B : ChunkBufferBuilder < T > ,
164165 C : ExternalChunk < T > ,
@@ -182,7 +183,7 @@ where
182183
183184impl < T , E , B , C > ExternalSorter < T , E , B , C >
184185where
185- T : Ord + Send ,
186+ T : Send ,
186187 E : Error ,
187188 B : ChunkBufferBuilder < T > ,
188189 C : ExternalChunk < T > ,
@@ -246,17 +247,42 @@ where
246247 return Ok ( tmp_dir) ;
247248 }
248249
249- /// Sorts data from input using external sort algorithm .
250+ /// Sorts data from the input .
250251 /// Returns an iterator that can be used to get sorted data stream.
252+ ///
253+ /// # Arguments
254+ /// * `input` - Input stream data to be fetched from
251255 pub fn sort < I > (
252256 & self ,
253257 input : I ,
254258 ) -> Result <
255- BinaryHeapMerger < T , C :: DeserializationError , C > ,
259+ BinaryHeapMerger < T , C :: DeserializationError , impl Fn ( & T , & T ) -> Ordering + Copy , C > ,
260+ SortError < C :: SerializationError , C :: DeserializationError , E > ,
261+ >
262+ where
263+ T : Ord ,
264+ I : IntoIterator < Item = Result < T , E > > ,
265+ {
266+ self . sort_by ( input, T :: cmp)
267+ }
268+
269+ /// Sorts data from the input using a custom compare function.
270+ /// Returns an iterator that can be used to get sorted data stream.
271+ ///
272+ /// # Arguments
273+ /// * `input` - Input stream data to be fetched from
274+ /// * `compare` - Function be be used to compare items
275+ pub fn sort_by < I , F > (
276+ & self ,
277+ input : I ,
278+ compare : F ,
279+ ) -> Result <
280+ BinaryHeapMerger < T , C :: DeserializationError , F , C > ,
256281 SortError < C :: SerializationError , C :: DeserializationError , E > ,
257282 >
258283 where
259284 I : IntoIterator < Item = Result < T , E > > ,
285+ F : Fn ( & T , & T ) -> Ordering + Sync + Send + Copy ,
260286 {
261287 let mut chunk_buf = self . buffer_builder . build ( ) ;
262288 let mut external_chunks = Vec :: new ( ) ;
@@ -268,34 +294,39 @@ where
268294 }
269295
270296 if chunk_buf. is_full ( ) {
271- external_chunks. push ( self . create_chunk ( chunk_buf) ?) ;
297+ external_chunks. push ( self . create_chunk ( chunk_buf, compare ) ?) ;
272298 chunk_buf = self . buffer_builder . build ( ) ;
273299 }
274300 }
275301
276302 if chunk_buf. len ( ) > 0 {
277- external_chunks. push ( self . create_chunk ( chunk_buf) ?) ;
303+ external_chunks. push ( self . create_chunk ( chunk_buf, compare ) ?) ;
278304 }
279305
280306 log:: debug!( "external sort preparation done" ) ;
281307
282- return Ok ( BinaryHeapMerger :: new ( external_chunks) ) ;
308+ return Ok ( BinaryHeapMerger :: new ( external_chunks, compare ) ) ;
283309 }
284310
285- fn create_chunk (
311+ fn create_chunk < F > (
286312 & self ,
287- mut chunk : impl ChunkBuffer < T > ,
288- ) -> Result < C , SortError < C :: SerializationError , C :: DeserializationError , E > > {
313+ mut buffer : impl ChunkBuffer < T > ,
314+ compare : F ,
315+ ) -> Result < C , SortError < C :: SerializationError , C :: DeserializationError , E > >
316+ where
317+ F : Fn ( & T , & T ) -> Ordering + Sync + Send ,
318+ {
289319 log:: debug!( "sorting chunk data ..." ) ;
290320 self . thread_pool . install ( || {
291- chunk . par_sort ( ) ;
321+ buffer . par_sort_by ( compare ) ;
292322 } ) ;
293323
294324 log:: debug!( "saving chunk data" ) ;
295- let external_chunk = ExternalChunk :: build ( & self . tmp_dir , chunk, self . rw_buf_size ) . map_err ( |err| match err {
296- ExternalChunkError :: IO ( err) => SortError :: IO ( err) ,
297- ExternalChunkError :: SerializationError ( err) => SortError :: SerializationError ( err) ,
298- } ) ?;
325+ let external_chunk =
326+ ExternalChunk :: build ( & self . tmp_dir , buffer, self . rw_buf_size ) . map_err ( |err| match err {
327+ ExternalChunkError :: IO ( err) => SortError :: IO ( err) ,
328+ ExternalChunkError :: SerializationError ( err) => SortError :: SerializationError ( err) ,
329+ } ) ?;
299330
300331 return Ok ( external_chunk) ;
301332 }
@@ -307,11 +338,14 @@ mod test {
307338 use std:: path:: Path ;
308339
309340 use rand:: seq:: SliceRandom ;
341+ use rstest:: * ;
310342
311343 use super :: { ExternalSorter , ExternalSorterBuilder , LimitedBufferBuilder } ;
312344
313- #[ test]
314- fn test_external_sorter ( ) {
345+ #[ rstest]
346+ #[ case( false ) ]
347+ #[ case( true ) ]
348+ fn test_external_sorter ( #[ case] reversed : bool ) {
315349 let input_sorted = 0 ..100 ;
316350
317351 let mut input: Vec < Result < i32 , io:: Error > > = Vec :: from_iter ( input_sorted. clone ( ) . map ( |item| Ok ( item) ) ) ;
@@ -324,11 +358,21 @@ mod test {
324358 . build ( )
325359 . unwrap ( ) ;
326360
327- let result = sorter. sort ( input) . unwrap ( ) ;
361+ let compare = if reversed {
362+ |a : & i32 , b : & i32 | a. cmp ( b) . reverse ( )
363+ } else {
364+ |a : & i32 , b : & i32 | a. cmp ( b)
365+ } ;
366+
367+ let result = sorter. sort_by ( input, compare) . unwrap ( ) ;
328368
329369 let actual_result: Result < Vec < i32 > , _ > = result. collect ( ) ;
330370 let actual_result = actual_result. unwrap ( ) ;
331- let expected_result = Vec :: from_iter ( input_sorted. clone ( ) ) ;
371+ let expected_result = if reversed {
372+ Vec :: from_iter ( input_sorted. clone ( ) . rev ( ) )
373+ } else {
374+ Vec :: from_iter ( input_sorted. clone ( ) )
375+ } ;
332376
333377 assert_eq ! ( actual_result, expected_result)
334378 }
0 commit comments