1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: { collections:: HashMap , sync:: Arc } ;
18+ use std:: { any :: Any , collections:: HashMap , fmt , sync:: Arc } ;
1919
2020use arrow_array:: {
2121 builder:: { Float32Builder , NullBufferBuilder } ,
22- ArrayRef , StructArray ,
22+ ArrayRef , RecordBatch , StructArray ,
2323} ;
24- use arrow_schema:: { DataType , Field , Fields } ;
24+ use arrow_schema:: { DataType , Field , Fields , Schema , SchemaRef } ;
25+ use async_trait:: async_trait;
2526use datafusion:: {
2627 config:: TableParquetOptions ,
2728 datasource:: {
28- file_format:: parquet:: ParquetSink , physical_plan:: FileSinkConfig , sink:: DataSinkExec ,
29+ file_format:: parquet:: ParquetSink ,
30+ physical_plan:: FileSinkConfig ,
31+ sink:: { DataSink , DataSinkExec } ,
2932 } ,
3033} ;
3134use datafusion_common:: {
3235 config:: ConfigOptions , exec_datafusion_err, exec_err, not_impl_err, DataFusionError , Result ,
3336} ;
37+ use datafusion_execution:: { SendableRecordBatchStream , TaskContext } ;
3438use datafusion_expr:: { dml:: InsertOp , ColumnarValue , ScalarUDF , Volatility } ;
3539use datafusion_physical_expr:: {
3640 expressions:: Column , LexRequirement , PhysicalExpr , ScalarFunctionExpr ,
3741} ;
38- use datafusion_physical_plan:: { projection:: ProjectionExec , ExecutionPlan } ;
42+ use datafusion_physical_plan:: {
43+ stream:: RecordBatchStreamAdapter , DisplayAs , DisplayFormatType , ExecutionPlan ,
44+ } ;
3945use float_next_after:: NextAfter ;
46+ use futures:: StreamExt ;
4047use geo_traits:: GeometryTrait ;
4148use sedona_common:: sedona_internal_err;
4249use sedona_expr:: scalar_udf:: { SedonaScalarKernel , SedonaScalarUDF } ;
@@ -58,7 +65,7 @@ use crate::{
5865} ;
5966
6067pub fn create_geoparquet_writer_physical_plan (
61- mut input : Arc < dyn ExecutionPlan > ,
68+ input : Arc < dyn ExecutionPlan > ,
6269 mut conf : FileSinkConfig ,
6370 order_requirements : Option < LexRequirement > ,
6471 options : & TableGeoParquetOptions ,
@@ -76,6 +83,8 @@ pub fn create_geoparquet_writer_physical_plan(
7683 // We have geometry and/or geography! Collect the GeoParquetMetadata we'll need to write
7784 let mut metadata = GeoParquetMetadata :: default ( ) ;
7885 let mut bbox_columns = HashMap :: new ( ) ;
86+ let mut bbox_projection = None ;
87+ let mut parquet_output_schema = conf. output_schema ( ) . clone ( ) ;
7988
8089 // Check the version
8190 match options. geoparquet_version {
@@ -84,9 +93,10 @@ pub fn create_geoparquet_writer_physical_plan(
8493 }
8594 GeoParquetVersion :: V1_1 => {
8695 metadata. version = "1.1.0" . to_string ( ) ;
87- ( input, bbox_columns) = project_bboxes ( input, options. overwrite_bbox_columns ) ?;
88- conf. output_schema = input. schema ( ) ;
89- output_geometry_column_indices = input. schema ( ) . geometry_column_indices ( ) ?;
96+ ( bbox_projection, bbox_columns) =
97+ project_bboxes ( & input, options. overwrite_bbox_columns ) ?;
98+ parquet_output_schema = compute_final_schema ( & bbox_projection, & input. schema ( ) ) ?;
99+ output_geometry_column_indices = conf. output_schema . geometry_column_indices ( ) ?;
90100 }
91101 _ => {
92102 return not_impl_err ! (
@@ -168,10 +178,78 @@ pub fn create_geoparquet_writer_physical_plan(
168178 ) ;
169179
170180 // Create the sink
171- let sink = Arc :: new ( ParquetSink :: new ( conf, parquet_options) ) ;
181+ let sink_input_schema = conf. output_schema ;
182+ conf. output_schema = parquet_output_schema. clone ( ) ;
183+ let sink = Arc :: new ( GeoParquetSink {
184+ inner : ParquetSink :: new ( conf, parquet_options) ,
185+ projection : bbox_projection,
186+ sink_input_schema,
187+ parquet_output_schema,
188+ } ) ;
172189 Ok ( Arc :: new ( DataSinkExec :: new ( input, sink, order_requirements) ) as _ )
173190}
174191
192+ /// Implementation of [DataSink] that computes GeoParquet 1.1 bbox columns
193+ /// if needed. This is used instead of a ProjectionExec because DataFusion's
194+ /// optimizer rules seem to rearrange the projection in ways that cause
195+ /// the plan to fail <https://github.com/apache/sedona-db/issues/379>.
196+ #[ derive( Debug ) ]
197+ struct GeoParquetSink {
198+ inner : ParquetSink ,
199+ projection : Option < Vec < ( Arc < dyn PhysicalExpr > , String ) > > ,
200+ sink_input_schema : SchemaRef ,
201+ parquet_output_schema : SchemaRef ,
202+ }
203+
204+ impl DisplayAs for GeoParquetSink {
205+ fn fmt_as ( & self , t : DisplayFormatType , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
206+ self . inner . fmt_as ( t, f)
207+ }
208+ }
209+
210+ #[ async_trait]
211+ impl DataSink for GeoParquetSink {
212+ fn as_any ( & self ) -> & dyn Any {
213+ self
214+ }
215+
216+ fn schema ( & self ) -> & SchemaRef {
217+ & self . sink_input_schema
218+ }
219+
220+ async fn write_all (
221+ & self ,
222+ data : SendableRecordBatchStream ,
223+ context : & Arc < TaskContext > ,
224+ ) -> Result < u64 > {
225+ if let Some ( projection) = & self . projection {
226+ // If we have a projection, apply it here
227+ let schema = self . parquet_output_schema . clone ( ) ;
228+ let projection = projection. clone ( ) ;
229+
230+ let data = Box :: pin ( RecordBatchStreamAdapter :: new (
231+ schema. clone ( ) ,
232+ data. map ( move |batch_result| {
233+ let schema = schema. clone ( ) ;
234+
235+ batch_result. and_then ( |batch| {
236+ let mut columns = Vec :: with_capacity ( projection. len ( ) ) ;
237+ for ( expr, _) in & projection {
238+ let col = expr. evaluate ( & batch) ?;
239+ columns. push ( col. into_array ( batch. num_rows ( ) ) ?) ;
240+ }
241+ Ok ( RecordBatch :: try_new ( schema. clone ( ) , columns) ?)
242+ } )
243+ } ) ,
244+ ) ) ;
245+
246+ self . inner . write_all ( data, context) . await
247+ } else {
248+ self . inner . write_all ( data, context) . await
249+ }
250+ }
251+ }
252+
175253/// Create a regular Parquet writer like DataFusion would otherwise do.
176254fn create_inner_writer (
177255 input : Arc < dyn ExecutionPlan > ,
@@ -184,6 +262,11 @@ fn create_inner_writer(
184262 Ok ( Arc :: new ( DataSinkExec :: new ( input, sink, order_requirements) ) as _ )
185263}
186264
265+ type ProjectBboxesResult = (
266+ Option < Vec < ( Arc < dyn PhysicalExpr > , String ) > > ,
267+ HashMap < String , String > ,
268+ ) ;
269+
187270/// Create a projection that inserts a bbox column for every geometry column
188271///
189272/// This implements creating the GeoParquet 1.1 bounding box columns,
@@ -206,9 +289,9 @@ fn create_inner_writer(
206289/// "some_col_bbox", it is unlikely that replacing it would have unintended
207290/// consequences.
208291fn project_bboxes (
209- input : Arc < dyn ExecutionPlan > ,
292+ input : & Arc < dyn ExecutionPlan > ,
210293 overwrite_bbox_columns : bool ,
211- ) -> Result < ( Arc < dyn ExecutionPlan > , HashMap < String , String > ) > {
294+ ) -> Result < ProjectBboxesResult > {
212295 let input_schema = input. schema ( ) ;
213296 let matcher = ArgMatcher :: is_geometry ( ) ;
214297 let bbox_udf: Arc < ScalarUDF > = Arc :: new ( geoparquet_bbox_udf ( ) . into ( ) ) ;
@@ -245,7 +328,7 @@ fn project_bboxes(
245328 // If we don't need to create any bbox columns, don't add an additional
246329 // projection at the end of the input plan
247330 if bbox_exprs. is_empty ( ) {
248- return Ok ( ( input , HashMap :: new ( ) ) ) ;
331+ return Ok ( ( None , HashMap :: new ( ) ) ) ;
249332 }
250333
251334 // Create the projection expressions
@@ -275,13 +358,34 @@ Use overwrite_bbox_columns = True if this is what was intended.",
275358 exprs. push ( ( column, f. name ( ) . clone ( ) ) ) ;
276359 }
277360
278- // Create the projection
279- let exec = ProjectionExec :: try_new ( exprs, input) ?;
280-
281361 // Flip the bbox_column_names into the form our caller needs it
282362 let bbox_column_names_by_field = bbox_column_names. drain ( ) . map ( |( k, v) | ( v, k) ) . collect ( ) ;
283363
284- Ok ( ( Arc :: new ( exec) , bbox_column_names_by_field) )
364+ Ok ( ( Some ( exprs) , bbox_column_names_by_field) )
365+ }
366+
367+ fn compute_final_schema (
368+ bbox_projection : & Option < Vec < ( Arc < dyn PhysicalExpr > , String ) > > ,
369+ initial_schema : & SchemaRef ,
370+ ) -> Result < SchemaRef > {
371+ if let Some ( bbox_projection) = bbox_projection {
372+ let new_fields = bbox_projection
373+ . iter ( )
374+ . map ( |( expr, name) | -> Result < Field > {
375+ let return_field_ref = expr. return_field ( initial_schema) ?;
376+ Ok ( Field :: new (
377+ name,
378+ return_field_ref. data_type ( ) . clone ( ) ,
379+ return_field_ref. is_nullable ( ) ,
380+ )
381+ . with_metadata ( return_field_ref. metadata ( ) . clone ( ) ) )
382+ } )
383+ . collect :: < Result < Vec < _ > > > ( ) ?;
384+
385+ Ok ( Arc :: new ( Schema :: new ( new_fields) ) )
386+ } else {
387+ Ok ( initial_schema. clone ( ) )
388+ }
285389}
286390
287391fn geoparquet_bbox_udf ( ) -> SedonaScalarUDF {
@@ -419,7 +523,7 @@ mod test {
419523 } ;
420524 use datafusion_common:: cast:: { as_float32_array, as_struct_array} ;
421525 use datafusion_common:: ScalarValue ;
422- use datafusion_expr:: { Expr , LogicalPlanBuilder } ;
526+ use datafusion_expr:: { Cast , Expr , LogicalPlanBuilder } ;
423527 use sedona_schema:: datatypes:: WKB_GEOMETRY ;
424528 use sedona_testing:: create:: create_array;
425529 use sedona_testing:: data:: test_geoparquet;
@@ -745,6 +849,60 @@ mod test {
745849 . unwrap ( ) ;
746850 }
747851
852+ #[ tokio:: test]
853+ async fn geoparquet_1_1_with_sort_by_expr ( ) {
854+ let example = test_geoparquet ( "ns-water" , "water-point" ) ;
855+
856+ // Requires submodules/download-assets.py which not all contributors need
857+ let example = match example {
858+ Ok ( path) => path,
859+ Err ( err) => {
860+ println ! ( "ns-water/water-point is not available: {err}" ) ;
861+ return ;
862+ }
863+ } ;
864+
865+ let ctx = setup_context ( ) ;
866+ let fns = sedona_functions:: register:: default_function_set ( ) ;
867+
868+ let geometry_udf: ScalarUDF = fns. scalar_udf ( "sd_format" ) . unwrap ( ) . clone ( ) . into ( ) ;
869+ let bbox_udf: ScalarUDF = geoparquet_bbox_udf ( ) . into ( ) ;
870+
871+ let df = ctx
872+ . table ( & example)
873+ . await
874+ . unwrap ( )
875+ . sort_by ( vec ! [ geometry_udf. call( vec![ col( "geometry" ) ] ) ] )
876+ . unwrap ( )
877+ . select ( vec ! [
878+ Expr :: Cast ( Cast :: new(
879+ geometry_udf. call( vec![ col( "geometry" ) ] ) . alias( "txt" ) . into( ) ,
880+ DataType :: Utf8View ,
881+ ) ) ,
882+ col( "geometry" ) ,
883+ ] )
884+ . unwrap ( ) ;
885+
886+ let mut options = TableGeoParquetOptions :: new ( ) ;
887+ options. geoparquet_version = GeoParquetVersion :: V1_1 ;
888+
889+ let df_batches_with_bbox = df
890+ . clone ( )
891+ . select ( vec ! [
892+ col( "txt" ) ,
893+ bbox_udf. call( vec![ col( "geometry" ) ] ) . alias( "bbox" ) ,
894+ col( "geometry" ) ,
895+ ] )
896+ . unwrap ( )
897+ . collect ( )
898+ . await
899+ . unwrap ( ) ;
900+
901+ test_write_dataframe ( ctx, df, df_batches_with_bbox, options, vec ! [ ] )
902+ . await
903+ . unwrap ( ) ;
904+ }
905+
748906 #[ test]
749907 fn float_bbox ( ) {
750908 let tester = ScalarUdfTester :: new ( geoparquet_bbox_udf ( ) . into ( ) , vec ! [ WKB_GEOMETRY ] ) ;
0 commit comments