Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
]

autosummary_generate = False
smartquotes = False


autosectionlabel_prefix_document = True
Expand Down
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ uuid = ["uuid-utils"]

[dependency-groups]
benchmarks = ["sqlalchemy[asyncio]", "psutil", "types-psutil", "duckdb-engine>=0.17.0"]
build = ["bump-my-version", "hatch-mypyc", "pydantic-settings"]
build = ["bump-my-version", "hatch-mypyc", "mypy>=1.19.1", "pydantic-settings"]
dev = [
{ include-group = "extras" },
{ include-group = "lint" },
Expand Down Expand Up @@ -115,7 +115,7 @@ extras = [
"dishka",
]
lint = [
"mypy>=1.13.0",
"mypy>=1.19.1",
"pre-commit>=3.5.0",
"pyright>=1.1.386",
"ruff>=0.7.1",
Expand Down Expand Up @@ -176,7 +176,7 @@ packages = ["sqlspec"]


[tool.hatch.build.targets.wheel.hooks.mypyc]
dependencies = ["hatch-mypyc", "hatch-cython"]
dependencies = ["hatch-mypyc", "hatch-cython", "mypy>=1.19.1"]
enable-by-default = false
exclude = [
"tests/**", # Test files
Expand All @@ -201,7 +201,6 @@ exclude = [
"sqlspec/adapters/**/data_dictionary.py", # Cross-module inheritance causes mypyc segfaults
"sqlspec/observability/_formatting.py", # Inherits from non-compiled logging.Formatter
"sqlspec/utils/arrow_helpers.py", # Arrow operations cause segfaults when compiled
"sqlspec/storage/backends/_iterators.py", # Async __anext__ + asyncio.to_thread causes mypyc segfault
]
include = [
"sqlspec/core/**/*.py", # Core module
Expand All @@ -212,7 +211,7 @@ include = [
"sqlspec/driver/**/*.py", # Driver module
"sqlspec/storage/registry.py", # Safe storage registry/runtime routing
"sqlspec/storage/errors.py", # Safe storage error normalization
"sqlspec/storage/backends/base.py", # Storage backend runtime base classes (iterators in _iterators.py)
"sqlspec/storage/backends/base.py", # Storage backend runtime base classes
"sqlspec/data_dictionary/**/*.py", # Data dictionary mixin (required for adapter inheritance)
"sqlspec/adapters/**/core.py", # Adapter compiled helpers
"sqlspec/adapters/**/type_converter.py", # All adapters type converters
Expand Down
6 changes: 6 additions & 0 deletions sqlspec/adapters/psycopg/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from psycopg import AsyncConnection, AsyncCursor, Connection, Cursor
from psycopg.rows import DictRow as PsycopgDictRow
from psycopg.sql import SQL as PsycopgSQL # noqa: N811
from psycopg.sql import Composed as PsycopgComposed
from psycopg.sql import Identifier as PsycopgIdentifier

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -203,8 +206,11 @@ async def __aexit__(
"PsycopgAsyncCursor",
"PsycopgAsyncRawCursor",
"PsycopgAsyncSessionContext",
"PsycopgComposed",
"PsycopgDictRow",
"PsycopgIdentifier",
"PsycopgPipelineDriver",
"PsycopgSQL",
"PsycopgSyncConnection",
"PsycopgSyncCursor",
"PsycopgSyncRawCursor",
Expand Down
20 changes: 10 additions & 10 deletions sqlspec/adapters/psycopg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections.abc import Sized
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from psycopg import sql as psycopg_sql
from typing_extensions import LiteralString

from sqlspec.adapters.psycopg._typing import PsycopgComposed, PsycopgIdentifier, PsycopgSQL
from sqlspec.core import (
SQL,
DriverParameterProfile,
Expand Down Expand Up @@ -88,7 +88,7 @@ class PreparedStackOperation(NamedTuple):
operation_index: int
operation: "StackOperation"
statement: "SQL"
sql: "LiteralString | psycopg_sql.SQL"
sql: "LiteralString | PsycopgSQL | PsycopgComposed"
parameters: "tuple[Any, ...] | dict[str, Any] | None"


Expand All @@ -113,23 +113,23 @@ def pipeline_supported() -> bool:
return False


def _compose_table_identifier(table: str) -> "psycopg_sql.Composed":
def _compose_table_identifier(table: str) -> "PsycopgComposed":
parts = [part for part in table.split(".") if part]
if not parts:
msg = "Table name must not be empty"
raise SQLSpecError(msg)
identifiers = [psycopg_sql.Identifier(part) for part in parts]
return psycopg_sql.SQL(".").join(identifiers)
identifiers = [PsycopgIdentifier(part) for part in parts]
return PsycopgSQL(".").join(identifiers)


def build_copy_from_command(table: str, columns: "list[str]") -> "psycopg_sql.Composed":
def build_copy_from_command(table: str, columns: "list[str]") -> "PsycopgComposed":
table_identifier = _compose_table_identifier(table)
column_sql = psycopg_sql.SQL(", ").join([psycopg_sql.Identifier(column) for column in columns])
return psycopg_sql.SQL("COPY {} ({}) FROM STDIN").format(table_identifier, column_sql)
column_sql = PsycopgSQL(", ").join([PsycopgIdentifier(column) for column in columns])
return PsycopgSQL("COPY {} ({}) FROM STDIN").format(table_identifier, column_sql)


def build_truncate_command(table: str) -> "psycopg_sql.Composed":
return psycopg_sql.SQL("TRUNCATE TABLE {}").format(_compose_table_identifier(table))
def build_truncate_command(table: str) -> "PsycopgComposed":
return PsycopgSQL("TRUNCATE TABLE {}").format(_compose_table_identifier(table))


def _identity(value: Any) -> Any:
Expand Down
9 changes: 5 additions & 4 deletions sqlspec/adapters/psycopg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from typing import TYPE_CHECKING, Any, cast

import psycopg
from psycopg import sql as psycopg_sql
from typing_extensions import LiteralString

from sqlspec.adapters.psycopg._typing import (
PsycopgAsyncConnection,
PsycopgAsyncCursor,
PsycopgAsyncSessionContext,
PsycopgComposed,
PsycopgSQL,
PsycopgSyncConnection,
PsycopgSyncCursor,
PsycopgSyncSessionContext,
Expand Down Expand Up @@ -111,7 +112,7 @@ def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[Prepare
operation_index=index,
operation=operation,
statement=sql_statement,
sql=cast("LiteralString | psycopg_sql.SQL", sql_text),
sql=cast("LiteralString | PsycopgSQL | PsycopgComposed", sql_text),
parameters=prepared_parameters,
)
)
Expand Down Expand Up @@ -396,7 +397,7 @@ def _raise_pending_exception(exception_ctx: "PsycopgSyncExceptionHandler") -> No
cursor = resource_stack.enter_context(self.with_cursor(self.connection))

try:
sql = cast("LiteralString | psycopg_sql.SQL", prepared.sql) # type: ignore[redundant-cast]
sql = cast("LiteralString | PsycopgSQL | PsycopgComposed", prepared.sql) # type: ignore[redundant-cast]
if prepared.parameters:
cursor.execute(sql, prepared.parameters)
else:
Expand Down Expand Up @@ -855,7 +856,7 @@ def _raise_pending_exception(exception_ctx: "PsycopgAsyncExceptionHandler") -> N
cursor = await resource_stack.enter_async_context(self.with_cursor(self.connection))

try:
sql = cast("LiteralString | psycopg_sql.SQL", prepared.sql) # type: ignore[redundant-cast]
sql = cast("LiteralString | PsycopgSQL | PsycopgComposed", prepared.sql) # type: ignore[redundant-cast]
if prepared.parameters:
await cursor.execute(sql, prepared.parameters)
else:
Expand Down
44 changes: 34 additions & 10 deletions sqlspec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,20 +968,21 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):

__slots__ = (
"_migration_commands",
"_migration_config",
"_migration_loader",
"_observability_runtime",
"_storage_capabilities",
"bind_key",
"connection_instance",
"driver_features",
"extension_config",
"migration_config",
"observability_config",
"statement_config",
)

_migration_loader: "SQLFileLoader"
_migration_commands: "SyncMigrationCommands[Any] | AsyncMigrationCommands[Any]"
_migration_config: "dict[str, Any] | MigrationConfig"
driver_type: "ClassVar[type[Any]]"
connection_type: "ClassVar[type[Any]]"
is_async: "ClassVar[bool]" = False
Expand All @@ -998,7 +999,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
bind_key: "str | None"
statement_config: "StatementConfig"
connection_instance: "PoolT | None"
migration_config: "dict[str, Any] | MigrationConfig"
extension_config: "ExtensionConfigs"
driver_features: "dict[str, Any]"
_storage_capabilities: "StorageCapabilities | None"
Expand All @@ -1022,6 +1022,20 @@ def __repr__(self) -> str:
])
return f"{type(self).__name__}({parts})"

@property
def migration_config(self) -> "dict[str, Any] | MigrationConfig":
"""Return the current migration configuration."""
return self._migration_config

@migration_config.setter
def migration_config(self, value: "dict[str, Any] | MigrationConfig | None") -> None:
"""Store migration configuration and refresh derived migration helpers."""
object.__setattr__(self, "_migration_config", dict(cast("dict[str, Any]", value) or {}))
if self._has_initialized_attribute("extension_config"):
self._ensure_extension_migrations()
if self._migration_components_ready():
self._initialize_migration_components()

def storage_capabilities(self) -> "StorageCapabilities":
"""Return cached storage capabilities for this configuration."""

Expand All @@ -1034,6 +1048,20 @@ def reset_storage_capabilities_cache(self) -> None:

self._storage_capabilities = None

def _has_initialized_attribute(self, attribute_name: str) -> bool:
"""Return whether a slot-backed attribute has been initialized."""
try:
object.__getattribute__(self, attribute_name)
except AttributeError:
return False
return True

def _migration_components_ready(self) -> bool:
"""Return whether migration helpers have already been initialized."""
return self._has_initialized_attribute("_migration_loader") and self._has_initialized_attribute(
"_migration_commands"
)

def _ensure_extension_migrations(self) -> None:
"""Auto-include extension migrations when extension_config has them configured.

Expand Down Expand Up @@ -1473,8 +1501,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down Expand Up @@ -1637,8 +1664,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down Expand Up @@ -1806,8 +1832,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down Expand Up @@ -2017,8 +2042,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down
Loading
Loading