@@ -35,7 +35,10 @@ use pyo3::IntoPyObjectExt;
3535use crate :: dataset:: Dataset ;
3636use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
3737use crate :: table:: PyTable ;
38- use crate :: utils:: { extract_logical_extension_codec, validate_pycapsule, wait_for_future} ;
38+ use crate :: utils:: {
39+ create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule,
40+ wait_for_future,
41+ } ;
3942
4043#[ pyclass( frozen, name = "RawCatalog" , module = "datafusion.catalog" , subclass) ]
4144#[ derive( Clone ) ]
@@ -111,23 +114,7 @@ impl PyCatalog {
111114 }
112115
113116 pub fn register_schema ( & self , name : & str , schema_provider : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
114- let provider = if schema_provider. hasattr ( "__datafusion_schema_provider__" ) ? {
115- let capsule = schema_provider
116- . getattr ( "__datafusion_schema_provider__" ) ?
117- . call0 ( ) ?;
118- let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
119- validate_pycapsule ( capsule, "datafusion_schema_provider" ) ?;
120-
121- let provider = unsafe { capsule. reference :: < FFI_SchemaProvider > ( ) } ;
122- let provider: Arc < dyn SchemaProvider + Send > = provider. into ( ) ;
123- provider as Arc < dyn SchemaProvider >
124- } else {
125- match schema_provider. extract :: < PySchema > ( ) {
126- Ok ( py_schema) => py_schema. schema ,
127- Err ( _) => Arc :: new ( RustWrappedPySchemaProvider :: new ( schema_provider. into ( ) ) )
128- as Arc < dyn SchemaProvider > ,
129- }
130- } ;
117+ let provider = extract_schema_provider_from_pyobj ( schema_provider, self . codec . as_ref ( ) ) ?;
131118
132119 let _ = self
133120 . catalog
@@ -195,14 +182,11 @@ impl PySchema {
195182 Ok ( format ! ( "Schema(table_names=[{}])" , names. join( ";" ) ) )
196183 }
197184
198- fn register_table ( & self , name : & str , table_provider : & Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
185+ fn register_table ( & self , name : & str , table_provider : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
199186 let py = table_provider. py ( ) ;
200- let codec_capsule = PyCapsule :: new (
201- py,
202- self . codec . clone ( ) ,
203- Some ( cr"datafusion_logical_extension_codec" . into ( ) ) ,
204- ) ?
205- . into_bound_py_any ( py) ?;
187+ let codec_capsule = create_logical_extension_capsule ( py, self . codec . as_ref ( ) ) ?
188+ . as_any ( )
189+ . clone ( ) ;
206190
207191 let table = PyTable :: new ( table_provider, Some ( codec_capsule) ) ?;
208192
@@ -256,7 +240,7 @@ impl RustWrappedPySchemaProvider {
256240 return Ok ( None ) ;
257241 }
258242
259- let table = PyTable :: new ( & py_table, None ) ?;
243+ let table = PyTable :: new ( py_table, None ) ?;
260244
261245 Ok ( Some ( table. table ) )
262246 } )
@@ -370,32 +354,7 @@ impl RustWrappedPyCatalogProvider {
370354 return Ok ( None ) ;
371355 }
372356
373- if py_schema. hasattr ( "__datafusion_schema_provider__" ) ? {
374- let capsule = provider
375- . getattr ( "__datafusion_schema_provider__" ) ?
376- . call0 ( ) ?;
377- let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
378- validate_pycapsule ( capsule, "datafusion_schema_provider" ) ?;
379-
380- let provider = unsafe { capsule. reference :: < FFI_SchemaProvider > ( ) } ;
381- let provider: Arc < dyn SchemaProvider + Send > = provider. into ( ) ;
382-
383- Ok ( Some ( provider as Arc < dyn SchemaProvider > ) )
384- } else {
385- if let Ok ( inner_schema) = py_schema. getattr ( "schema" ) {
386- if let Ok ( inner_schema) = inner_schema. extract :: < PySchema > ( ) {
387- return Ok ( Some ( inner_schema. schema ) ) ;
388- }
389- }
390- match py_schema. extract :: < PySchema > ( ) {
391- Ok ( inner_schema) => Ok ( Some ( inner_schema. schema ) ) ,
392- Err ( _) => {
393- let py_schema = RustWrappedPySchemaProvider :: new ( py_schema. into ( ) ) ;
394-
395- Ok ( Some ( Arc :: new ( py_schema) as Arc < dyn SchemaProvider > ) )
396- }
397- }
398- }
357+ extract_schema_provider_from_pyobj ( py_schema, self . codec . as_ref ( ) ) . map ( Some )
399358 } )
400359 }
401360}
@@ -479,6 +438,35 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
479438 }
480439}
481440
441+ fn extract_schema_provider_from_pyobj (
442+ mut schema_provider : Bound < PyAny > ,
443+ codec : & FFI_LogicalExtensionCodec ,
444+ ) -> PyResult < Arc < dyn SchemaProvider > > {
445+ if schema_provider. hasattr ( "__datafusion_schema_provider__" ) ? {
446+ let py = schema_provider. py ( ) ;
447+ let codec_capsule = create_logical_extension_capsule ( py, codec) ?;
448+ schema_provider = schema_provider
449+ . getattr ( "__datafusion_schema_provider__" ) ?
450+ . call1 ( ( codec_capsule, ) ) ?;
451+ }
452+
453+ let provider = if let Ok ( capsule) = schema_provider. downcast :: < PyCapsule > ( ) {
454+ validate_pycapsule ( capsule, "datafusion_schema_provider" ) ?;
455+
456+ let provider = unsafe { capsule. reference :: < FFI_SchemaProvider > ( ) } ;
457+ let provider: Arc < dyn SchemaProvider + Send > = provider. into ( ) ;
458+ provider as Arc < dyn SchemaProvider >
459+ } else {
460+ match schema_provider. extract :: < PySchema > ( ) {
461+ Ok ( py_schema) => py_schema. schema ,
462+ Err ( _) => Arc :: new ( RustWrappedPySchemaProvider :: new ( schema_provider. into ( ) ) )
463+ as Arc < dyn SchemaProvider > ,
464+ }
465+ } ;
466+
467+ Ok ( provider)
468+ }
469+
482470pub ( crate ) fn init_module ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
483471 m. add_class :: < PyCatalog > ( ) ?;
484472 m. add_class :: < PySchema > ( ) ?;
0 commit comments