11use arrow:: {
2- array:: { ArrayRef , UInt32Array } ,
2+ array:: { ArrayRef , Float16Array , Float32Array , Float64Array , UInt32Array } ,
33 compute:: { SortColumn , concat_batches, lexsort_to_indices} ,
44 record_batch:: RecordBatch ,
55} ;
66use datafusion:: {
77 common:: { internal_datafusion_err, internal_err} ,
88 error:: { DataFusionError , Result } ,
9- execution:: context:: SessionContext ,
109 physical_expr:: LexOrdering ,
10+ physical_plan:: ExecutionPlan ,
1111} ;
1212use std:: sync:: Arc ;
1313
14+ /// compares the set of record batches for equality
1415pub async fn compare_result_set (
15- _test_ctx : & SessionContext ,
16- compare_ctx : & SessionContext ,
17- query : & str ,
18- test_result : & Result < Vec < RecordBatch > > ,
16+ actual_result : & Result < Vec < RecordBatch > > ,
17+ expected_result : & Result < Vec < RecordBatch > > ,
1918) -> Result < ( ) > {
20- let df = compare_ctx. sql ( query) . await ?;
21- let compare_result = df. collect ( ) . await ;
22- let test_batches = match test_result. as_ref ( ) {
19+ let test_batches = match actual_result. as_ref ( ) {
2320 Ok ( batches) => batches,
2421 Err ( e) => {
25- if compare_result. is_ok ( ) {
26- return internal_err ! (
27- "query failed in test_ctx but succeeded in the compare_ctx: {}" ,
28- e
29- ) ;
22+ if expected_result. is_ok ( ) {
23+ return internal_err ! ( "expected no error but got: {}" , e) ;
3024 }
3125 return Ok ( ( ) ) ; // Both errored, so the query is valid
3226 }
3327 } ;
3428
35- let compare_batches = match compare_result . as_ref ( ) {
29+ let compare_batches = match expected_result . as_ref ( ) {
3630 Ok ( batches) => batches,
3731 Err ( e) => {
38- if compare_result. is_ok ( ) {
39- return internal_err ! (
40- "test_ctx query succeeded but failed in the compare_ctx: {}" ,
41- e
42- ) ;
32+ if actual_result. is_ok ( ) {
33+ return internal_err ! ( "expected error but got none, error: {}" , e) ;
4334 }
4435 return Ok ( ( ) ) ; // Both errored, so the query is valid
4536 }
4637 } ;
4738
48- // Compare result sets
4939 records_equal_as_sets ( test_batches, compare_batches)
50- . map_err ( |e| internal_datafusion_err ! ( "ResultSetOracle validation failed: {}" , e) ) ?;
51-
52- Ok ( ( ) )
40+ . map_err ( |e| internal_datafusion_err ! ( "result sets were not equal: {}" , e) )
5341}
5442
43+ // Ensures that the plans have the same ordering properties and that the actual result is sorted
44+ // correctly.
5545pub async fn compare_ordering (
56- ctx : & SessionContext ,
57- compare_ctx : & SessionContext ,
58- query : & str ,
59- test_result : & Result < Vec < RecordBatch > > ,
46+ actual_physical_plan : Arc < dyn ExecutionPlan > ,
47+ expected_physical_plan : Arc < dyn ExecutionPlan > ,
48+ actual_result : & Result < Vec < RecordBatch > > ,
6049) -> Result < ( ) > {
6150 // Only validate if the query succeeded
62- let test_batches = match test_result . as_ref ( ) {
51+ let test_batches = match actual_result . as_ref ( ) {
6352 Ok ( batches) => batches,
6453 Err ( _) => return Ok ( ( ) ) ,
6554 } ;
6655
67- let df = ctx. sql ( query) . await ?;
68- let physical_plan = df. create_physical_plan ( ) . await ?;
69- let actual_ordering = physical_plan. properties ( ) . output_ordering ( ) ;
70-
71- let df = compare_ctx. sql ( query) . await ?;
72- let physical_plan = df. create_physical_plan ( ) . await ?;
73- let expected_ordering = physical_plan. properties ( ) . output_ordering ( ) ;
56+ let actual_ordering = actual_physical_plan. properties ( ) . output_ordering ( ) ;
57+ let expected_ordering = expected_physical_plan. properties ( ) . output_ordering ( ) ;
7458
7559 if actual_ordering != expected_ordering {
7660 return internal_err ! (
77- "Ordering Oracle validation failed : expected ordering: {:?}, actual ordering: {:?}" ,
61+ "ordering mismatch : expected ordering: {:?}, actual ordering: {:?}" ,
7862 expected_ordering,
7963 actual_ordering
8064 ) ;
@@ -93,11 +77,10 @@ pub async fn compare_ordering(
9377 concat_batches ( & test_batches[ 0 ] . schema ( ) , test_batches) ?
9478 } ;
9579
96- // Check if the coalesced batch maintains the expected ordering
9780 let is_sorted = is_table_same_after_sort ( lex_ordering, & coalesced_batch) ?;
9881 if !is_sorted {
9982 return internal_err ! (
100- "OrderingOracle validation failed: result set is not properly sorted according to expected ordering: {:?}" ,
83+ "ordering validation failed: results are not properly sorted according to expected ordering: {:?}" ,
10184 lex_ordering
10285 ) ;
10386 }
@@ -178,7 +161,7 @@ fn detailed_batch_comparison(
178161 left_only. len( )
179162 ) ) ;
180163 for row in left_only {
181- error_msg. push_str ( & format ! ( "\n {}" , row ) ) ;
164+ error_msg. push_str ( & format ! ( "\n {row}" ) ) ;
182165 }
183166 }
184167
@@ -188,7 +171,7 @@ fn detailed_batch_comparison(
188171 right_only. len( )
189172 ) ) ;
190173 for row in right_only {
191- error_msg. push_str ( & format ! ( "\n {}" , row ) ) ;
174+ error_msg. push_str ( & format ! ( "\n {row}" ) ) ;
192175 }
193176 }
194177
@@ -213,6 +196,12 @@ fn batch_rows_to_strings(batches: &[RecordBatch]) -> Vec<String> {
213196
214197 if array. is_null ( row_idx) {
215198 row_values. push ( "NULL" . to_string ( ) ) ;
199+ } else if let Some ( arr) = array. as_any ( ) . downcast_ref :: < Float16Array > ( ) {
200+ row_values. push ( format ! ( "{:.1$}" , arr. value( row_idx) , 2 ) ) ;
201+ } else if let Some ( arr) = array. as_any ( ) . downcast_ref :: < Float32Array > ( ) {
202+ row_values. push ( format ! ( "{:.1$}" , arr. value( row_idx) , 2 ) ) ;
203+ } else if let Some ( arr) = array. as_any ( ) . downcast_ref :: < Float64Array > ( ) {
204+ row_values. push ( format ! ( "{:.1$}" , arr. value( row_idx) , 2 ) ) ;
216205 } else {
217206 // Use Arrow's deterministic string representation
218207 let value_str = array_value_to_string ( array, row_idx)
@@ -292,6 +281,8 @@ mod tests {
292281
293282 use arrow:: array:: { Int32Array , StringArray } ;
294283 use arrow:: datatypes:: { DataType , Field , Schema } ;
284+ use datafusion:: physical_plan:: collect;
285+ use datafusion:: prelude:: SessionContext ;
295286
296287 use std:: sync:: Arc ;
297288
@@ -421,7 +412,7 @@ mod tests {
421412 }
422413
423414 #[ tokio:: test]
424- async fn test_ordering_oracle ( ) {
415+ async fn test_ordering_validation ( ) {
425416 let actual_ctx = SessionContext :: new ( ) ;
426417 let expected_ctx = SessionContext :: new ( ) ;
427418
@@ -447,26 +438,28 @@ mod tests {
447438 . unwrap ( ) ;
448439
449440 // Query which sorted by id should pass.
450- let ordered_query = "SELECT * FROM test_table ORDER BY id" ;
441+ let ordered_query = "SELECT * FROM test_table ORDER BY value" ;
442+
451443 let df = actual_ctx. sql ( ordered_query) . await . unwrap ( ) ;
452- let result = df. collect ( ) . await ;
444+ let task_ctx = actual_ctx. task_ctx ( ) ;
445+ let actual_plan = df. create_physical_plan ( ) . await . unwrap ( ) ;
446+ let results = collect ( actual_plan. clone ( ) , task_ctx) . await ;
447+
448+ let df = expected_ctx. sql ( ordered_query) . await . unwrap ( ) ;
449+ let expected_plan = df. create_physical_plan ( ) . await . unwrap ( ) ;
450+
453451 assert ! (
454- compare_ordering( & actual_ctx , & expected_ctx , ordered_query , & result )
452+ compare_ordering( actual_plan . clone ( ) , expected_plan . clone ( ) , & results )
455453 . await
456454 . is_ok( )
457455 ) ;
458456
459457 // This should fail because the batch is not sorted by value
460458 let result = Ok ( vec ! [ batch] ) ;
461459 assert ! (
462- compare_ordering(
463- & actual_ctx,
464- & expected_ctx,
465- "SELECT * FROM test_table ORDER BY value" ,
466- & result
467- )
468- . await
469- . is_err( )
460+ compare_ordering( actual_plan. clone( ) , expected_plan. clone( ) , & result)
461+ . await
462+ . is_err( )
470463 ) ;
471464 }
472465}
0 commit comments