Skip to content

Commit 69f89a9

Browse files
committed
fix(rust/sedona-geoparquet): Don't use ProjectionExec to create GeoParquet 1.1 bounding box columns (#398)
1 parent 6258163 commit 69f89a9

File tree

1 file changed

+176
-18
lines changed

1 file changed

+176
-18
lines changed

rust/sedona-geoparquet/src/writer.rs

Lines changed: 176 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,35 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::{collections::HashMap, sync::Arc};
18+
use std::{any::Any, collections::HashMap, fmt, sync::Arc};
1919

2020
use arrow_array::{
2121
builder::{Float32Builder, NullBufferBuilder},
22-
ArrayRef, StructArray,
22+
ArrayRef, RecordBatch, StructArray,
2323
};
24-
use arrow_schema::{DataType, Field, Fields};
24+
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
25+
use async_trait::async_trait;
2526
use datafusion::{
2627
config::TableParquetOptions,
2728
datasource::{
28-
file_format::parquet::ParquetSink, physical_plan::FileSinkConfig, sink::DataSinkExec,
29+
file_format::parquet::ParquetSink,
30+
physical_plan::FileSinkConfig,
31+
sink::{DataSink, DataSinkExec},
2932
},
3033
};
3134
use datafusion_common::{
3235
config::ConfigOptions, exec_datafusion_err, exec_err, not_impl_err, DataFusionError, Result,
3336
};
37+
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
3438
use datafusion_expr::{dml::InsertOp, ColumnarValue, ScalarUDF, Volatility};
3539
use datafusion_physical_expr::{
3640
expressions::Column, LexRequirement, PhysicalExpr, ScalarFunctionExpr,
3741
};
38-
use datafusion_physical_plan::{projection::ProjectionExec, ExecutionPlan};
42+
use datafusion_physical_plan::{
43+
stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan,
44+
};
3945
use float_next_after::NextAfter;
46+
use futures::StreamExt;
4047
use geo_traits::GeometryTrait;
4148
use sedona_common::sedona_internal_err;
4249
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
@@ -58,7 +65,7 @@ use crate::{
5865
};
5966

6067
pub fn create_geoparquet_writer_physical_plan(
61-
mut input: Arc<dyn ExecutionPlan>,
68+
input: Arc<dyn ExecutionPlan>,
6269
mut conf: FileSinkConfig,
6370
order_requirements: Option<LexRequirement>,
6471
options: &TableGeoParquetOptions,
@@ -76,6 +83,8 @@ pub fn create_geoparquet_writer_physical_plan(
7683
// We have geometry and/or geography! Collect the GeoParquetMetadata we'll need to write
7784
let mut metadata = GeoParquetMetadata::default();
7885
let mut bbox_columns = HashMap::new();
86+
let mut bbox_projection = None;
87+
let mut parquet_output_schema = conf.output_schema().clone();
7988

8089
// Check the version
8190
match options.geoparquet_version {
@@ -84,9 +93,10 @@ pub fn create_geoparquet_writer_physical_plan(
8493
}
8594
GeoParquetVersion::V1_1 => {
8695
metadata.version = "1.1.0".to_string();
87-
(input, bbox_columns) = project_bboxes(input, options.overwrite_bbox_columns)?;
88-
conf.output_schema = input.schema();
89-
output_geometry_column_indices = input.schema().geometry_column_indices()?;
96+
(bbox_projection, bbox_columns) =
97+
project_bboxes(&input, options.overwrite_bbox_columns)?;
98+
parquet_output_schema = compute_final_schema(&bbox_projection, &input.schema())?;
99+
output_geometry_column_indices = conf.output_schema.geometry_column_indices()?;
90100
}
91101
_ => {
92102
return not_impl_err!(
@@ -168,10 +178,78 @@ pub fn create_geoparquet_writer_physical_plan(
168178
);
169179

170180
// Create the sink
171-
let sink = Arc::new(ParquetSink::new(conf, parquet_options));
181+
let sink_input_schema = conf.output_schema;
182+
conf.output_schema = parquet_output_schema.clone();
183+
let sink = Arc::new(GeoParquetSink {
184+
inner: ParquetSink::new(conf, parquet_options),
185+
projection: bbox_projection,
186+
sink_input_schema,
187+
parquet_output_schema,
188+
});
172189
Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
173190
}
174191

192+
/// Implementation of [DataSink] that computes GeoParquet 1.1 bbox columns
193+
/// if needed. This is used instead of a ProjectionExec because DataFusion's
194+
/// optimizer rules seem to rearrange the projection in ways that cause
195+
/// the plan to fail <https://github.com/apache/sedona-db/issues/379>.
196+
#[derive(Debug)]
197+
struct GeoParquetSink {
198+
inner: ParquetSink,
199+
projection: Option<Vec<(Arc<dyn PhysicalExpr>, String)>>,
200+
sink_input_schema: SchemaRef,
201+
parquet_output_schema: SchemaRef,
202+
}
203+
204+
impl DisplayAs for GeoParquetSink {
205+
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206+
self.inner.fmt_as(t, f)
207+
}
208+
}
209+
210+
#[async_trait]
211+
impl DataSink for GeoParquetSink {
212+
fn as_any(&self) -> &dyn Any {
213+
self
214+
}
215+
216+
fn schema(&self) -> &SchemaRef {
217+
&self.sink_input_schema
218+
}
219+
220+
async fn write_all(
221+
&self,
222+
data: SendableRecordBatchStream,
223+
context: &Arc<TaskContext>,
224+
) -> Result<u64> {
225+
if let Some(projection) = &self.projection {
226+
// If we have a projection, apply it here
227+
let schema = self.parquet_output_schema.clone();
228+
let projection = projection.clone();
229+
230+
let data = Box::pin(RecordBatchStreamAdapter::new(
231+
schema.clone(),
232+
data.map(move |batch_result| {
233+
let schema = schema.clone();
234+
235+
batch_result.and_then(|batch| {
236+
let mut columns = Vec::with_capacity(projection.len());
237+
for (expr, _) in &projection {
238+
let col = expr.evaluate(&batch)?;
239+
columns.push(col.into_array(batch.num_rows())?);
240+
}
241+
Ok(RecordBatch::try_new(schema.clone(), columns)?)
242+
})
243+
}),
244+
));
245+
246+
self.inner.write_all(data, context).await
247+
} else {
248+
self.inner.write_all(data, context).await
249+
}
250+
}
251+
}
252+
175253
/// Create a regular Parquet writer like DataFusion would otherwise do.
176254
fn create_inner_writer(
177255
input: Arc<dyn ExecutionPlan>,
@@ -184,6 +262,11 @@ fn create_inner_writer(
184262
Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
185263
}
186264

265+
type ProjectBboxesResult = (
266+
Option<Vec<(Arc<dyn PhysicalExpr>, String)>>,
267+
HashMap<String, String>,
268+
);
269+
187270
/// Create a projection that inserts a bbox column for every geometry column
188271
///
189272
/// This implements creating the GeoParquet 1.1 bounding box columns,
@@ -206,9 +289,9 @@ fn create_inner_writer(
206289
/// "some_col_bbox", it is unlikely that replacing it would have unintended
207290
/// consequences.
208291
fn project_bboxes(
209-
input: Arc<dyn ExecutionPlan>,
292+
input: &Arc<dyn ExecutionPlan>,
210293
overwrite_bbox_columns: bool,
211-
) -> Result<(Arc<dyn ExecutionPlan>, HashMap<String, String>)> {
294+
) -> Result<ProjectBboxesResult> {
212295
let input_schema = input.schema();
213296
let matcher = ArgMatcher::is_geometry();
214297
let bbox_udf: Arc<ScalarUDF> = Arc::new(geoparquet_bbox_udf().into());
@@ -245,7 +328,7 @@ fn project_bboxes(
245328
// If we don't need to create any bbox columns, don't add an additional
246329
// projection at the end of the input plan
247330
if bbox_exprs.is_empty() {
248-
return Ok((input, HashMap::new()));
331+
return Ok((None, HashMap::new()));
249332
}
250333

251334
// Create the projection expressions
@@ -275,13 +358,34 @@ Use overwrite_bbox_columns = True if this is what was intended.",
275358
exprs.push((column, f.name().clone()));
276359
}
277360

278-
// Create the projection
279-
let exec = ProjectionExec::try_new(exprs, input)?;
280-
281361
// Flip the bbox_column_names into the form our caller needs it
282362
let bbox_column_names_by_field = bbox_column_names.drain().map(|(k, v)| (v, k)).collect();
283363

284-
Ok((Arc::new(exec), bbox_column_names_by_field))
364+
Ok((Some(exprs), bbox_column_names_by_field))
365+
}
366+
367+
fn compute_final_schema(
368+
bbox_projection: &Option<Vec<(Arc<dyn PhysicalExpr>, String)>>,
369+
initial_schema: &SchemaRef,
370+
) -> Result<SchemaRef> {
371+
if let Some(bbox_projection) = bbox_projection {
372+
let new_fields = bbox_projection
373+
.iter()
374+
.map(|(expr, name)| -> Result<Field> {
375+
let return_field_ref = expr.return_field(initial_schema)?;
376+
Ok(Field::new(
377+
name,
378+
return_field_ref.data_type().clone(),
379+
return_field_ref.is_nullable(),
380+
)
381+
.with_metadata(return_field_ref.metadata().clone()))
382+
})
383+
.collect::<Result<Vec<_>>>()?;
384+
385+
Ok(Arc::new(Schema::new(new_fields)))
386+
} else {
387+
Ok(initial_schema.clone())
388+
}
285389
}
286390

287391
fn geoparquet_bbox_udf() -> SedonaScalarUDF {
@@ -419,7 +523,7 @@ mod test {
419523
};
420524
use datafusion_common::cast::{as_float32_array, as_struct_array};
421525
use datafusion_common::ScalarValue;
422-
use datafusion_expr::{Expr, LogicalPlanBuilder};
526+
use datafusion_expr::{Cast, Expr, LogicalPlanBuilder};
423527
use sedona_schema::datatypes::WKB_GEOMETRY;
424528
use sedona_testing::create::create_array;
425529
use sedona_testing::data::test_geoparquet;
@@ -745,6 +849,60 @@ mod test {
745849
.unwrap();
746850
}
747851

852+
#[tokio::test]
853+
async fn geoparquet_1_1_with_sort_by_expr() {
854+
let example = test_geoparquet("ns-water", "water-point");
855+
856+
// Requires submodules/download-assets.py which not all contributors need
857+
let example = match example {
858+
Ok(path) => path,
859+
Err(err) => {
860+
println!("ns-water/water-point is not available: {err}");
861+
return;
862+
}
863+
};
864+
865+
let ctx = setup_context();
866+
let fns = sedona_functions::register::default_function_set();
867+
868+
let geometry_udf: ScalarUDF = fns.scalar_udf("sd_format").unwrap().clone().into();
869+
let bbox_udf: ScalarUDF = geoparquet_bbox_udf().into();
870+
871+
let df = ctx
872+
.table(&example)
873+
.await
874+
.unwrap()
875+
.sort_by(vec![geometry_udf.call(vec![col("geometry")])])
876+
.unwrap()
877+
.select(vec![
878+
Expr::Cast(Cast::new(
879+
geometry_udf.call(vec![col("geometry")]).alias("txt").into(),
880+
DataType::Utf8View,
881+
)),
882+
col("geometry"),
883+
])
884+
.unwrap();
885+
886+
let mut options = TableGeoParquetOptions::new();
887+
options.geoparquet_version = GeoParquetVersion::V1_1;
888+
889+
let df_batches_with_bbox = df
890+
.clone()
891+
.select(vec![
892+
col("txt"),
893+
bbox_udf.call(vec![col("geometry")]).alias("bbox"),
894+
col("geometry"),
895+
])
896+
.unwrap()
897+
.collect()
898+
.await
899+
.unwrap();
900+
901+
test_write_dataframe(ctx, df, df_batches_with_bbox, options, vec![])
902+
.await
903+
.unwrap();
904+
}
905+
748906
#[test]
749907
fn float_bbox() {
750908
let tester = ScalarUdfTester::new(geoparquet_bbox_udf().into(), vec![WKB_GEOMETRY]);

0 commit comments

Comments
 (0)