Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 96 additions & 21 deletions ggsql-python/python/ggsql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from __future__ import annotations

import json
from typing import Any, Union
from typing import Any, Protocol, Union, runtime_checkable

import altair
import narwhals as nw
from narwhals.typing import IntoFrame
import polars as pl

from ggsql._ggsql import (
DuckDBReader,
VegaLiteWriter,
VegaLiteWriter as _RustVegaLiteWriter,
Validated,
Spec,
validate,
execute,
ParseError,
ValidationError,
ReaderError,
WriterError,
)

__all__ = [
Expand All @@ -22,12 +27,18 @@
"VegaLiteWriter",
"Validated",
"Spec",
"Reader",
# Functions
"validate",
"execute",
"render_altair",
# Exceptions
"ParseError",
"ValidationError",
"ReaderError",
"WriterError",
]
__version__ = "0.1.0"
__version__ = "0.1.4"

# Type alias for any Altair chart type
AltairChart = Union[
Expand All @@ -41,6 +52,87 @@
]


@runtime_checkable
class Reader(Protocol):
"""Protocol for ggsql database readers.

Any object implementing these methods can be used as a reader with
``ggsql.execute()``. Native readers like ``DuckDBReader`` satisfy
this protocol automatically.

Required methods
----------------
execute_sql(sql: str) -> polars.DataFrame
Execute a SQL query and return results as a polars DataFrame.
register(name: str, df: polars.DataFrame, replace: bool = False) -> None
Register a DataFrame as a named table for SQL queries.
"""

def execute_sql(self, sql: str) -> pl.DataFrame: ...

def register(
self, name: str, df: pl.DataFrame, replace: bool = False
) -> None: ...


def _json_to_altair_chart(vegalite_json: str, **kwargs: Any) -> AltairChart:
"""Convert a Vega-Lite JSON string to the appropriate Altair chart type."""
spec = json.loads(vegalite_json)

if "layer" in spec:
return altair.LayerChart.from_json(vegalite_json, **kwargs)
elif "facet" in spec or "spec" in spec:
return altair.FacetChart.from_json(vegalite_json, **kwargs)
elif "concat" in spec:
return altair.ConcatChart.from_json(vegalite_json, **kwargs)
elif "hconcat" in spec:
return altair.HConcatChart.from_json(vegalite_json, **kwargs)
elif "vconcat" in spec:
return altair.VConcatChart.from_json(vegalite_json, **kwargs)
elif "repeat" in spec:
return altair.RepeatChart.from_json(vegalite_json, **kwargs)
else:
return altair.Chart.from_json(vegalite_json, **kwargs)


class VegaLiteWriter:
"""Vega-Lite v6 JSON output writer.

Methods
-------
render(spec)
Render a Spec to a Vega-Lite JSON string.
render_chart(spec, **kwargs)
Render a Spec to an Altair chart object.
"""

def __init__(self) -> None:
self._inner = _RustVegaLiteWriter()

def render(self, spec: Spec) -> str:
"""Render a Spec to a Vega-Lite JSON string."""
return self._inner.render(spec)

def render_chart(self, spec: Spec, **kwargs: Any) -> AltairChart:
"""Render a Spec to an Altair chart object.

Parameters
----------
spec
The resolved visualization specification from ``reader.execute()``.
**kwargs
Additional keyword arguments passed to ``altair.Chart.from_json()``.
Common options include ``validate=False`` to skip schema validation.

Returns
-------
AltairChart
An Altair chart object (Chart, LayerChart, FacetChart, etc.).
"""
vegalite_json = self.render(spec)
return _json_to_altair_chart(vegalite_json, **kwargs)


def render_altair(
df: IntoFrame,
viz: str,
Expand Down Expand Up @@ -86,21 +178,4 @@ def render_altair(
writer = VegaLiteWriter()
vegalite_json = writer.render(spec)

# Parse to determine the correct Altair class
spec = json.loads(vegalite_json)

# Determine the correct Altair class based on spec structure
if "layer" in spec:
return altair.LayerChart.from_json(vegalite_json, **kwargs)
elif "facet" in spec or "spec" in spec:
return altair.FacetChart.from_json(vegalite_json, **kwargs)
elif "concat" in spec:
return altair.ConcatChart.from_json(vegalite_json, **kwargs)
elif "hconcat" in spec:
return altair.HConcatChart.from_json(vegalite_json, **kwargs)
elif "vconcat" in spec:
return altair.VConcatChart.from_json(vegalite_json, **kwargs)
elif "repeat" in spec:
return altair.RepeatChart.from_json(vegalite_json, **kwargs)
else:
return altair.Chart.from_json(vegalite_json, **kwargs)
return _json_to_altair_chart(vegalite_json, **kwargs)
84 changes: 65 additions & 19 deletions ggsql-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// See: https://github.com/PyO3/pyo3/issues/4327
#![allow(clippy::useless_conversion)]

use pyo3::create_exception;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList};
use std::io::Cursor;
Expand All @@ -12,6 +14,48 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning};
use ggsql::writer::{VegaLiteWriter as RustVegaLiteWriter, Writer as RustWriter};
use ggsql::GgsqlError;

// ============================================================================
// Custom Exception Classes
// ============================================================================

// All subclass ValueError for backwards compatibility
create_exception!(
ggsql,
ParseError,
PyValueError,
"Raised on query syntax errors."
);
create_exception!(
ggsql,
ValidationError,
PyValueError,
"Raised on semantic validation errors."
);
create_exception!(
ggsql,
ReaderError,
PyValueError,
"Raised on data source errors."
);
create_exception!(
ggsql,
WriterError,
PyValueError,
"Raised on output generation errors."
);

/// Convert a GgsqlError to the appropriate typed Python exception.
fn ggsql_err_to_py(e: GgsqlError) -> PyErr {
let msg = e.to_string();
match e {
GgsqlError::ParseError(_) => PyErr::new::<ParseError, _>(msg),
GgsqlError::ValidationError(_) => PyErr::new::<ValidationError, _>(msg),
GgsqlError::ReaderError(_) => PyErr::new::<ReaderError, _>(msg),
GgsqlError::WriterError(_) => PyErr::new::<WriterError, _>(msg),
GgsqlError::InternalError(_) => PyErr::new::<PyValueError, _>(msg),
}
}

use polars::prelude::{DataFrame, IpcReader, IpcWriter, SerReader, SerWriter};

// ============================================================================
Expand Down Expand Up @@ -142,9 +186,13 @@ impl Reader for PyReaderBridge {
Python::attach(|py| {
let py_df =
polars_to_py(py, &df).map_err(|e| GgsqlError::ReaderError(e.to_string()))?;
let kwargs = PyDict::new(py);
kwargs
.set_item("replace", replace)
.map_err(|e| GgsqlError::ReaderError(e.to_string()))?;
self.obj
.bind(py)
.call_method1("register", (name, py_df, replace))
.call_method("register", (name, py_df), Some(&kwargs))
.map_err(|e| GgsqlError::ReaderError(format!("Reader.register() failed: {}", e)))?;
Ok(())
})
Expand Down Expand Up @@ -175,7 +223,7 @@ macro_rules! try_native_readers {
if let Ok(native) = $reader.downcast::<$native_type>() {
return native.borrow().inner.execute($query)
.map(|s| PySpec { inner: s })
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()));
.map_err(ggsql_err_to_py);
}
)*
}};
Expand Down Expand Up @@ -224,8 +272,8 @@ impl PyDuckDBReader {
/// If the connection string is invalid or the database cannot be opened.
#[new]
fn new(connection: &str) -> PyResult<Self> {
let inner = RustDuckDBReader::from_connection_string(connection)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let inner =
RustDuckDBReader::from_connection_string(connection).map_err(ggsql_err_to_py)?;
Ok(Self { inner })
}

Expand Down Expand Up @@ -255,7 +303,7 @@ impl PyDuckDBReader {
let rust_df = py_to_polars(py, df)?;
self.inner
.register(name, rust_df, replace)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
.map_err(ggsql_err_to_py)
}

/// Unregister a previously registered table.
Expand All @@ -270,9 +318,7 @@ impl PyDuckDBReader {
/// ValueError
/// If the table wasn't registered via this reader or unregistration fails.
fn unregister(&self, name: &str) -> PyResult<()> {
self.inner
.unregister(name)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
self.inner.unregister(name).map_err(ggsql_err_to_py)
}

/// Execute a SQL query and return the result as a DataFrame.
Expand All @@ -292,10 +338,7 @@ impl PyDuckDBReader {
/// ValueError
/// If the SQL is invalid or execution fails.
fn execute_sql(&self, py: Python<'_>, sql: &str) -> PyResult<Py<PyAny>> {
let df = self
.inner
.execute_sql(sql)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let df = self.inner.execute_sql(sql).map_err(ggsql_err_to_py)?;
polars_to_py(py, &df)
}

Expand Down Expand Up @@ -330,7 +373,7 @@ impl PyDuckDBReader {
self.inner
.execute(query)
.map(|s| PySpec { inner: s })
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
.map_err(ggsql_err_to_py)
}
}

Expand Down Expand Up @@ -391,9 +434,7 @@ impl PyVegaLiteWriter {
/// >>> writer = VegaLiteWriter()
/// >>> json_output = writer.render(spec)
fn render(&self, spec: &PySpec) -> PyResult<String> {
self.inner
.render(&spec.inner)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
self.inner.render(&spec.inner).map_err(ggsql_err_to_py)
}
}

Expand Down Expand Up @@ -657,8 +698,7 @@ impl PySpec {
/// If validation fails unexpectedly (not for syntax errors, which are captured).
#[pyfunction]
fn validate(query: &str) -> PyResult<PyValidated> {
let v = rust_validate(query)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let v = rust_validate(query).map_err(ggsql_err_to_py)?;

Ok(PyValidated {
sql: v.sql().to_string(),
Expand Down Expand Up @@ -739,7 +779,7 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
bridge
.execute(query)
.map(|s| PySpec { inner: s })
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
.map_err(ggsql_err_to_py)
}

// ============================================================================
Expand All @@ -748,6 +788,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {

#[pymodule]
fn _ggsql(m: &Bound<'_, PyModule>) -> PyResult<()> {
// Exceptions
m.add("ParseError", m.py().get_type::<ParseError>())?;
m.add("ValidationError", m.py().get_type::<ValidationError>())?;
m.add("ReaderError", m.py().get_type::<ReaderError>())?;
m.add("WriterError", m.py().get_type::<WriterError>())?;

// Classes
m.add_class::<PyDuckDBReader>()?;
m.add_class::<PyVegaLiteWriter>()?;
Expand Down
Loading
Loading