Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions deepnote_toolkit/sql/duckdb_sql.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import sys

import duckdb
from duckdb_extensions import import_extension
from packaging.version import Version

from deepnote_toolkit.logging import LoggerManager

_DEEPNOTE_DUCKDB_CONNECTION = None
_DEFAULT_DUCKDB_SAMPLE_SIZE = 20_000

Expand Down Expand Up @@ -40,16 +43,30 @@ def _get_duckdb_connection():
duckdb.Connection: A connection to the DuckDB database.
"""
global _DEEPNOTE_DUCKDB_CONNECTION
logger = LoggerManager().get_logger()

if not _DEEPNOTE_DUCKDB_CONNECTION:
_DEEPNOTE_DUCKDB_CONNECTION = duckdb.connect(
database=":memory:", read_only=False
)

# DuckDB extensions are loaded from included wheels to prevent loading them
# from the internet on every notebook start
#
# Install and load the spatial extension. Primary use case: reading xlsx files
# e.g. SELECT * FROM st_read('excel.xlsx')
_DEEPNOTE_DUCKDB_CONNECTION.execute("install spatial;")
_DEEPNOTE_DUCKDB_CONNECTION.execute("load spatial;")
# there is also official excel extension, which mentions that Excel support from spatial extension
# may be removed in the future (see: https://duckdb.org/docs/stable/core_extensions/excel)
for extension_name in ["spatial", "excel"]:
try:
import_extension(
name=extension_name,
force_install=True,
con=_DEEPNOTE_DUCKDB_CONNECTION,
)
_DEEPNOTE_DUCKDB_CONNECTION.load_extension(extension_name)
except Exception as e:
logger.error(f"Failed to load DuckDB {extension_name} extension: {e}")

_set_sample_size(_DEEPNOTE_DUCKDB_CONNECTION, _DEFAULT_DUCKDB_SAMPLE_SIZE)

Expand Down
52 changes: 51 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ dependencies = [
"duckdb>=1.1.0,<2.0.0; python_version < '3.12'",
"duckdb>=1.1.0,<2.0.0; python_version >= '3.12'",
"duckdb>=1.4.1,<2.0.0; python_version >= '3.13'",
"duckdb-extensions>=1.1.0,<2.0.0", # bake in as dependency to not pull extensions from the internet on every notebook start
"duckdb-extension-spatial>=1.1.0,<2.0.0",
"duckdb-extension-excel>=1.1.0,<2.0.0",
"google-cloud-bigquery-storage==2.16.2; python_version < '3.13'",
"google-cloud-bigquery-storage>=2.33.1,<3; python_version>='3.13'",

Expand Down
97 changes: 97 additions & 0 deletions tests/unit/test_duckdb_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pandas as pd
import pytest

from deepnote_toolkit.sql.duckdb_sql import (
_get_duckdb_connection,
_set_sample_size,
_set_scan_all_frames,
)


@pytest.fixture(scope="function")
def duckdb_connection():
import deepnote_toolkit.sql.duckdb_sql as duckdb_sql_module

# reset the connection to ensure a fresh one is created for each test
duckdb_sql_module._DEEPNOTE_DUCKDB_CONNECTION = None
conn = _get_duckdb_connection()

try:
yield conn
finally:
conn.close()
duckdb_sql_module._DEEPNOTE_DUCKDB_CONNECTION = None


@pytest.mark.parametrize("extension_name", ["spatial", "excel"])
def test_extension_installed_and_loadable(duckdb_connection, extension_name):
result = duckdb_connection.execute(
f"SELECT installed FROM duckdb_extensions() WHERE extension_name = '{extension_name}'"
).fetchone()

assert (
result is not None
), f"{extension_name} extension should be found in duckdb_extensions()"
assert result[0] is True, f"{extension_name} extension should be installed"

loaded_result = duckdb_connection.execute(
f"SELECT loaded FROM duckdb_extensions() WHERE extension_name = '{extension_name}'"
).fetchone()
assert loaded_result[0] is True, f"{extension_name} extension should be loaded"


def test_connection_singleton_pattern():
conn1 = _get_duckdb_connection()
conn2 = _get_duckdb_connection()

assert conn1 is conn2, "Connection should be a singleton"


def test_set_sample_size(duckdb_connection):
_set_sample_size(duckdb_connection, 50000)
result = duckdb_connection.execute(
"SELECT value FROM duckdb_settings() WHERE name = 'pandas_analyze_sample'"
).fetchone()
assert int(result[0]) == 50000


def test_set_scan_all_frames(duckdb_connection):
_set_scan_all_frames(duckdb_connection, False)
result = duckdb_connection.execute(
"SELECT value FROM duckdb_settings() WHERE name = 'python_scan_all_frames'"
).fetchone()
assert result[0] == "false"

_set_scan_all_frames(duckdb_connection, True)
result = duckdb_connection.execute(
"SELECT value FROM duckdb_settings() WHERE name = 'python_scan_all_frames'"
).fetchone()
assert result[0] == "true"


def test_excel_extension_roundtrip(duckdb_connection, tmp_path):
test_data = pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"score": [95.5, 87.3, 91.2],
}
)
duckdb_connection.register("test_table", test_data)
excel_path = tmp_path / "test_data.xlsx"
duckdb_connection.execute(
f"COPY test_table TO '{excel_path}' WITH (FORMAT xlsx, HEADER true)"
)
duckdb_connection.unregister("test_table")

assert excel_path.exists(), "Excel file should be created"

# read with spatial extension
result = duckdb_connection.execute(f"SELECT * FROM st_read('{excel_path}')").df()
diff = test_data.compare(result)
assert diff.empty, "Data should be the same"

# read with excel extension
result = duckdb_connection.execute(f"SELECT * FROM read_xlsx('{excel_path}')").df()
diff = test_data.compare(result)
assert diff.empty, "Data should be the same"