@@ -3,7 +3,6 @@ use arrow::{
33 compute:: { SortColumn , concat_batches, lexsort_to_indices} ,
44 record_batch:: RecordBatch ,
55} ;
6- use async_trait:: async_trait;
76use datafusion:: {
87 common:: { internal_datafusion_err, internal_err} ,
98 error:: { DataFusionError , Result } ,
@@ -12,188 +11,99 @@ use datafusion::{
1211} ;
1312use std:: sync:: Arc ;
1413
15- /// Validator for query outputs
16- /// - asserts that the set of rows is the same when running a query on two different execution contexts
17- /// - uses oracles to assert properties of the query output
18- pub struct Validator {
19- pub test_ctx : SessionContext ,
20- pub compare_ctx : SessionContext ,
21- pub oracles : Vec < Box < dyn Oracle + Send + Sync > > ,
22- }
23-
24- impl Validator {
25- /// Create a new Validator.
26- /// - [actual_ctx] is the context we want to test. It produces the "actual" results
27- /// - [expected_ctx] is the context we want to compare against. It produces the "expected"
28- /// results
29- pub async fn new ( test_ctx : SessionContext , compare_ctx : SessionContext ) -> Result < Self > {
30- let oracles: Vec < Box < dyn Oracle + Send + Sync > > =
31- vec ! [ Box :: new( ResultSetOracle { } ) , Box :: new( OrderingOracle { } ) ] ;
32-
33- Ok ( Validator {
34- test_ctx,
35- compare_ctx,
36- oracles,
37- } )
38- }
39-
40- /// Create a new Validator with ordering checks enabled.
41- pub async fn new_with_ordering (
42- test_ctx : SessionContext ,
43- compare_ctx : SessionContext ,
44- ) -> Result < Self > {
45- let oracles: Vec < Box < dyn Oracle + Send + Sync > > =
46- vec ! [ Box :: new( ResultSetOracle { } ) , Box :: new( OrderingOracle { } ) ] ;
47-
48- Ok ( Validator {
49- test_ctx,
50- compare_ctx,
51- oracles,
52- } )
53- }
54-
55- // runs a query and returns an error if there is a validation failure. Ok(None) is returned
56- // if the query errors in both the [actual_ctx] and [expected_ctx], otherwise the actual record batches
57- // are returned.
58- pub async fn run ( & self , query : & str ) -> Result < Option < Vec < RecordBatch > > > {
59- let result = self . run_inner ( query) . await ;
60- for oracle in & self . oracles {
61- oracle
62- . validate ( & self . test_ctx , & self . compare_ctx , query, & result)
63- . await ?;
14+ pub async fn compare_result_set (
15+ _test_ctx : & SessionContext ,
16+ compare_ctx : & SessionContext ,
17+ query : & str ,
18+ test_result : & Result < Vec < RecordBatch > > ,
19+ ) -> Result < ( ) > {
20+ let df = compare_ctx. sql ( query) . await ?;
21+ let compare_result = df. collect ( ) . await ;
22+ let test_batches = match test_result. as_ref ( ) {
23+ Ok ( batches) => batches,
24+ Err ( e) => {
25+ if compare_result. is_ok ( ) {
26+ return internal_err ! (
27+ "query failed in test_ctx but succeeded in the compare_ctx: {}" ,
28+ e
29+ ) ;
30+ }
31+ return Ok ( ( ) ) ; // Both errored, so the query is valid
6432 }
33+ } ;
6534
66- if result. is_err ( ) {
67- return Ok ( None ) ;
35+ let compare_batches = match compare_result. as_ref ( ) {
36+ Ok ( batches) => batches,
37+ Err ( e) => {
38+ if compare_result. is_ok ( ) {
39+ return internal_err ! (
40+ "test_ctx query succeeded but failed in the compare_ctx: {}" ,
41+ e
42+ ) ;
43+ }
44+ return Ok ( ( ) ) ; // Both errored, so the query is valid
6845 }
69- result. map ( Some )
70- }
46+ } ;
7147
72- async fn run_inner ( & self , query : & str ) -> Result < Vec < RecordBatch > > {
73- let df = self . test_ctx . sql ( query) . await ?;
74- df. collect ( ) . await
75- }
76- }
48+ // Compare result sets
49+ records_equal_as_sets ( test_batches, compare_batches)
50+ . map_err ( |e| internal_datafusion_err ! ( "ResultSetOracle validation failed: {}" , e) ) ?;
7751
78- /// Trait for query result validation oracles
79- #[ async_trait]
80- pub trait Oracle {
81- async fn validate (
82- & self ,
83- test_ctx : & SessionContext ,
84- compare_ctx : & SessionContext ,
85- query : & str ,
86- results : & Result < Vec < RecordBatch > > ,
87- ) -> Result < ( ) > ;
52+ Ok ( ( ) )
8853}
8954
90- /// Oracle that verifies the set of rows is the same between the test context and compare context.
91- pub struct ResultSetOracle { }
92-
93- #[ async_trait]
94- impl Oracle for ResultSetOracle {
95- async fn validate (
96- & self ,
97- _test_ctx : & SessionContext ,
98- compare_ctx : & SessionContext ,
99- query : & str ,
100- test_result : & Result < Vec < RecordBatch > > ,
101- ) -> Result < ( ) > {
102- let df = compare_ctx. sql ( query) . await ?;
103- let compare_result = df. collect ( ) . await ;
104- let test_batches = match test_result. as_ref ( ) {
105- Ok ( batches) => batches,
106- Err ( e) => {
107- if compare_result. is_ok ( ) {
108- return internal_err ! (
109- "query failed in test_ctx but succeeded in the compare_ctx: {}" ,
110- e
111- ) ;
112- }
113- return Ok ( ( ) ) ; // Both errored, so the query is valid
114- }
115- } ;
116-
117- let compare_batches = match compare_result. as_ref ( ) {
118- Ok ( batches) => batches,
119- Err ( e) => {
120- if compare_result. is_ok ( ) {
121- return internal_err ! (
122- "test_ctx query succeeded but failed in the compare_ctx: {}" ,
123- e
124- ) ;
125- }
126- return Ok ( ( ) ) ; // Both errored, so the query is valid
127- }
128- } ;
129-
130- // Compare result sets
131- records_equal_as_sets ( test_batches, compare_batches)
132- . map_err ( |e| internal_datafusion_err ! ( "ResultSetOracle validation failed: {}" , e) ) ?;
133-
134- Ok ( ( ) )
55+ pub async fn compare_ordering (
56+ ctx : & SessionContext ,
57+ compare_ctx : & SessionContext ,
58+ query : & str ,
59+ test_result : & Result < Vec < RecordBatch > > ,
60+ ) -> Result < ( ) > {
61+ // Only validate if the query succeeded
62+ let test_batches = match test_result. as_ref ( ) {
63+ Ok ( batches) => batches,
64+ Err ( _) => return Ok ( ( ) ) ,
65+ } ;
66+
67+ let df = ctx. sql ( query) . await ?;
68+ let physical_plan = df. create_physical_plan ( ) . await ?;
69+ let actual_ordering = physical_plan. properties ( ) . output_ordering ( ) ;
70+
71+ let df = compare_ctx. sql ( query) . await ?;
72+ let physical_plan = df. create_physical_plan ( ) . await ?;
73+ let expected_ordering = physical_plan. properties ( ) . output_ordering ( ) ;
74+
75+ if actual_ordering != expected_ordering {
76+ return internal_err ! (
77+ "Ordering Oracle validation failed: expected ordering: {:?}, actual ordering: {:?}" ,
78+ expected_ordering,
79+ actual_ordering
80+ ) ;
13581 }
136- }
13782
138- /// Oracle that asserts ordering based on the ordering properties from the physical plan.
139- pub struct OrderingOracle ;
140-
141- #[ async_trait]
142- impl Oracle for OrderingOracle {
143- async fn validate (
144- & self ,
145- ctx : & SessionContext ,
146- compare_ctx : & SessionContext ,
147- query : & str ,
148- test_result : & Result < Vec < RecordBatch > > ,
149- ) -> Result < ( ) > {
150- // Only validate if the query succeeded
151- let test_batches = match test_result. as_ref ( ) {
152- Ok ( batches) => batches,
153- Err ( _) => return Ok ( ( ) ) ,
83+ // If there's no ordering, there's nothing to validate.
84+ let Some ( lex_ordering) = actual_ordering else {
85+ return Ok ( ( ) ) ;
86+ } ;
87+
88+ // Coalesce all batches into a single batch to check ordering across the entire result set
89+ if !test_batches. is_empty ( ) {
90+ let coalesced_batch = if test_batches. len ( ) == 1 {
91+ test_batches[ 0 ] . clone ( )
92+ } else {
93+ concat_batches ( & test_batches[ 0 ] . schema ( ) , test_batches) ?
15494 } ;
15595
156- let df = ctx. sql ( query) . await ?;
157- let physical_plan = df. create_physical_plan ( ) . await ?;
158- let actual_ordering = physical_plan. properties ( ) . output_ordering ( ) ;
159-
160- let df = compare_ctx. sql ( query) . await ?;
161- let physical_plan = df. create_physical_plan ( ) . await ?;
162- let expected_ordering = physical_plan. properties ( ) . output_ordering ( ) ;
163-
164- if actual_ordering != expected_ordering {
96+ // Check if the coalesced batch maintains the expected ordering
97+ let is_sorted = is_table_same_after_sort ( lex_ordering, & coalesced_batch) ?;
98+ if !is_sorted {
16599 return internal_err ! (
166- "Ordering Oracle validation failed: expected ordering: {:?}, actual ordering: {:?}" ,
167- expected_ordering,
168- actual_ordering
100+ "OrderingOracle validation failed: result set is not properly sorted according to expected ordering: {:?}" ,
101+ lex_ordering
169102 ) ;
170103 }
171-
172- // If there's no ordering, there's nothing to validate.
173- let Some ( lex_ordering) = actual_ordering else {
174- return Ok ( ( ) ) ;
175- } ;
176-
177- // Coalesce all batches into a single batch to check ordering across the entire result set
178- if !test_batches. is_empty ( ) {
179- let coalesced_batch = if test_batches. len ( ) == 1 {
180- test_batches[ 0 ] . clone ( )
181- } else {
182- concat_batches ( & test_batches[ 0 ] . schema ( ) , test_batches) ?
183- } ;
184-
185- // Check if the coalesced batch maintains the expected ordering
186- let is_sorted = is_table_same_after_sort ( lex_ordering, & coalesced_batch) ?;
187- if !is_sorted {
188- return internal_err ! (
189- "OrderingOracle validation failed: result set is not properly sorted according to expected ordering: {:?}" ,
190- lex_ordering
191- ) ;
192- }
193- }
194-
195- Ok ( ( ) )
196104 }
105+
106+ Ok ( ( ) )
197107}
198108
199109/// Compare two sets of record batches for equality (order-independent)
@@ -536,31 +446,27 @@ mod tests {
536446 . register_batch ( "test_table" , batch. clone ( ) )
537447 . unwrap ( ) ;
538448
539- let oracle = OrderingOracle ;
540-
541449 // Query which sorted by id should pass.
542450 let ordered_query = "SELECT * FROM test_table ORDER BY id" ;
543451 let df = actual_ctx. sql ( ordered_query) . await . unwrap ( ) ;
544452 let result = df. collect ( ) . await ;
545453 assert ! (
546- oracle
547- . validate( & actual_ctx, & expected_ctx, ordered_query, & result)
454+ compare_ordering( & actual_ctx, & expected_ctx, ordered_query, & result)
548455 . await
549456 . is_ok( )
550457 ) ;
551458
552459 // This should fail because the batch is not sorted by value
553460 let result = Ok ( vec ! [ batch] ) ;
554461 assert ! (
555- oracle
556- . validate(
557- & actual_ctx,
558- & expected_ctx,
559- "SELECT * FROM test_table ORDER BY value" ,
560- & result
561- )
562- . await
563- . is_err( )
462+ compare_ordering(
463+ & actual_ctx,
464+ & expected_ctx,
465+ "SELECT * FROM test_table ORDER BY value" ,
466+ & result
467+ )
468+ . await
469+ . is_err( )
564470 ) ;
565471 }
566472}
0 commit comments