Skip to content

Commit b01ddde

Browse files
committed
Update signatures for FFI passing of logical codec around
1 parent bea73b1 commit b01ddde

File tree

14 files changed

+193
-164
lines changed

14 files changed

+193
-164
lines changed

examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,14 @@
1919

2020
import pyarrow as pa
2121
from datafusion import SessionContext
22-
from datafusion_ffi_example import MyCatalogProvider
22+
from datafusion_ffi_example import FixedSchemaProvider, MyCatalogProvider
2323

2424

25-
def test_catalog_provider():
26-
ctx = SessionContext()
27-
28-
my_catalog_name = "my_catalog"
25+
def common_checks(ctx: SessionContext, my_catalog_name: str) -> None:
2926
expected_schema_name = "my_schema"
3027
expected_table_name = "my_table"
3128
expected_table_columns = ["units", "price"]
3229

33-
catalog_provider = MyCatalogProvider(ctx)
34-
ctx.register_catalog_provider(my_catalog_name, catalog_provider)
3530
my_catalog = ctx.catalog(my_catalog_name)
3631

3732
my_catalog_schemas = my_catalog.names()
@@ -58,3 +53,23 @@ def test_catalog_provider():
5853
]
5954
assert col0_result == expected_col0
6055
assert col1_result == expected_col1
56+
57+
58+
def test_catalog_provider():
59+
ctx = SessionContext()
60+
61+
my_catalog_name = "my_catalog"
62+
63+
catalog_provider = MyCatalogProvider()
64+
ctx.register_catalog_provider(my_catalog_name, catalog_provider)
65+
common_checks(ctx, my_catalog_name)
66+
67+
68+
def test_schema_provider():
69+
ctx = SessionContext()
70+
71+
my_schema_name = "my_schema"
72+
73+
schema_provider = FixedSchemaProvider()
74+
ctx.catalog().register_schema(my_schema_name, schema_provider)
75+
common_checks(ctx, "datafusion")

examples/datafusion-ffi-example/src/catalog_provider.rs

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use datafusion_catalog::{
2727
};
2828
use datafusion_common::error::{DataFusionError, Result};
2929
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
30-
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
30+
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
3131
use pyo3::types::PyCapsule;
3232
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
3333

@@ -58,14 +58,19 @@ pub fn my_table() -> Arc<dyn TableProvider + 'static> {
5858
Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
5959
}
6060

61+
#[pyclass(
62+
name = "FixedSchemaProvider",
63+
module = "datafusion_ffi_example",
64+
subclass
65+
)]
6166
#[derive(Debug)]
6267
pub struct FixedSchemaProvider {
63-
inner: MemorySchemaProvider,
68+
inner: Arc<MemorySchemaProvider>,
6469
}
6570

6671
impl Default for FixedSchemaProvider {
6772
fn default() -> Self {
68-
let inner = MemorySchemaProvider::new();
73+
let inner = Arc::new(MemorySchemaProvider::new());
6974

7075
let table = my_table();
7176

@@ -75,6 +80,29 @@ impl Default for FixedSchemaProvider {
7580
}
7681
}
7782

83+
#[pymethods]
84+
impl FixedSchemaProvider {
85+
#[new]
86+
pub fn new() -> Self {
87+
Self::default()
88+
}
89+
90+
pub fn __datafusion_schema_provider__<'py>(
91+
&self,
92+
py: Python<'py>,
93+
session: Bound<PyAny>,
94+
) -> PyResult<Bound<'py, PyCapsule>> {
95+
let name = cr"datafusion_schema_provider".into();
96+
97+
let provider = Arc::clone(&self.inner) as Arc<dyn SchemaProvider + Send>;
98+
99+
let codec = ffi_logical_codec_from_pycapsule(session)?;
100+
let provider = FFI_SchemaProvider::new_with_ffi_codec(provider, None, codec);
101+
102+
PyCapsule::new(py, provider, Some(name))
103+
}
104+
}
105+
78106
#[async_trait]
79107
impl SchemaProvider for FixedSchemaProvider {
80108
fn as_any(&self) -> &dyn Any {
@@ -116,7 +144,6 @@ impl SchemaProvider for FixedSchemaProvider {
116144
#[derive(Debug, Clone)]
117145
pub(crate) struct MyCatalogProvider {
118146
inner: Arc<MemoryCatalogProvider>,
119-
logical_codec: FFI_LogicalExtensionCodec,
120147
}
121148

122149
impl CatalogProvider for MyCatalogProvider {
@@ -152,28 +179,27 @@ impl CatalogProvider for MyCatalogProvider {
152179
#[pymethods]
153180
impl MyCatalogProvider {
154181
#[new]
155-
pub fn new(session: &Bound<PyAny>) -> PyResult<Self> {
156-
let logical_codec = ffi_logical_codec_from_pycapsule(session)?;
182+
pub fn new() -> PyResult<Self> {
157183
let inner = Arc::new(MemoryCatalogProvider::new());
158184

159185
let schema_name: &str = "my_schema";
160186
let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default()));
161187

162-
Ok(Self {
163-
inner,
164-
logical_codec,
165-
})
188+
Ok(Self { inner })
166189
}
167190

168191
pub fn __datafusion_catalog_provider__<'py>(
169192
&self,
170193
py: Python<'py>,
194+
session: Bound<PyAny>,
171195
) -> PyResult<Bound<'py, PyCapsule>> {
172196
let name = cr"datafusion_catalog_provider".into();
173-
let codec = self.logical_codec.clone();
174-
let catalog_provider =
175-
FFI_CatalogProvider::new_with_ffi_codec(Arc::new(self.clone()), None, codec);
176197

177-
PyCapsule::new(py, catalog_provider, Some(name))
198+
let provider = Arc::clone(&self.inner) as Arc<dyn CatalogProvider + Send>;
199+
200+
let codec = ffi_logical_codec_from_pycapsule(session)?;
201+
let provider = FFI_CatalogProvider::new_with_ffi_codec(provider, None, codec);
202+
203+
PyCapsule::new(py, provider, Some(name))
178204
}
179205
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use pyo3::prelude::*;
1919

2020
use crate::aggregate_udf::MySumUDF;
21-
use crate::catalog_provider::MyCatalogProvider;
21+
use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider};
2222
use crate::scalar_udf::IsNullUDF;
2323
use crate::table_function::MyTableFunction;
2424
use crate::table_provider::MyTableProvider;
@@ -37,6 +37,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
3737
m.add_class::<MyTableProvider>()?;
3838
m.add_class::<MyTableFunction>()?;
3939
m.add_class::<MyCatalogProvider>()?;
40+
m.add_class::<FixedSchemaProvider>()?;
4041
m.add_class::<IsNullUDF>()?;
4142
m.add_class::<MySumUDF>()?;
4243
m.add_class::<MyRankUDF>()?;

examples/datafusion-ffi-example/src/table_function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl MyTableFunction {
4141
fn __datafusion_table_function__<'py>(
4242
&self,
4343
py: Python<'py>,
44-
session: &Bound<PyAny>,
44+
session: Bound<PyAny>,
4545
) -> PyResult<Bound<'py, PyCapsule>> {
4646
let name = cr"datafusion_table_function".into();
4747

examples/datafusion-ffi-example/src/table_provider.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ impl MyTableProvider {
9393
pub fn __datafusion_table_provider__<'py>(
9494
&self,
9595
py: Python<'py>,
96-
session: &Bound<PyAny>,
96+
session: Bound<PyAny>,
9797
) -> PyResult<Bound<'py, PyCapsule>> {
9898
let name = cr"datafusion_table_provider".into();
9999

100100
let provider = self
101101
.create_table()
102-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
102+
.map_err(|e: DataFusionError| PyRuntimeError::new_err(e.to_string()))?;
103103

104104
let codec = ffi_logical_codec_from_pycapsule(session)?;
105105
let provider =

examples/datafusion-ffi-example/src/utils.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,21 @@ use pyo3::types::PyCapsule;
55
use pyo3::{Bound, PyAny, PyResult};
66

77
pub(crate) fn ffi_logical_codec_from_pycapsule(
8-
obj: &Bound<PyAny>,
8+
obj: Bound<PyAny>,
99
) -> PyResult<FFI_LogicalExtensionCodec> {
1010
let attr_name = "__datafusion_logical_extension_codec__";
11+
let capsule = if obj.hasattr(attr_name)? {
12+
obj.getattr(attr_name)?.call0()?
13+
} else {
14+
obj
15+
};
1116

12-
if obj.hasattr(attr_name)? {
13-
let capsule = obj.getattr(attr_name)?.call0()?;
14-
let capsule = capsule.downcast::<PyCapsule>()?;
15-
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
17+
let capsule = capsule.downcast::<PyCapsule>()?;
18+
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
1619

17-
let provider = unsafe { capsule.reference::<FFI_LogicalExtensionCodec>() };
20+
let codec = unsafe { capsule.reference::<FFI_LogicalExtensionCodec>() };
1821

19-
Ok(provider.clone())
20-
} else {
21-
Err(PyValueError::new_err(
22-
"Expected PyCapsule object for FFI_LogicalExtensionCodec, but attribute does not exist",
23-
))
24-
}
22+
Ok(codec.clone())
2523
}
2624

2725
pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {

python/datafusion/catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,4 +273,4 @@ class SchemaProviderExportable(Protocol):
273273
https://docs.rs/datafusion/latest/datafusion/catalog/trait.SchemaProvider.html
274274
"""
275275

276-
def __datafusion_schema_provider__(self) -> object: ...
276+
def __datafusion_schema_provider__(self, session: Any) -> object: ...

python/datafusion/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TableProviderExportable(Protocol):
8888
https://datafusion.apache.org/python/user-guide/io/table_provider.html
8989
"""
9090

91-
def __datafusion_table_provider__(self) -> object: ... # noqa: D105
91+
def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105
9292

9393

9494
class CatalogProviderExportable(Protocol):
@@ -97,7 +97,7 @@ class CatalogProviderExportable(Protocol):
9797
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
9898
"""
9999

100-
def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105
100+
def __datafusion_catalog_provider__(self, session: Any) -> object: ... # noqa: D105
101101

102102

103103
class SessionConfig:

src/catalog.rs

Lines changed: 40 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ use pyo3::IntoPyObjectExt;
3535
use crate::dataset::Dataset;
3636
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
3737
use crate::table::PyTable;
38-
use crate::utils::{extract_logical_extension_codec, validate_pycapsule, wait_for_future};
38+
use crate::utils::{
39+
create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule,
40+
wait_for_future,
41+
};
3942

4043
#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)]
4144
#[derive(Clone)]
@@ -111,23 +114,7 @@ impl PyCatalog {
111114
}
112115

113116
pub fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
114-
let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
115-
let capsule = schema_provider
116-
.getattr("__datafusion_schema_provider__")?
117-
.call0()?;
118-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
119-
validate_pycapsule(capsule, "datafusion_schema_provider")?;
120-
121-
let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
122-
let provider: Arc<dyn SchemaProvider + Send> = provider.into();
123-
provider as Arc<dyn SchemaProvider>
124-
} else {
125-
match schema_provider.extract::<PySchema>() {
126-
Ok(py_schema) => py_schema.schema,
127-
Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
128-
as Arc<dyn SchemaProvider>,
129-
}
130-
};
117+
let provider = extract_schema_provider_from_pyobj(schema_provider, self.codec.as_ref())?;
131118

132119
let _ = self
133120
.catalog
@@ -195,14 +182,11 @@ impl PySchema {
195182
Ok(format!("Schema(table_names=[{}])", names.join(";")))
196183
}
197184

198-
fn register_table(&self, name: &str, table_provider: &Bound<'_, PyAny>) -> PyResult<()> {
185+
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
199186
let py = table_provider.py();
200-
let codec_capsule = PyCapsule::new(
201-
py,
202-
self.codec.clone(),
203-
Some(cr"datafusion_logical_extension_codec".into()),
204-
)?
205-
.into_bound_py_any(py)?;
187+
let codec_capsule = create_logical_extension_capsule(py, self.codec.as_ref())?
188+
.as_any()
189+
.clone();
206190

207191
let table = PyTable::new(table_provider, Some(codec_capsule))?;
208192

@@ -256,7 +240,7 @@ impl RustWrappedPySchemaProvider {
256240
return Ok(None);
257241
}
258242

259-
let table = PyTable::new(&py_table, None)?;
243+
let table = PyTable::new(py_table, None)?;
260244

261245
Ok(Some(table.table))
262246
})
@@ -370,32 +354,7 @@ impl RustWrappedPyCatalogProvider {
370354
return Ok(None);
371355
}
372356

373-
if py_schema.hasattr("__datafusion_schema_provider__")? {
374-
let capsule = provider
375-
.getattr("__datafusion_schema_provider__")?
376-
.call0()?;
377-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
378-
validate_pycapsule(capsule, "datafusion_schema_provider")?;
379-
380-
let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
381-
let provider: Arc<dyn SchemaProvider + Send> = provider.into();
382-
383-
Ok(Some(provider as Arc<dyn SchemaProvider>))
384-
} else {
385-
if let Ok(inner_schema) = py_schema.getattr("schema") {
386-
if let Ok(inner_schema) = inner_schema.extract::<PySchema>() {
387-
return Ok(Some(inner_schema.schema));
388-
}
389-
}
390-
match py_schema.extract::<PySchema>() {
391-
Ok(inner_schema) => Ok(Some(inner_schema.schema)),
392-
Err(_) => {
393-
let py_schema = RustWrappedPySchemaProvider::new(py_schema.into());
394-
395-
Ok(Some(Arc::new(py_schema) as Arc<dyn SchemaProvider>))
396-
}
397-
}
398-
}
357+
extract_schema_provider_from_pyobj(py_schema, self.codec.as_ref()).map(Some)
399358
})
400359
}
401360
}
@@ -479,6 +438,35 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
479438
}
480439
}
481440

441+
fn extract_schema_provider_from_pyobj(
442+
mut schema_provider: Bound<PyAny>,
443+
codec: &FFI_LogicalExtensionCodec,
444+
) -> PyResult<Arc<dyn SchemaProvider>> {
445+
if schema_provider.hasattr("__datafusion_schema_provider__")? {
446+
let py = schema_provider.py();
447+
let codec_capsule = create_logical_extension_capsule(py, codec)?;
448+
schema_provider = schema_provider
449+
.getattr("__datafusion_schema_provider__")?
450+
.call1((codec_capsule,))?;
451+
}
452+
453+
let provider = if let Ok(capsule) = schema_provider.downcast::<PyCapsule>() {
454+
validate_pycapsule(capsule, "datafusion_schema_provider")?;
455+
456+
let provider = unsafe { capsule.reference::<FFI_SchemaProvider>() };
457+
let provider: Arc<dyn SchemaProvider + Send> = provider.into();
458+
provider as Arc<dyn SchemaProvider>
459+
} else {
460+
match schema_provider.extract::<PySchema>() {
461+
Ok(py_schema) => py_schema.schema,
462+
Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
463+
as Arc<dyn SchemaProvider>,
464+
}
465+
};
466+
467+
Ok(provider)
468+
}
469+
482470
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
483471
m.add_class::<PyCatalog>()?;
484472
m.add_class::<PySchema>()?;

0 commit comments

Comments
 (0)