diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index d871fdb71..ecf5545bc 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -26,7 +26,7 @@ except ImportError: import importlib_metadata -from . import functions, object_store, substrait +from . import functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -89,6 +89,7 @@ "udaf", "udf", "udwf", + "unparser", ] diff --git a/python/datafusion/unparser.py b/python/datafusion/unparser.py new file mode 100644 index 000000000..7ca5b9190 --- /dev/null +++ b/python/datafusion/unparser.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module provides support for unparsing datafusion plans to SQL. + +For additional information about unparsing, see https://docs.rs/datafusion-sql/latest/datafusion_sql/unparser/index.html +""" + +from ._internal import unparser as unparser_internal +from .plan import LogicalPlan + + +class Dialect: + """DataFusion data catalog.""" + + def __init__(self, dialect: unparser_internal.Dialect) -> None: + """This constructor is not typically called by the end user.""" + self.dialect = dialect + + @staticmethod + def default() -> "Dialect": + """Create a new default dialect.""" + return Dialect(unparser_internal.Dialect.default()) + + @staticmethod + def mysql() -> "Dialect": + """Create a new MySQL dialect.""" + return Dialect(unparser_internal.Dialect.mysql()) + + @staticmethod + def postgres() -> "Dialect": + """Create a new PostgreSQL dialect.""" + return Dialect(unparser_internal.Dialect.postgres()) + + @staticmethod + def sqlite() -> "Dialect": + """Create a new SQLite dialect.""" + return Dialect(unparser_internal.Dialect.sqlite()) + + @staticmethod + def duckdb() -> "Dialect": + """Create a new DuckDB dialect.""" + return Dialect(unparser_internal.Dialect.duckdb()) + + +class Unparser: + """DataFusion unparser.""" + + def __init__(self, dialect: Dialect) -> None: + """This constructor is not typically called by the end user.""" + self.unparser = unparser_internal.Unparser(dialect.dialect) + + def plan_to_sql(self, plan: LogicalPlan) -> str: + """Convert a logical plan to a SQL string.""" + return self.unparser.plan_to_sql(plan._raw_plan) + + def with_pretty(self, pretty: bool) -> "Unparser": + """Set the pretty flag.""" + self.unparser = self.unparser.with_pretty(pretty) + return self + + +__all__ = [ + "Dialect", + "Unparser", +] diff --git a/python/tests/test_unparser.py b/python/tests/test_unparser.py new file mode 100644 index 000000000..c4e05780c --- /dev/null +++ b/python/tests/test_unparser.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion.context import SessionContext +from datafusion.unparser import Dialect, Unparser + + +def test_unparser(): + ctx = SessionContext() + df = ctx.sql("SELECT 1") + for dialect in [ + Dialect.mysql(), + Dialect.postgres(), + Dialect.sqlite(), + Dialect.duckdb(), + ]: + unparser = Unparser(dialect) + sql = unparser.plan_to_sql(df.logical_plan()) + assert sql == "SELECT 1" diff --git a/src/lib.rs b/src/lib.rs index ce93ff0c3..6eeda0878 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ pub mod pyarrow_util; mod record_batch; pub mod sql; pub mod store; +pub mod unparser; #[cfg(feature = "substrait")] pub mod substrait; @@ -103,6 +104,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { expr::init_module(&expr)?; m.add_submodule(&expr)?; + let unparser = PyModule::new(py, "unparser")?; + unparser::init_module(&unparser)?; + m.add_submodule(&unparser)?; + // Register the functions as a submodule let funcs = PyModule::new(py, "functions")?; functions::init_module(&funcs)?; diff --git a/src/unparser/dialect.rs b/src/unparser/dialect.rs new file mode 100644 index 000000000..caeef9949 --- /dev/null +++ b/src/unparser/dialect.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::sql::unparser::dialect::{ + DefaultDialect, Dialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect, +}; +use pyo3::prelude::*; + +#[pyclass(name = "Dialect", module = "datafusion.unparser", subclass)] +#[derive(Clone)] +pub struct PyDialect { + pub dialect: Arc, +} + +#[pymethods] +impl PyDialect { + #[staticmethod] + pub fn default() -> Self { + Self { + dialect: Arc::new(DefaultDialect {}), + } + } + #[staticmethod] + pub fn postgres() -> Self { + Self { + dialect: Arc::new(PostgreSqlDialect {}), + } + } + #[staticmethod] + pub fn mysql() -> Self { + Self { + dialect: Arc::new(MySqlDialect {}), + } + } + #[staticmethod] + pub fn sqlite() -> Self { + Self { + dialect: Arc::new(SqliteDialect {}), + } + } + #[staticmethod] + pub fn duckdb() -> Self { + Self { + dialect: Arc::new(DuckDBDialect::new()), + } + } +} diff --git a/src/unparser/mod.rs b/src/unparser/mod.rs new file mode 100644 index 000000000..b4b0fed10 --- /dev/null +++ b/src/unparser/mod.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod dialect; + +use std::sync::Arc; + +use datafusion::sql::unparser::{dialect::Dialect, Unparser}; +use dialect::PyDialect; +use pyo3::{exceptions::PyValueError, prelude::*}; + +use crate::sql::logical::PyLogicalPlan; + +#[pyclass(name = "Unparser", module = "datafusion.unparser", subclass)] +#[derive(Clone)] +pub struct PyUnparser { + dialect: Arc, + pretty: bool, +} + +#[pymethods] +impl PyUnparser { + #[new] + pub fn new(dialect: PyDialect) -> Self { + Self { + dialect: dialect.dialect.clone(), + pretty: false, + } + } + + pub fn plan_to_sql(&self, plan: &PyLogicalPlan) -> PyResult { + let mut unparser = Unparser::new(self.dialect.as_ref()); + unparser = unparser.with_pretty(self.pretty); + let sql = unparser + .plan_to_sql(&plan.plan()) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + Ok(sql.to_string()) + } + + pub fn with_pretty(&self, pretty: bool) -> Self { + Self { + dialect: self.dialect.clone(), + pretty, + } + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +}