diff --git a/Cargo.lock b/Cargo.lock index 666d0f4..8f0f1b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,12 +145,21 @@ dependencies = [ "snap", "strum 0.27.2", "strum_macros 0.27.2", - "thiserror", + "thiserror 2.0.17", "uuid", "xz2", "zstd", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "array-init" version = "2.1.0" @@ -336,12 +345,14 @@ name = "arrow-pg" version = "0.8.1" dependencies = [ "arrow", + "arrow-schema", "async-trait", "bytes", "chrono", "datafusion", "duckdb", "futures", + "geoarrow-schema", "pgwire", "postgres-types", "rust_decimal", @@ -1921,6 +1932,39 @@ dependencies = [ "version_check", ] +[[package]] +name = "geo-traits" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e7c353d12a704ccfab1ba8bfb1a7fe6cb18b665bf89d37f4f7890edcd260206" +dependencies = [ + "geo-types", +] + +[[package]] +name = "geo-types" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a4dcd69d35b2c87a7c83bce9af69fd65c9d68d3833a0ded568983928f3fc99" +dependencies = [ + "approx", + "num-traits", + "serde", +] + +[[package]] +name = "geoarrow-schema" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02f1b18b1c9a44ecd72be02e53d6e63bbccfdc8d1765206226af227327e2be6e" +dependencies = [ + "arrow-schema", + "geo-traits", + "serde", + "serde_json", + "thiserror 1.0.69", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -2602,7 +2646,7 @@ dependencies = [ "itertools", "parking_lot", "percent-encoding", - "thiserror", + "thiserror 2.0.17", "tokio", "tracing", "url", @@ -2750,7 +2794,7 @@ dependencies = [ "serde", "serde_json", "stringprep", - "thiserror", + "thiserror 2.0.17", "tokio", "tokio-rustls", "tokio-util", @@ -2836,6 +2880,7 @@ dependencies = [ "bytes", "chrono", "fallible-iterator 0.2.0", + "geo-types", "postgres-protocol", "serde_core", "serde_json", @@ -3630,13 +3675,33 @@ dependencies = [ "unicode-width 0.1.14", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + [[package]] name = "thiserror" version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.17", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -4322,7 +4387,7 @@ dependencies = [ "ring", "signature", "spki", - "thiserror", + "thiserror 2.0.17", "zeroize", ] diff --git a/Cargo.toml b/Cargo.toml index ca07b69..4da72fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] arrow = "56" +arrow-schema = "56" bytes = "1.10.1" chrono = { version = "0.4", features = ["std"] } datafusion = { version = "50", default-features = false } diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 92c60ca..2be4e76 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -13,9 +13,10 @@ readme = "../README.md" rust-version.workspace = true [features] -default = ["arrow"] +default = ["arrow", "geo"] arrow = ["dep:arrow"] datafusion = ["dep:datafusion"] +geo = ["postgres-types/with-geo-types-0_7", "dep:geoarrow-schema"] # for testing _duckdb = [] _bundled = ["duckdb/bundled"] @@ -23,6 +24,8 @@ _bundled = ["duckdb/bundled"] [dependencies] arrow = { workspace = true, optional = true } +arrow-schema = { workspace = true } +geoarrow-schema = { version = "0.6", optional = true } bytes.workspace = true chrono.workspace = true datafusion = { workspace = true, optional = true } diff --git a/arrow-pg/examples/duckdb.rs b/arrow-pg/examples/duckdb.rs index 29faa1e..7298680 100644 --- a/arrow-pg/examples/duckdb.rs +++ b/arrow-pg/examples/duckdb.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, Mutex}; use arrow_pg::datatypes::arrow_schema_to_pg_fields; use arrow_pg::datatypes::encode_recordbatch; use arrow_pg::datatypes::into_pg_type; +use arrow_schema::Field; use async_trait::async_trait; use duckdb::{params, Connection, Statement, ToSql}; use futures::stream; @@ -137,11 +138,13 @@ fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult PgWireResult { - Ok(match arrow_type { - DataType::Null => Type::UNKNOWN, - DataType::Boolean => Type::BOOL, - DataType::Int8 | DataType::UInt8 => Type::CHAR, - DataType::Int16 | DataType::UInt16 => Type::INT2, - DataType::Int32 | DataType::UInt32 => Type::INT4, - DataType::Int64 | DataType::UInt64 => Type::INT8, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ - } else { - Type::TIMESTAMP +pub fn into_pg_type(field: &Arc) -> PgWireResult { + let arrow_type = field.data_type(); + + match field.extension_type_name() { + #[cfg(feature = "geo")] + Some(geoarrow_schema::PointType::NAME) => Ok(Type::POINT), + _ => Ok(match arrow_type { + DataType::Null => Type::UNKNOWN, + DataType::Boolean => Type::BOOL, + DataType::Int8 | DataType::UInt8 => Type::CHAR, + DataType::Int16 | DataType::UInt16 => Type::INT2, + DataType::Int32 | DataType::UInt32 => Type::INT4, + DataType::Int64 | DataType::UInt64 => Type::INT8, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ + } else { + Type::TIMESTAMP + } } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME, - DataType::Date32 | DataType::Date64 => Type::DATE, - DataType::Interval(_) => Type::INTERVAL, - DataType::Binary - | DataType::FixedSizeBinary(_) - | DataType::LargeBinary - | DataType::BinaryView => Type::BYTEA, - DataType::Float16 | DataType::Float32 => Type::FLOAT4, - DataType::Float64 => Type::FLOAT8, - DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, - DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { - match field.data_type() { + DataType::Time32(_) | DataType::Time64(_) => Type::TIME, + DataType::Date32 | DataType::Date64 => Type::DATE, + DataType::Interval(_) => Type::INTERVAL, + DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView => Type::BYTEA, + DataType::Float16 | DataType::Float32 => Type::FLOAT4, + DataType::Float64 => Type::FLOAT8, + DataType::Decimal128(_, _) => Type::NUMERIC, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT, + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => match field.data_type() { DataType::Boolean => Type::BOOL_ARRAY, DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, @@ -68,10 +77,10 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, DataType::Float64 => Type::FLOAT8_ARRAY, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Type::TEXT_ARRAY, - struct_type @ DataType::Struct(_) => Type::new( + DataType::Struct(_) => Type::new( Type::RECORD_ARRAY.name().into(), Type::RECORD_ARRAY.oid(), - Kind::Array(into_pg_type(struct_type)?), + Kind::Array(into_pg_type(field)?), Type::RECORD_ARRAY.schema().into(), ), list_type => { @@ -81,35 +90,42 @@ pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { format!("Unsupported List Datatype {list_type}"), )))); } + }, + DataType::Dictionary(_, value_type) => { + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + *value_type.clone(), + true, + )); + into_pg_type(&field)? } - } - DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, - DataType::Struct(fields) => { - let name: String = fields - .iter() - .map(|x| x.name().clone()) - .reduce(|a, b| a + ", " + &b) - .map(|x| format!("({x})")) - .unwrap_or("()".to_string()); - let kind = Kind::Composite( - fields + DataType::Struct(fields) => { + let name: String = fields .iter() - .map(|x| { - into_pg_type(x.data_type()) - .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) - }) - .collect::, PgWireError>>()?, - ); - Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) - } - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported Datatype {arrow_type}"), - )))); - } - }) + .map(|x| x.name().clone()) + .reduce(|a, b| a + ", " + &b) + .map(|x| format!("({x})")) + .unwrap_or("()".to_string()); + let kind = Kind::Composite( + fields + .iter() + .map(|x| { + into_pg_type(x) + .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) + }) + .collect::, PgWireError>>()?, + ); + Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) + } + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported Datatype {arrow_type}"), + )))); + } + }), + } } pub fn arrow_schema_to_pg_fields( @@ -123,7 +139,7 @@ pub fn arrow_schema_to_pg_fields( .iter() .enumerate() .map(|(idx, f)| { - let pg_type = into_pg_type(f.data_type())?; + let pg_type = into_pg_type(f)?; let mut field_info = FieldInfo::new(f.name().into(), None, None, pg_type, format.format_for(idx)); if let Some(data_format_options) = &data_format_options { diff --git a/arrow-pg/src/datatypes/df.rs b/arrow-pg/src/datatypes/df.rs index 2959128..1451502 100644 --- a/arrow-pg/src/datatypes/df.rs +++ b/arrow-pg/src/datatypes/df.rs @@ -2,7 +2,7 @@ use std::iter; use std::sync::Arc; use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; -use datafusion::arrow::datatypes::{DataType, Date32Type}; +use datafusion::arrow::datatypes::{DataType, Date32Type, Field}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::ParamValues; use datafusion::prelude::*; @@ -70,7 +70,7 @@ where if let Some(ty) = pg_type_hint { Ok(ty.clone()) } else if let Some(infer_type) = inferenced_type { - into_pg_type(infer_type) + into_pg_type(&Arc::new(Field::new("item", infer_type.clone(), true))) } else { Ok(Type::UNKNOWN) } diff --git a/arrow-pg/src/encoder.rs b/arrow-pg/src/encoder.rs index e35fadd..b7144a0 100644 --- a/arrow-pg/src/encoder.rs +++ b/arrow-pg/src/encoder.rs @@ -284,10 +284,9 @@ pub fn encode_value( encoder: &mut T, arr: &Arc, idx: usize, + _arrow_filed: &Field, pg_field: &FieldInfo, ) -> PgWireResult<()> { - let type_ = pg_field.datatype(); - match arr.data_type() { DataType::Null => encoder.encode_field(&None::, pg_field)?, DataType::Boolean => encoder.encode_field(&get_bool_value(arr, idx), pg_field)?, @@ -422,16 +421,8 @@ pub fn encode_value( let value = encode_list(array, pg_field)?; encoder.encode_field(&value, pg_field)? } - DataType::Struct(_) => { - let fields = match type_.kind() { - postgres_types::Kind::Composite(fields) => fields, - _ => { - return Err(PgWireError::ApiError(ToSqlError::from(format!( - "Failed to unwrap a composite type from type {type_}" - )))); - } - }; - let value = encode_struct(arr, idx, fields, pg_field)?; + DataType::Struct(arrow_fields) => { + let value = encode_struct(arr, idx, arrow_fields, pg_field)?; encoder.encode_field(&value, pg_field)? } DataType::Dictionary(_, value_type) => { @@ -462,7 +453,9 @@ pub fn encode_value( )) })?; - encode_value(encoder, values, idx, pg_field)? + let inner_arrow_field = Field::new(pg_field.name(), *value_type.clone(), true); + + encode_value(encoder, values, idx, &inner_arrow_field, pg_field)? } _ => { return Err(PgWireError::ApiError(ToSqlError::from(format!( @@ -511,8 +504,9 @@ mod tests { let mut encoder = MockEncoder::default(); + let arrow_field = Field::new("x", DataType::Utf8, true); let pg_field = FieldInfo::new("x".to_string(), None, None, Type::TEXT, FieldFormat::Text); - let result = encode_value(&mut encoder, &dict_arr, 2, &pg_field); + let result = encode_value(&mut encoder, &dict_arr, 2, &arrow_field, &pg_field); assert!(result.is_ok()); diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index d7dca3d..0da22fb 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -386,27 +386,9 @@ pub(crate) fn encode_list(arr: Arc, pg_field: &FieldInfo) -> PgWireRe } } }, - DataType::Struct(_) => { - let fields = match type_.kind() { - postgres_types::Kind::Array(struct_type_) => Ok(struct_type_), - _ => Err(format!( - "Expected list type found type {} of kind {:?}", - type_, - type_.kind() - )), - } - .and_then(|struct_type| match struct_type.kind() { - postgres_types::Kind::Composite(fields) => Ok(fields), - _ => Err(format!( - "Failed to unwrap a composite type inside from type {} kind {:?}", - type_, - type_.kind() - )), - }) - .map_err(ToSqlError::from)?; - + DataType::Struct(arrow_fields) => { let values: PgWireResult> = (0..arr.len()) - .map(|row| encode_struct(&arr, row, fields, pg_field)) + .map(|row| encode_struct(&arr, row, arrow_fields, pg_field)) .map(|x| { if matches!(format, FieldFormat::Text) { x.map(|opt| { diff --git a/arrow-pg/src/row_encoder.rs b/arrow-pg/src/row_encoder.rs index 674751b..f9b0a32 100644 --- a/arrow-pg/src/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -33,12 +33,17 @@ impl RowEncoder { if self.curr_idx == self.rb.num_rows() { return None; } + let arrow_schema = self.rb.schema_ref(); let mut encoder = DataRowEncoder::new(self.fields.clone()); for col in 0..self.rb.num_columns() { let array = self.rb.column(col); - let field = &self.fields[col]; + let arrow_field = arrow_schema.field(col); + let pg_field = &self.fields[col]; - encode_value(&mut encoder, array, self.curr_idx, field).unwrap(); + if let Err(e) = encode_value(&mut encoder, array, self.curr_idx, arrow_field, pg_field) + { + return Some(Err(e)); + }; } self.curr_idx += 1; Some(encoder.finish()) diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index ad86b96..7db119f 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -2,6 +2,7 @@ use std::sync::Arc; #[cfg(not(feature = "datafusion"))] use arrow::array::{Array, StructArray}; +use arrow_schema::Fields; #[cfg(feature = "datafusion")] use datafusion::arrow::array::{Array, StructArray}; @@ -11,23 +12,33 @@ use pgwire::error::PgWireResult; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; use postgres_types::{Field, IsNull, ToSql}; +use crate::datatypes::into_pg_type; use crate::encoder::{encode_value, EncodedValue, Encoder}; pub(crate) fn encode_struct( arr: &Arc, idx: usize, - fields: &[Field], + arrow_fields: &Fields, parent_pg_field_info: &FieldInfo, ) -> PgWireResult> { let arr = arr.as_any().downcast_ref::().unwrap(); if arr.is_null(idx) { return Ok(None); } - let mut row_encoder = StructEncoder::new(fields.len()); + + let fields = arrow_fields + .iter() + .map(|f| into_pg_type(f).map(|t| Field::new(f.name().to_owned(), t))) + .collect::>>()?; + + let mut row_encoder = StructEncoder::new(arrow_fields.len()); + for (i, arr) in arr.columns().iter().enumerate() { let field = &fields[i]; let type_ = field.type_(); + let arrow_field = &arrow_fields[i]; + let mut pg_field = FieldInfo::new( field.name().to_string(), None, @@ -35,10 +46,9 @@ pub(crate) fn encode_struct( type_.clone(), parent_pg_field_info.format(), ); - pg_field = pg_field.with_format_options(parent_pg_field_info.format_options().clone()); - encode_value(&mut row_encoder, arr, idx, &pg_field).unwrap(); + encode_value(&mut row_encoder, arr, idx, arrow_field, &pg_field).unwrap(); } Ok(Some(EncodedValue { bytes: row_encoder.row_buffer, diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 7da8fb4..2a9aa9f 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -396,7 +396,8 @@ impl ExtendedQueryHandler for DfSessionService { for param_type in ordered_param_types(¶ms).iter() { // Fixed: Use ¶ms if let Some(datatype) = param_type { - let pgtype = into_pg_type(datatype)?; + let pgtype = + into_pg_type(&Arc::new(Field::new("item", (*datatype).clone(), true)))?; param_types.push(pgtype); } else { param_types.push(Type::UNKNOWN);