Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sqlit/domains/connections/providers/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,22 @@ def quote_identifier(self, name: str) -> str:
"""Quote an identifier (table name, column name, etc.)."""
pass

def qualified_name(self, database: str | None, schema: str | None, name: str) -> str:
"""Build a quoted qualified identifier, skipping empty segments.

Default handles SQL Server-style `[db].[schema].[name]`, PostgreSQL-
style `"schema"."name"`, and single-part `"name"` by omitting any
empty/None component. Dialects that want different composition
(e.g. MySQL, which has no schemas within databases) can override.
"""
parts: list[str] = []
if database:
parts.append(self.quote_identifier(database))
if schema:
parts.append(self.quote_identifier(schema))
parts.append(self.quote_identifier(name))
return ".".join(parts)

@abstractmethod
def build_select_query(self, table: str, limit: int, database: str | None = None, schema: str | None = None) -> str:
"""Build a SELECT query with limit.
Expand Down
12 changes: 12 additions & 0 deletions sqlit/domains/connections/providers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def build_select_query(self, table: str, limit: int, database: str | None = None

def format_table_name(self, schema: str | None, table: str) -> str: ...

def qualified_name(self, database: str | None, schema: str | None, name: str) -> str:
"""Build a quoted qualified identifier for use in SQL.

Called when completing tables/views across multiple databases. Each
dialect decides how many segments to emit:
- SQL Server / generic: `[db].[schema].[name]`
- PostgreSQL: `"schema"."name"` (databases are isolated)
- MySQL/MariaDB: `` `db`.`name` `` (no schemas within databases)
- SQLite: `"name"`
"""
...


@runtime_checkable
class SchemaInspector(Protocol):
Expand Down
36 changes: 6 additions & 30 deletions sqlit/domains/query/ui/mixins/autocomplete_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,7 @@ def work() -> None:
if single_db:
full_name = table_name
else:
quoted_db = dialect.quote_identifier(database) if database else ""
quoted_schema = dialect.quote_identifier(schema_name)
quoted_table = dialect.quote_identifier(table_name)
if database:
full_name = f"{quoted_db}.{quoted_schema}.{quoted_table}"
else:
full_name = f"{quoted_schema}.{quoted_table}"
full_name = dialect.qualified_name(database, schema_name, table_name)

display_name = dialect.format_table_name(schema_name, table_name)
metadata = [
Expand Down Expand Up @@ -495,13 +489,7 @@ def work() -> None:
if single_db:
full_name = view_name
else:
quoted_db = dialect.quote_identifier(database) if database else ""
quoted_schema = dialect.quote_identifier(schema_name)
quoted_view = dialect.quote_identifier(view_name)
if database:
full_name = f"{quoted_db}.{quoted_schema}.{quoted_view}"
else:
full_name = f"{quoted_schema}.{quoted_view}"
full_name = dialect.qualified_name(database, schema_name, view_name)

display_name = dialect.format_table_name(schema_name, view_name)
metadata = [
Expand Down Expand Up @@ -816,14 +804,8 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any:
# Single database - use simple table name
schema_cache["tables"].append(table_name)
else:
# Multiple databases - use full qualifier [db].[schema].[table]
quoted_db = dialect.quote_identifier(database) if database else ""
quoted_schema = dialect.quote_identifier(schema_name)
quoted_table = dialect.quote_identifier(table_name)
if database:
full_name = f"{quoted_db}.{quoted_schema}.{quoted_table}"
else:
full_name = f"{quoted_schema}.{quoted_table}"
# Multiple databases - use qualified identifier
full_name = dialect.qualified_name(database, schema_name, table_name)
schema_cache["tables"].append(full_name)
# Keep metadata for column loading (multiple keys for flexible lookup)
display_name = dialect.format_table_name(schema_name, table_name)
Expand All @@ -843,14 +825,8 @@ async def run_db_call(fn: Any, *args: Any, **kwargs: Any) -> Any:
# Single database - use simple view name
schema_cache["views"].append(view_name)
else:
# Multiple databases - use full qualifier [db].[schema].[view]
quoted_db = dialect.quote_identifier(database) if database else ""
quoted_schema = dialect.quote_identifier(schema_name)
quoted_view = dialect.quote_identifier(view_name)
if database:
full_name = f"{quoted_db}.{quoted_schema}.{quoted_view}"
else:
full_name = f"{quoted_schema}.{quoted_view}"
# Multiple databases - use qualified identifier
full_name = dialect.qualified_name(database, schema_name, view_name)
schema_cache["views"].append(full_name)
# Keep metadata for column loading (multiple keys for flexible lookup)
display_name = dialect.format_table_name(schema_name, view_name)
Expand Down
16 changes: 14 additions & 2 deletions sqlit/domains/query/ui/mixins/autocomplete_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,25 @@ def _get_autocomplete_suggestions(self: AutocompleteMixinHost, text: str, cursor
alias_map = self._build_alias_map(text)
table_refs = extract_table_refs(text)
loading: set[str] = getattr(self, "_columns_loading", set())
table_metadata = getattr(self, "_table_metadata", {}) or {}

def needs_column_load(key: str) -> bool:
"""Return True only if the key names a known table that hasn't
been loaded yet. Returning Loading... for an unknown key would
wedge forever because the loader skips unknown tables silently.
"""
if key in columns:
return False
if key not in table_metadata:
return False
return True

for suggestion in suggestions:
if suggestion.type == SuggestionType.COLUMN:
# Check if any tables need column loading
for ref in table_refs:
table_key = ref.name.lower()
if table_key not in columns and table_key not in loading:
if needs_column_load(table_key) and table_key not in loading:
self._load_columns_for_table(table_key)
return ["Loading..."]
elif table_key in loading:
Expand All @@ -92,7 +104,7 @@ def _get_autocomplete_suggestions(self: AutocompleteMixinHost, text: str, cursor
scope_lower = scope.lower()
table_key = alias_map.get(scope_lower, scope_lower)

if table_key not in columns and table_key not in loading:
if needs_column_load(table_key) and table_key not in loading:
self._load_columns_for_table(table_key)
return ["Loading..."]
elif table_key in loading:
Expand Down
155 changes: 155 additions & 0 deletions tests/unit/test_autocomplete_multidb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Regression tests for MySQL autocomplete in multi-database scenarios (#151).

Two narrow bugs are covered:

1. Qualified identifiers for databases without schemas (MySQL/MariaDB) used
to render as `db`.``.`table` — an empty-backticked middle segment. The
qualifying logic is now a Dialect method so each adapter owns its own
composition rule.

2. Autocomplete returned a permanent "Loading..." sentinel whenever the
table reference in the query didn't resolve to anything in the schema
cache (e.g. the user typed `SELECT * FROM shop.cu`: `shop` was treated
as an alias and the loader spun forever for an unknown key).
"""

from __future__ import annotations

import pytest


# --------------------------------------------------------------------------
# Bug 1: Dialect.qualified_name
# --------------------------------------------------------------------------


def _get_dialect(db_type: str):
from sqlit.domains.connections.providers.catalog import get_provider

return get_provider(db_type).dialect


def test_mysql_qualified_name_is_two_part() -> None:
"""MySQL has no schema-within-database; the qualified form is
`db`.`table`, NOT `db`.``.`table`."""
dialect = _get_dialect("mysql")
assert dialect.qualified_name("shop", "", "customers") == "`shop`.`customers`"


def test_mariadb_qualified_name_is_two_part() -> None:
dialect = _get_dialect("mariadb")
assert dialect.qualified_name("shop", "", "customers") == "`shop`.`customers`"


def test_postgresql_qualified_name_uses_schema_only() -> None:
"""PostgreSQL databases are isolated; only schema.table makes sense
for cross-reference within the connected database."""
dialect = _get_dialect("postgresql")
# No db segment expected when schema is present.
assert dialect.qualified_name(None, "public", "users") == '"public"."users"'


def test_sqlserver_qualified_name_is_three_part() -> None:
"""SQL Server explicitly uses [db].[schema].[table]."""
dialect = _get_dialect("mssql")
assert dialect.qualified_name("app", "dbo", "Users") == "[app].[dbo].[Users]"


def test_qualified_name_skips_empty_segments_everywhere() -> None:
"""Regardless of dialect, empty db/schema segments must be omitted,
never rendered as empty-quoted placeholders."""
for db_type in ("mysql", "postgresql", "mssql", "sqlite"):
dialect = _get_dialect(db_type)
# bare table name: single quoted segment, no dot joins.
bare = dialect.qualified_name(None, None, "t")
assert "t" in bare, f"{db_type}: {bare}"
assert "." not in bare, f"{db_type} unexpected joins: {bare}"
# empty schema segment must not produce empty quote pair like `` or "" or [].
out = dialect.qualified_name(None, "", "t")
for marker in ("``", '""', "[]"):
assert marker not in out, f"{db_type} emitted empty-quoted segment: {out}"


def test_qualified_name_escapes_embedded_quote_chars() -> None:
"""Each dialect must escape its own quote char, not leak raw input."""
for db_type, payload, expected_substr in [
("mysql", "app`evil", "app``evil"),
("postgresql", 'app"evil', 'app""evil'),
("mssql", "app]evil", "app]]evil"),
]:
dialect = _get_dialect(db_type)
out = dialect.qualified_name(None, None, payload)
assert expected_substr in out, f"{db_type}: {out}"


# --------------------------------------------------------------------------
# Bug 2: stuck Loading... for unknown table references
# --------------------------------------------------------------------------


class _SchemaHost:
"""Minimal stand-in for AutocompleteMixinHost — just enough to drive
`_get_autocomplete_suggestions`."""

def __init__(self, tables, metadata, columns=None, loading=None):
self._schema_cache = {
"tables": tables,
"views": [],
"columns": columns or {},
"procedures": [],
}
self._table_metadata = metadata
self._columns_loading = loading or set()
self.load_calls: list[str] = []

def _load_columns_for_table(self, table_name: str) -> None:
self.load_calls.append(table_name)

def _build_alias_map(self, text: str) -> dict:
return {}


def test_unknown_table_ref_does_not_stick_on_loading() -> None:
"""`SELECT * FROM shop.cu` parses `shop` as an ALIAS_COLUMN scope. It
isn't a real table. Before the fix the completion engine called
_load_columns_for_table('shop') and returned `Loading...`; the loader
skipped unknown keys silently, so the sentinel never cleared."""
from sqlit.domains.query.ui.mixins.autocomplete_suggestions import AutocompleteSuggestionsMixin

host = _SchemaHost(
tables=["customers", "orders", "products"],
metadata={
"customers": ("", "customers", "shop"),
"shop.customers": ("", "customers", "shop"),
"orders": ("", "orders", "shop"),
"products": ("", "products", "shop"),
},
)

get_suggestions = AutocompleteSuggestionsMixin._get_autocomplete_suggestions.__get__(host)

text = "SELECT * FROM shop.cu"
result = get_suggestions(text, len(text))

assert result != ["Loading..."], (
"unknown table key must not pin Loading... — loader never clears it"
)
assert "shop" not in host.load_calls


def test_known_table_ref_still_triggers_loading_on_first_call() -> None:
"""Sanity check: the fix must not regress legit lazy loading."""
from sqlit.domains.query.ui.mixins.autocomplete_suggestions import AutocompleteSuggestionsMixin

host = _SchemaHost(
tables=["customers"],
metadata={"customers": ("", "customers", None)},
columns={},
)
get_suggestions = AutocompleteSuggestionsMixin._get_autocomplete_suggestions.__get__(host)

text = "SELECT * FROM customers WHERE em"
result = get_suggestions(text, len(text))

assert result == ["Loading..."]
assert "customers" in host.load_calls
Loading