@@ -23,7 +23,8 @@ use std::str::FromStr;
2323use std:: sync:: Arc ;
2424
2525use arrow_arith:: boolean:: { and, and_kleene, is_not_null, is_null, not, or, or_kleene} ;
26- use arrow_array:: { Array , ArrayRef , BooleanArray , RecordBatch } ;
26+ use arrow_array:: { Array , ArrayRef , BooleanArray , Datum as ArrowDatum , RecordBatch , Scalar } ;
27+ use arrow_cast:: cast:: cast;
2728use arrow_ord:: cmp:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
2829use arrow_schema:: {
2930 ArrowError , DataType , FieldRef , Schema as ArrowSchema , SchemaRef as ArrowSchemaRef ,
@@ -1106,6 +1107,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11061107
11071108 Ok ( Box :: new ( move |batch| {
11081109 let left = project_column ( & batch, idx) ?;
1110+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11091111 lt ( & left, literal. as_ref ( ) )
11101112 } ) )
11111113 } else {
@@ -1125,6 +1127,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11251127
11261128 Ok ( Box :: new ( move |batch| {
11271129 let left = project_column ( & batch, idx) ?;
1130+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11281131 lt_eq ( & left, literal. as_ref ( ) )
11291132 } ) )
11301133 } else {
@@ -1144,6 +1147,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11441147
11451148 Ok ( Box :: new ( move |batch| {
11461149 let left = project_column ( & batch, idx) ?;
1150+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11471151 gt ( & left, literal. as_ref ( ) )
11481152 } ) )
11491153 } else {
@@ -1163,6 +1167,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11631167
11641168 Ok ( Box :: new ( move |batch| {
11651169 let left = project_column ( & batch, idx) ?;
1170+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11661171 gt_eq ( & left, literal. as_ref ( ) )
11671172 } ) )
11681173 } else {
@@ -1182,6 +1187,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11821187
11831188 Ok ( Box :: new ( move |batch| {
11841189 let left = project_column ( & batch, idx) ?;
1190+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11851191 eq ( & left, literal. as_ref ( ) )
11861192 } ) )
11871193 } else {
@@ -1201,6 +1207,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12011207
12021208 Ok ( Box :: new ( move |batch| {
12031209 let left = project_column ( & batch, idx) ?;
1210+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
12041211 neq ( & left, literal. as_ref ( ) )
12051212 } ) )
12061213 } else {
@@ -1220,6 +1227,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12201227
12211228 Ok ( Box :: new ( move |batch| {
12221229 let left = project_column ( & batch, idx) ?;
1230+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
12231231 starts_with ( & left, literal. as_ref ( ) )
12241232 } ) )
12251233 } else {
@@ -1239,7 +1247,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12391247
12401248 Ok ( Box :: new ( move |batch| {
12411249 let left = project_column ( & batch, idx) ?;
1242-
1250+ let literal = try_cast_literal ( & literal , left . data_type ( ) ) ? ;
12431251 // update here if arrow ever adds a native not_starts_with
12441252 not ( & starts_with ( & left, literal. as_ref ( ) ) ?)
12451253 } ) )
@@ -1264,8 +1272,10 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12641272 Ok ( Box :: new ( move |batch| {
12651273 // update this if arrow ever adds a native is_in kernel
12661274 let left = project_column ( & batch, idx) ?;
1275+
12671276 let mut acc = BooleanArray :: from ( vec ! [ false ; batch. num_rows( ) ] ) ;
12681277 for literal in & literals {
1278+ let literal = try_cast_literal ( literal, left. data_type ( ) ) ?;
12691279 acc = or ( & acc, & eq ( & left, literal. as_ref ( ) ) ?) ?
12701280 }
12711281
@@ -1294,6 +1304,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12941304 let left = project_column ( & batch, idx) ?;
12951305 let mut acc = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
12961306 for literal in & literals {
1307+ let literal = try_cast_literal ( literal, left. data_type ( ) ) ?;
12971308 acc = and ( & acc, & neq ( & left, literal. as_ref ( ) ) ?) ?
12981309 }
12991310
@@ -1387,14 +1398,35 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
13871398 }
13881399}
13891400
1401+ /// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1402+ /// that Iceberg uses for literals - but they are effectively the same logical type,
1403+ /// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1404+ ///
1405+ /// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1406+ /// into the type of the batch we read from Parquet before sending it to the compute kernel.
1407+ fn try_cast_literal (
1408+ literal : & Arc < dyn ArrowDatum + Send + Sync > ,
1409+ column_type : & DataType ,
1410+ ) -> std:: result:: Result < Arc < dyn ArrowDatum + Send + Sync > , ArrowError > {
1411+ let literal_array = literal. get ( ) . 0 ;
1412+
1413+ // No cast required
1414+ if literal_array. data_type ( ) == column_type {
1415+ return Ok ( Arc :: clone ( literal) ) ;
1416+ }
1417+
1418+ let literal_array = cast ( literal_array, column_type) ?;
1419+ Ok ( Arc :: new ( Scalar :: new ( literal_array) ) )
1420+ }
1421+
13901422#[ cfg( test) ]
13911423mod tests {
13921424 use std:: collections:: { HashMap , HashSet } ;
13931425 use std:: fs:: File ;
13941426 use std:: sync:: Arc ;
13951427
13961428 use arrow_array:: cast:: AsArray ;
1397- use arrow_array:: { ArrayRef , RecordBatch , StringArray } ;
1429+ use arrow_array:: { ArrayRef , LargeStringArray , RecordBatch , StringArray } ;
13981430 use arrow_schema:: { DataType , Field , Schema as ArrowSchema , TimeUnit } ;
13991431 use futures:: TryStreamExt ;
14001432 use parquet:: arrow:: arrow_reader:: { RowSelection , RowSelector } ;
@@ -1590,7 +1622,8 @@ message schema {
15901622 // Expected: [NULL, "foo"].
15911623 let expected = vec ! [ None , Some ( "foo" . to_string( ) ) ] ;
15921624
1593- let ( file_io, schema, table_location, _temp_dir) = setup_kleene_logic ( data_for_col_a) ;
1625+ let ( file_io, schema, table_location, _temp_dir) =
1626+ setup_kleene_logic ( data_for_col_a, DataType :: Utf8 ) ;
15941627 let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
15951628
15961629 let result_data = test_perform_read ( predicate, schema, table_location, reader) . await ;
@@ -1611,14 +1644,88 @@ message schema {
16111644 // Expected: ["bar"].
16121645 let expected = vec ! [ Some ( "bar" . to_string( ) ) ] ;
16131646
1614- let ( file_io, schema, table_location, _temp_dir) = setup_kleene_logic ( data_for_col_a) ;
1647+ let ( file_io, schema, table_location, _temp_dir) =
1648+ setup_kleene_logic ( data_for_col_a, DataType :: Utf8 ) ;
16151649 let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
16161650
16171651 let result_data = test_perform_read ( predicate, schema, table_location, reader) . await ;
16181652
16191653 assert_eq ! ( result_data, expected) ;
16201654 }
16211655
1656+ #[ tokio:: test]
1657+ async fn test_predicate_cast_literal ( ) {
1658+ let predicates = vec ! [
1659+ // a == 'foo'
1660+ ( Reference :: new( "a" ) . equal_to( Datum :: string( "foo" ) ) , vec![
1661+ Some ( "foo" . to_string( ) ) ,
1662+ ] ) ,
1663+ // a != 'foo'
1664+ (
1665+ Reference :: new( "a" ) . not_equal_to( Datum :: string( "foo" ) ) ,
1666+ vec![ Some ( "bar" . to_string( ) ) ] ,
1667+ ) ,
1668+ // STARTS_WITH(a, 'foo')
1669+ ( Reference :: new( "a" ) . starts_with( Datum :: string( "f" ) ) , vec![
1670+ Some ( "foo" . to_string( ) ) ,
1671+ ] ) ,
1672+ // NOT STARTS_WITH(a, 'foo')
1673+ (
1674+ Reference :: new( "a" ) . not_starts_with( Datum :: string( "f" ) ) ,
1675+ vec![ Some ( "bar" . to_string( ) ) ] ,
1676+ ) ,
1677+ // a < 'foo'
1678+ ( Reference :: new( "a" ) . less_than( Datum :: string( "foo" ) ) , vec![
1679+ Some ( "bar" . to_string( ) ) ,
1680+ ] ) ,
1681+ // a <= 'foo'
1682+ (
1683+ Reference :: new( "a" ) . less_than_or_equal_to( Datum :: string( "foo" ) ) ,
1684+ vec![ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ,
1685+ ) ,
1686+ // a > 'foo'
1687+ (
1688+ Reference :: new( "a" ) . greater_than( Datum :: string( "bar" ) ) ,
1689+ vec![ Some ( "foo" . to_string( ) ) ] ,
1690+ ) ,
1691+ // a >= 'foo'
1692+ (
1693+ Reference :: new( "a" ) . greater_than_or_equal_to( Datum :: string( "foo" ) ) ,
1694+ vec![ Some ( "foo" . to_string( ) ) ] ,
1695+ ) ,
1696+ // a IN ('foo', 'bar')
1697+ (
1698+ Reference :: new( "a" ) . is_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1699+ vec![ Some ( "foo" . to_string( ) ) ] ,
1700+ ) ,
1701+ // a NOT IN ('foo', 'bar')
1702+ (
1703+ Reference :: new( "a" ) . is_not_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1704+ vec![ Some ( "bar" . to_string( ) ) ] ,
1705+ ) ,
1706+ ] ;
1707+
1708+ // Table data: ["foo", "bar"]
1709+ let data_for_col_a = vec ! [ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ;
1710+
1711+ let ( file_io, schema, table_location, _temp_dir) =
1712+ setup_kleene_logic ( data_for_col_a, DataType :: LargeUtf8 ) ;
1713+ let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1714+
1715+ for ( predicate, expected) in predicates {
1716+ println ! ( "testing predicate {predicate}" ) ;
1717+ let result_data = test_perform_read (
1718+ predicate. clone ( ) ,
1719+ schema. clone ( ) ,
1720+ table_location. clone ( ) ,
1721+ reader. clone ( ) ,
1722+ )
1723+ . await ;
1724+
1725+ assert_eq ! ( result_data, expected, "predicate={predicate}" ) ;
1726+ }
1727+ }
1728+
16221729 async fn test_perform_read (
16231730 predicate : Predicate ,
16241731 schema : SchemaRef ,
@@ -1661,6 +1768,7 @@ message schema {
16611768
16621769 fn setup_kleene_logic (
16631770 data_for_col_a : Vec < Option < String > > ,
1771+ col_a_type : DataType ,
16641772 ) -> ( FileIO , SchemaRef , String , TempDir ) {
16651773 let schema = Arc :: new (
16661774 Schema :: builder ( )
@@ -1677,7 +1785,7 @@ message schema {
16771785
16781786 let arrow_schema = Arc :: new ( ArrowSchema :: new ( vec ! [ Field :: new(
16791787 "a" ,
1680- DataType :: Utf8 ,
1788+ col_a_type . clone ( ) ,
16811789 true ,
16821790 )
16831791 . with_metadata( HashMap :: from( [ (
@@ -1690,7 +1798,11 @@ message schema {
16901798
16911799 let file_io = FileIO :: from_path ( & table_location) . unwrap ( ) . build ( ) . unwrap ( ) ;
16921800
1693- let col = Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ;
1801+ let col = match col_a_type {
1802+ DataType :: Utf8 => Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ,
1803+ DataType :: LargeUtf8 => Arc :: new ( LargeStringArray :: from ( data_for_col_a) ) as ArrayRef ,
1804+ _ => panic ! ( "unexpected col_a_type" ) ,
1805+ } ;
16941806
16951807 let to_write = RecordBatch :: try_new ( arrow_schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
16961808
0 commit comments