Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 735f523

Browse files
authored
Merge pull request #271 from datafold/refactor_dialect
Refactor dialect
2 parents e07bb90 + b1a1453 commit 735f523

24 files changed

+710
-569
lines changed

data_diff/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]):
3939
try:
4040
args = run_args.pop(index)
4141
except KeyError:
42-
raise ConfigParseError(f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'.")
42+
raise ConfigParseError(
43+
f"Could not find source #{index}: Expecting a key of '{index}' containing '.database' and '.table'."
44+
)
4345
for attr in ("database", "table"):
4446
if attr not in args:
4547
raise ConfigParseError(f"Running 'run.{run_name}': Connection #{index} is missing attribute '{attr}'.")

data_diff/databases/base.py

Lines changed: 119 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from data_diff.queries import Expr, Compiler, table, Select, SKIP, Explain
1414
from .database_types import (
1515
AbstractDatabase,
16+
AbstractDialect,
17+
AbstractMixin_MD5,
18+
AbstractMixin_NormalizeValue,
1619
ColType,
1720
Integer,
1821
Decimal,
@@ -99,6 +102,116 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
99102
return callback(sql_code)
100103

101104

105+
class BaseDialect(AbstractDialect, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
106+
SUPPORTS_PRIMARY_KEY = False
107+
TYPE_CLASSES: Dict[str, type] = {}
108+
109+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
110+
if offset:
111+
raise NotImplementedError("No support for OFFSET in query")
112+
113+
return f"LIMIT {limit}"
114+
115+
def concat(self, items: List[str]) -> str:
116+
assert len(items) > 1
117+
joined_exprs = ", ".join(items)
118+
return f"concat({joined_exprs})"
119+
120+
def is_distinct_from(self, a: str, b: str) -> str:
121+
return f"{a} is distinct from {b}"
122+
123+
def timestamp_value(self, t: DbTime) -> str:
124+
return f"'{t.isoformat()}'"
125+
126+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
127+
if isinstance(coltype, String_UUID):
128+
return f"TRIM({value})"
129+
return self.to_string(value)
130+
131+
def random(self) -> str:
132+
return "RANDOM()"
133+
134+
def explain_as_text(self, query: str) -> str:
135+
return f"EXPLAIN {query}"
136+
137+
def _constant_value(self, v):
138+
if v is None:
139+
return "NULL"
140+
elif isinstance(v, str):
141+
return f"'{v}'"
142+
elif isinstance(v, datetime):
143+
# TODO use self.timestamp_value
144+
return f"timestamp '{v}'"
145+
elif isinstance(v, UUID):
146+
return f"'{v}'"
147+
return repr(v)
148+
149+
def constant_values(self, rows) -> str:
150+
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
151+
return f"VALUES {values}"
152+
153+
def type_repr(self, t) -> str:
154+
if isinstance(t, str):
155+
return t
156+
return {
157+
int: "INT",
158+
str: "VARCHAR",
159+
bool: "BOOLEAN",
160+
float: "FLOAT",
161+
datetime: "TIMESTAMP",
162+
}[t]
163+
164+
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
165+
return self.TYPE_CLASSES.get(type_repr)
166+
167+
def parse_type(
168+
self,
169+
table_path: DbPath,
170+
col_name: str,
171+
type_repr: str,
172+
datetime_precision: int = None,
173+
numeric_precision: int = None,
174+
numeric_scale: int = None,
175+
) -> ColType:
176+
""" """
177+
178+
cls = self._parse_type_repr(type_repr)
179+
if not cls:
180+
return UnknownColType(type_repr)
181+
182+
if issubclass(cls, TemporalType):
183+
return cls(
184+
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
185+
rounds=self.ROUNDS_ON_PREC_LOSS,
186+
)
187+
188+
elif issubclass(cls, Integer):
189+
return cls()
190+
191+
elif issubclass(cls, Decimal):
192+
if numeric_scale is None:
193+
numeric_scale = 0 # Needed for Oracle.
194+
return cls(precision=numeric_scale)
195+
196+
elif issubclass(cls, Float):
197+
# assert numeric_scale is None
198+
return cls(
199+
precision=self._convert_db_precision_to_digits(
200+
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
201+
)
202+
)
203+
204+
elif issubclass(cls, (Text, Native_UUID)):
205+
return cls()
206+
207+
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
208+
209+
def _convert_db_precision_to_digits(self, p: int) -> int:
210+
"""Convert from binary precision, used by floats, to decimal precision."""
211+
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
212+
return math.floor(math.log(2**p, 10))
213+
214+
102215
class Database(AbstractDatabase):
103216
"""Base abstract class for databases.
104217
@@ -107,10 +220,10 @@ class Database(AbstractDatabase):
107220
Instanciated using :meth:`~data_diff.connect`
108221
"""
109222

110-
TYPE_CLASSES: Dict[str, type] = {}
111223
default_schema: str = None
224+
dialect: AbstractDialect = None
225+
112226
SUPPORTS_ALPHANUMS = True
113-
SUPPORTS_PRIMARY_KEY = False
114227
SUPPORTS_UNIQUE_CONSTAINT = False
115228

116229
_interactive = False
@@ -169,56 +282,6 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
169282
def enable_interactive(self):
170283
self._interactive = True
171284

172-
def _convert_db_precision_to_digits(self, p: int) -> int:
173-
"""Convert from binary precision, used by floats, to decimal precision."""
174-
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
175-
return math.floor(math.log(2**p, 10))
176-
177-
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
178-
return self.TYPE_CLASSES.get(type_repr)
179-
180-
def _parse_type(
181-
self,
182-
table_path: DbPath,
183-
col_name: str,
184-
type_repr: str,
185-
datetime_precision: int = None,
186-
numeric_precision: int = None,
187-
numeric_scale: int = None,
188-
) -> ColType:
189-
""" """
190-
191-
cls = self._parse_type_repr(type_repr)
192-
if not cls:
193-
return UnknownColType(type_repr)
194-
195-
if issubclass(cls, TemporalType):
196-
return cls(
197-
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
198-
rounds=self.ROUNDS_ON_PREC_LOSS,
199-
)
200-
201-
elif issubclass(cls, Integer):
202-
return cls()
203-
204-
elif issubclass(cls, Decimal):
205-
if numeric_scale is None:
206-
numeric_scale = 0 # Needed for Oracle.
207-
return cls(precision=numeric_scale)
208-
209-
elif issubclass(cls, Float):
210-
# assert numeric_scale is None
211-
return cls(
212-
precision=self._convert_db_precision_to_digits(
213-
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
214-
)
215-
)
216-
217-
elif issubclass(cls, (Text, Native_UUID)):
218-
return cls()
219-
220-
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
221-
222285
def select_table_schema(self, path: DbPath) -> str:
223286
schema, table = self._normalize_table_path(path)
224287

@@ -257,7 +320,9 @@ def _process_table_schema(
257320
):
258321
accept = {i.lower() for i in filter_columns}
259322

260-
col_dict = {row[0]: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept}
323+
col_dict = {
324+
row[0]: self.dialect.parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept
325+
}
261326

262327
self._refine_coltypes(path, col_dict, where)
263328

@@ -274,7 +339,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
274339
if not text_columns:
275340
return
276341

277-
fields = [self.normalize_uuid(self.quote(c), String_UUID()) for c in text_columns]
342+
fields = [self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID()) for c in text_columns]
278343
samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list)
279344
if not samples_by_row:
280345
raise ValueError(f"Table {table_path} is empty.")
@@ -321,58 +386,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
321386
def parse_table_name(self, name: str) -> DbPath:
322387
return parse_table_name(name)
323388

324-
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
325-
if offset:
326-
raise NotImplementedError("No support for OFFSET in query")
327-
328-
return f"LIMIT {limit}"
329-
330-
def concat(self, items: List[str]) -> str:
331-
assert len(items) > 1
332-
joined_exprs = ", ".join(items)
333-
return f"concat({joined_exprs})"
334-
335-
def is_distinct_from(self, a: str, b: str) -> str:
336-
return f"{a} is distinct from {b}"
337-
338-
def timestamp_value(self, t: DbTime) -> str:
339-
return f"'{t.isoformat()}'"
340-
341-
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
342-
if isinstance(coltype, String_UUID):
343-
return f"TRIM({value})"
344-
return self.to_string(value)
345-
346-
def random(self) -> str:
347-
return "RANDOM()"
348-
349-
def _constant_value(self, v):
350-
if v is None:
351-
return "NULL"
352-
elif isinstance(v, str):
353-
return f"'{v}'"
354-
elif isinstance(v, datetime):
355-
# TODO use self.timestamp_value
356-
return f"timestamp '{v}'"
357-
elif isinstance(v, UUID):
358-
return f"'{v}'"
359-
return repr(v)
360-
361-
def constant_values(self, rows) -> str:
362-
values = ", ".join("(%s)" % ", ".join(self._constant_value(v) for v in row) for row in rows)
363-
return f"VALUES {values}"
364-
365-
def type_repr(self, t) -> str:
366-
if isinstance(t, str):
367-
return t
368-
return {
369-
int: "INT",
370-
str: "VARCHAR",
371-
bool: "BOOLEAN",
372-
float: "FLOAT",
373-
datetime: "TIMESTAMP",
374-
}[t]
375-
376389
def _query_cursor(self, c, sql_code: str):
377390
assert isinstance(sql_code, str), sql_code
378391
try:
@@ -389,9 +402,6 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis
389402
callback = partial(self._query_cursor, c)
390403
return apply_query(callback, sql_code)
391404

392-
def explain_as_text(self, query: str) -> str:
393-
return f"EXPLAIN {query}"
394-
395405

396406
class ThreadedDatabase(Database):
397407
"""Access the database through singleton threads.

0 commit comments

Comments
 (0)