Skip to content

Commit 8886643

Browse files
authored
Merge pull request #5 from dapper91/compare-feature
- sorting custom comparator feature implemented.
2 parents 564bd43 + 8b65fe5 commit 8886643

File tree

3 files changed

+145
-36
lines changed

3 files changed

+145
-36
lines changed

src/main.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,15 @@ fn main() {
6363
}
6464
};
6565

66-
let sorted_stream = match sorter.sort(input_stream.lines()) {
66+
let compare = |a: &String, b: &String| {
67+
if order == Order::Asc {
68+
a.cmp(&b)
69+
} else {
70+
a.cmp(&b).reverse()
71+
}
72+
};
73+
74+
let sorted_stream = match sorter.sort_by(input_stream.lines(), compare) {
6775
Ok(sorted_stream) => sorted_stream,
6876
Err(err) => {
6977
log::error!("data sorting error: {}", err);
@@ -115,7 +123,7 @@ impl std::str::FromStr for LogLevel {
115123
}
116124
}
117125

118-
#[derive(Copy, Clone, clap::ArgEnum)]
126+
#[derive(Copy, Clone, PartialEq, clap::ArgEnum)]
119127
enum Order {
120128
Asc,
121129
Desc,

src/merger.rs

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,88 @@
11
//! Binary heap merger.
22
3+
use std::cmp::Ordering;
34
use std::collections::BinaryHeap;
45
use std::error::Error;
56

7+
/// Value wrapper binding custom compare function to a value.
8+
struct OrderedWrapper<T, F>
9+
where
10+
F: Fn(&T, &T) -> Ordering,
11+
{
12+
value: T,
13+
compare: F,
14+
}
15+
16+
impl<T, F> OrderedWrapper<T, F>
17+
where
18+
F: Fn(&T, &T) -> Ordering,
19+
{
20+
fn wrap(value: T, compare: F) -> Self {
21+
OrderedWrapper { value, compare }
22+
}
23+
24+
fn unwrap(self) -> T {
25+
self.value
26+
}
27+
}
28+
29+
impl<T, F> PartialEq for OrderedWrapper<T, F>
30+
where
31+
F: Fn(&T, &T) -> Ordering,
32+
{
33+
fn eq(&self, other: &Self) -> bool {
34+
self.cmp(other) == Ordering::Equal
35+
}
36+
}
37+
38+
impl<T, F> Eq for OrderedWrapper<T, F> where F: Fn(&T, &T) -> Ordering {}
39+
40+
impl<T, F> PartialOrd for OrderedWrapper<T, F>
41+
where
42+
F: Fn(&T, &T) -> Ordering,
43+
{
44+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45+
Some(self.cmp(other))
46+
}
47+
}
48+
impl<T, F> Ord for OrderedWrapper<T, F>
49+
where
50+
F: Fn(&T, &T) -> Ordering,
51+
{
52+
fn cmp(&self, other: &Self) -> Ordering {
53+
(self.compare)(&self.value, &other.value)
54+
}
55+
}
56+
657
/// Binary heap merger implementation.
758
/// Merges multiple sorted inputs into a single sorted output.
859
/// Time complexity is *m* \* log(*n*) in worst case where *m* is the number of items,
960
/// *n* is the number of chunks (inputs).
10-
pub struct BinaryHeapMerger<T, E, C>
61+
pub struct BinaryHeapMerger<T, E, F, C>
1162
where
12-
T: Ord,
1363
E: Error,
64+
F: Fn(&T, &T) -> Ordering,
1465
C: IntoIterator<Item = Result<T, E>>,
1566
{
1667
// binary heap is max-heap by default so we reverse it to convert it to min-heap
17-
items: BinaryHeap<(std::cmp::Reverse<T>, usize)>,
68+
items: BinaryHeap<(std::cmp::Reverse<OrderedWrapper<T, F>>, usize)>,
1869
chunks: Vec<C::IntoIter>,
1970
initiated: bool,
71+
compare: F,
2072
}
2173

22-
impl<T, E, C> BinaryHeapMerger<T, E, C>
74+
impl<T, E, F, C> BinaryHeapMerger<T, E, F, C>
2375
where
24-
T: Ord,
2576
E: Error,
77+
F: Fn(&T, &T) -> Ordering,
2678
C: IntoIterator<Item = Result<T, E>>,
2779
{
2880
/// Creates an instance of a binary heap merger using chunks as inputs.
2981
/// Chunk items should be sorted in ascending order otherwise the result is undefined.
3082
///
3183
/// # Arguments
3284
/// * `chunks` - Chunks to be merged in a single sorted one
33-
pub fn new<I>(chunks: I) -> Self
85+
pub fn new<I>(chunks: I, compare: F) -> Self
3486
where
3587
I: IntoIterator<Item = C>,
3688
{
@@ -40,15 +92,16 @@ where
4092
return BinaryHeapMerger {
4193
chunks,
4294
items,
95+
compare,
4396
initiated: false,
4497
};
4598
}
4699
}
47100

48-
impl<T, E, C> Iterator for BinaryHeapMerger<T, E, C>
101+
impl<T, E, F, C> Iterator for BinaryHeapMerger<T, E, F, C>
49102
where
50-
T: Ord,
51103
E: Error,
104+
F: Fn(&T, &T) -> Ordering + Copy,
52105
C: IntoIterator<Item = Result<T, E>>,
53106
{
54107
type Item = Result<T, E>;
@@ -59,7 +112,9 @@ where
59112
for (idx, chunk) in self.chunks.iter_mut().enumerate() {
60113
if let Some(item) = chunk.next() {
61114
match item {
62-
Ok(item) => self.items.push((std::cmp::Reverse(item), idx)),
115+
Ok(item) => self
116+
.items
117+
.push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)),
63118
Err(err) => return Some(Err(err)),
64119
}
65120
}
@@ -70,12 +125,14 @@ where
70125
let (result, idx) = self.items.pop()?;
71126
if let Some(item) = self.chunks[idx].next() {
72127
match item {
73-
Ok(item) => self.items.push((std::cmp::Reverse(item), idx)),
128+
Ok(item) => self
129+
.items
130+
.push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)),
74131
Err(err) => return Some(Err(err)),
75132
}
76133
}
77134

78-
return Some(Ok(result.0));
135+
return Some(Ok(result.0.unwrap()));
79136
}
80137
}
81138

@@ -131,7 +188,7 @@ mod test {
131188
#[case] chunks: Vec<Vec<Result<i32, io::Error>>>,
132189
#[case] expected_result: Vec<Result<i32, io::Error>>,
133190
) {
134-
let merger = BinaryHeapMerger::new(chunks);
191+
let merger = BinaryHeapMerger::new(chunks, i32::cmp);
135192
let actual_result = merger.collect();
136193
assert!(
137194
compare_vectors_of_result::<_, io::Error>(&actual_result, &expected_result),

src/sort.rs

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! External sorter.
22
33
use log;
4+
use std::cmp::Ordering;
45
use std::error::Error;
56
use std::fmt;
67
use std::fmt::{Debug, Display};
@@ -64,7 +65,7 @@ impl<S: Error, D: Error, I: Error> Display for SortError<S, D, I> {
6465
#[derive(Clone)]
6566
pub struct ExternalSorterBuilder<T, E, B = LimitedBufferBuilder, C = RmpExternalChunk<T>>
6667
where
67-
T: Ord + Send,
68+
T: Send,
6869
E: Error,
6970
B: ChunkBufferBuilder<T>,
7071
C: ExternalChunk<T>,
@@ -88,7 +89,7 @@ where
8889

8990
impl<T, E, B, C> ExternalSorterBuilder<T, E, B, C>
9091
where
91-
T: Ord + Send,
92+
T: Send,
9293
E: Error,
9394
B: ChunkBufferBuilder<T>,
9495
C: ExternalChunk<T>,
@@ -137,7 +138,7 @@ where
137138

138139
impl<T, E, B, C> Default for ExternalSorterBuilder<T, E, B, C>
139140
where
140-
T: Ord + Send,
141+
T: Send,
141142
E: Error,
142143
B: ChunkBufferBuilder<T>,
143144
C: ExternalChunk<T>,
@@ -158,7 +159,7 @@ where
158159
/// External sorter.
159160
pub struct ExternalSorter<T, E, B = LimitedBufferBuilder, C = RmpExternalChunk<T>>
160161
where
161-
T: Ord + Send,
162+
T: Send,
162163
E: Error,
163164
B: ChunkBufferBuilder<T>,
164165
C: ExternalChunk<T>,
@@ -182,7 +183,7 @@ where
182183

183184
impl<T, E, B, C> ExternalSorter<T, E, B, C>
184185
where
185-
T: Ord + Send,
186+
T: Send,
186187
E: Error,
187188
B: ChunkBufferBuilder<T>,
188189
C: ExternalChunk<T>,
@@ -246,17 +247,42 @@ where
246247
return Ok(tmp_dir);
247248
}
248249

249-
/// Sorts data from input using external sort algorithm.
250+
/// Sorts data from the input.
250251
/// Returns an iterator that can be used to get sorted data stream.
252+
///
253+
/// # Arguments
254+
/// * `input` - Input stream data to be fetched from
251255
pub fn sort<I>(
252256
&self,
253257
input: I,
254258
) -> Result<
255-
BinaryHeapMerger<T, C::DeserializationError, C>,
259+
BinaryHeapMerger<T, C::DeserializationError, impl Fn(&T, &T) -> Ordering + Copy, C>,
260+
SortError<C::SerializationError, C::DeserializationError, E>,
261+
>
262+
where
263+
T: Ord,
264+
I: IntoIterator<Item = Result<T, E>>,
265+
{
266+
self.sort_by(input, T::cmp)
267+
}
268+
269+
/// Sorts data from the input using a custom compare function.
270+
/// Returns an iterator that can be used to get sorted data stream.
271+
///
272+
/// # Arguments
273+
/// * `input` - Input stream data to be fetched from
274+
/// * `compare` - Function be be used to compare items
275+
pub fn sort_by<I, F>(
276+
&self,
277+
input: I,
278+
compare: F,
279+
) -> Result<
280+
BinaryHeapMerger<T, C::DeserializationError, F, C>,
256281
SortError<C::SerializationError, C::DeserializationError, E>,
257282
>
258283
where
259284
I: IntoIterator<Item = Result<T, E>>,
285+
F: Fn(&T, &T) -> Ordering + Sync + Send + Copy,
260286
{
261287
let mut chunk_buf = self.buffer_builder.build();
262288
let mut external_chunks = Vec::new();
@@ -268,34 +294,39 @@ where
268294
}
269295

270296
if chunk_buf.is_full() {
271-
external_chunks.push(self.create_chunk(chunk_buf)?);
297+
external_chunks.push(self.create_chunk(chunk_buf, compare)?);
272298
chunk_buf = self.buffer_builder.build();
273299
}
274300
}
275301

276302
if chunk_buf.len() > 0 {
277-
external_chunks.push(self.create_chunk(chunk_buf)?);
303+
external_chunks.push(self.create_chunk(chunk_buf, compare)?);
278304
}
279305

280306
log::debug!("external sort preparation done");
281307

282-
return Ok(BinaryHeapMerger::new(external_chunks));
308+
return Ok(BinaryHeapMerger::new(external_chunks, compare));
283309
}
284310

285-
fn create_chunk(
311+
fn create_chunk<F>(
286312
&self,
287-
mut chunk: impl ChunkBuffer<T>,
288-
) -> Result<C, SortError<C::SerializationError, C::DeserializationError, E>> {
313+
mut buffer: impl ChunkBuffer<T>,
314+
compare: F,
315+
) -> Result<C, SortError<C::SerializationError, C::DeserializationError, E>>
316+
where
317+
F: Fn(&T, &T) -> Ordering + Sync + Send,
318+
{
289319
log::debug!("sorting chunk data ...");
290320
self.thread_pool.install(|| {
291-
chunk.par_sort();
321+
buffer.par_sort_by(compare);
292322
});
293323

294324
log::debug!("saving chunk data");
295-
let external_chunk = ExternalChunk::build(&self.tmp_dir, chunk, self.rw_buf_size).map_err(|err| match err {
296-
ExternalChunkError::IO(err) => SortError::IO(err),
297-
ExternalChunkError::SerializationError(err) => SortError::SerializationError(err),
298-
})?;
325+
let external_chunk =
326+
ExternalChunk::build(&self.tmp_dir, buffer, self.rw_buf_size).map_err(|err| match err {
327+
ExternalChunkError::IO(err) => SortError::IO(err),
328+
ExternalChunkError::SerializationError(err) => SortError::SerializationError(err),
329+
})?;
299330

300331
return Ok(external_chunk);
301332
}
@@ -307,11 +338,14 @@ mod test {
307338
use std::path::Path;
308339

309340
use rand::seq::SliceRandom;
341+
use rstest::*;
310342

311343
use super::{ExternalSorter, ExternalSorterBuilder, LimitedBufferBuilder};
312344

313-
#[test]
314-
fn test_external_sorter() {
345+
#[rstest]
346+
#[case(false)]
347+
#[case(true)]
348+
fn test_external_sorter(#[case] reversed: bool) {
315349
let input_sorted = 0..100;
316350

317351
let mut input: Vec<Result<i32, io::Error>> = Vec::from_iter(input_sorted.clone().map(|item| Ok(item)));
@@ -324,11 +358,21 @@ mod test {
324358
.build()
325359
.unwrap();
326360

327-
let result = sorter.sort(input).unwrap();
361+
let compare = if reversed {
362+
|a: &i32, b: &i32| a.cmp(b).reverse()
363+
} else {
364+
|a: &i32, b: &i32| a.cmp(b)
365+
};
366+
367+
let result = sorter.sort_by(input, compare).unwrap();
328368

329369
let actual_result: Result<Vec<i32>, _> = result.collect();
330370
let actual_result = actual_result.unwrap();
331-
let expected_result = Vec::from_iter(input_sorted.clone());
371+
let expected_result = if reversed {
372+
Vec::from_iter(input_sorted.clone().rev())
373+
} else {
374+
Vec::from_iter(input_sorted.clone())
375+
};
332376

333377
assert_eq!(actual_result, expected_result)
334378
}

0 commit comments

Comments
 (0)