Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 0fbc4cf

Browse files
authored
Merge pull request #133 from datafold/fix_uuids
Fix UUIDs + small fix for presto
2 parents bb12581 + c01ed65 commit 0fbc4cf

File tree

7 files changed

+51
-33
lines changed

7 files changed

+51
-33
lines changed

data_diff/databases/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from uuid import UUID
21
import math
32
import sys
43
import logging
@@ -16,7 +15,6 @@
1615
Integer,
1716
Decimal,
1817
Float,
19-
PrecisionType,
2018
TemporalType,
2119
UnknownColType,
2220
Text,

data_diff/databases/mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,4 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
7070
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
7171

7272
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
73-
return f"CAST(TRIM({value}) AS char)"
73+
return f"TRIM(CAST({value} AS char))"

data_diff/databases/presto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Presto(Database):
2727
"timestamp": Timestamp,
2828
# Numbers
2929
"integer": Integer,
30+
"bigint": Integer,
3031
"real": Float,
3132
"double": Float,
3233
# Text

data_diff/diff_tables.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,27 @@ def __post_init__(self):
7979
if self.min_update is not None and self.max_update is not None and self.min_update >= self.max_update:
8080
raise ValueError("Error: min_update expected to be smaller than max_update!")
8181

82-
@property
83-
def _key_column(self):
84-
return self._quote_column(self.key_column)
85-
8682
@property
8783
def _update_column(self):
8884
return self._quote_column(self.update_column)
8985

90-
def _quote_column(self, c):
86+
def _quote_column(self, c: str) -> str:
9187
if self._schema:
9288
c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive.
9389
return self.database.quote(c)
9490

91+
def _normalize_column(self, name: str, template: str = None) -> str:
92+
if not self._schema:
93+
raise RuntimeError(
94+
"Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()."
95+
)
96+
97+
col = self._quote_column(name)
98+
if template is not None:
99+
col = template % col # Apply template using Python's string formatting
100+
101+
return self.database.normalize_value_by_type(col, self._schema[name])
102+
95103
def with_schema(self) -> "TableSegment":
96104
"Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema."
97105
if self._schema:
@@ -115,9 +123,9 @@ def with_schema(self) -> "TableSegment":
115123

116124
def _make_key_range(self):
117125
if self.min_key is not None:
118-
yield Compare("<=", Value(self.min_key), self._key_column)
126+
yield Compare("<=", Value(self.min_key), self._quote_column(self.key_column))
119127
if self.max_key is not None:
120-
yield Compare("<", self._key_column, Value(self.max_key))
128+
yield Compare("<", self._quote_column(self.key_column), Value(self.max_key))
121129

122130
def _make_update_range(self):
123131
if self.min_update is not None:
@@ -127,7 +135,7 @@ def _make_update_range(self):
127135

128136
def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None):
129137
if columns is None:
130-
columns = [self._key_column]
138+
columns = [self._normalize_column(self.key_column)]
131139
where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where])
132140
order_by = None if order_by is None else [order_by]
133141
return Select(
@@ -184,14 +192,7 @@ def _relevant_columns(self) -> List[str]:
184192

185193
@property
186194
def _relevant_columns_repr(self) -> List[str]:
187-
if not self._schema:
188-
raise RuntimeError(
189-
"Cannot compile query when the schema is unknown. Please use TableSegment.with_schema()."
190-
)
191-
return [
192-
self.database.normalize_value_by_type(self._quote_column(c), self._schema[c])
193-
for c in self._relevant_columns
194-
]
195+
return [self._normalize_column(c) for c in self._relevant_columns]
195196

196197
def count(self) -> Tuple[int, int]:
197198
"""Count how many rows are in the segment, in one pass."""
@@ -214,7 +215,13 @@ def count_and_checksum(self) -> Tuple[int, int]:
214215

215216
def query_key_range(self) -> Tuple[int, int]:
216217
"""Query database for minimum and maximum key. This is used for setting the initial bounds."""
217-
select = self._make_select(columns=[Min(self._key_column), Max(self._key_column)])
218+
# Normalizes the result (needed for UUIDs) after the min/max computation
219+
select = self._make_select(
220+
columns=[
221+
self._normalize_column(self.key_column, "min(%s)"),
222+
self._normalize_column(self.key_column, "max(%s)"),
223+
]
224+
)
218225
min_key, max_key = self.database.query(select, tuple)
219226

220227
if min_key is None or max_key is None:
@@ -296,13 +303,16 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
296303
key_ranges = self._threaded_call("query_key_range", [table1, table2])
297304
mins, maxs = zip(*key_ranges)
298305

299-
key_type = table1._schema["id"]
300-
key_type2 = table2._schema["id"]
306+
key_type = table1._schema[table1.key_column]
307+
key_type2 = table2._schema[table2.key_column]
301308
assert key_type.python_type is key_type2.python_type
302309

303310
# We add 1 because our ranges are exclusive of the end (like in Python)
304-
min_key = min(map(key_type.python_type, mins))
305-
max_key = max(map(key_type.python_type, maxs)) + 1
311+
try:
312+
min_key = min(map(key_type.python_type, mins))
313+
max_key = max(map(key_type.python_type, maxs)) + 1
314+
except (TypeError, ValueError) as e:
315+
raise type(e)(f"Cannot apply {key_type} to {mins}, {maxs}.") from e
306316

307317
table1 = table1.new(min_key=min_key, max_key=max_key)
308318
table2 = table2.new(min_key=min_key, max_key=max_key)

data_diff/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def compile(self, c: Compiler):
124124
compiled_exprs = ", ".join(map(c.compile, self.exprs))
125125
expr = f"concat({compiled_exprs})"
126126
else:
127-
expr ,= self.exprs
127+
(expr,) = self.exprs
128128
expr = c.compile(expr)
129129
md5 = c.database.md5_to_int(expr)
130130
return f"sum({md5})"

tests/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES))
1818
BENCHMARK = os.environ.get("BENCHMARK", False)
1919

20+
2021
def get_git_revision_short_hash() -> str:
21-
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
22+
return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
23+
2224

23-
GIT_REVISION=get_git_revision_short_hash()
25+
GIT_REVISION = get_git_revision_short_hash()
2426

2527
level = logging.ERROR
2628
if os.environ.get("LOG_LEVEL", False):

tests/test_database_types.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __iter__(self):
186186
"int": [
187187
# "smallint", # 2 bytes
188188
"int", # 4 bytes
189-
"bigint", # 8 bytes
189+
"bigint", # 8 bytes
190190
],
191191
# https://www.postgresql.org/docs/current/datatype-datetime.html
192192
"datetime": [
@@ -214,7 +214,7 @@ def __iter__(self):
214214
# "smallint", # 2 bytes
215215
# "mediumint", # 3 bytes
216216
"int", # 4 bytes
217-
"bigint", # 8 bytes
217+
"bigint", # 8 bytes
218218
],
219219
# https://dev.mysql.com/doc/refman/8.0/en/datetime.html
220220
"datetime": [
@@ -327,7 +327,7 @@ def __iter__(self):
327327
# "smallint", # 2 bytes
328328
# "mediumint", # 3 bytes
329329
"int", # 4 bytes
330-
"bigint", # 8 bytes
330+
"bigint", # 8 bytes
331331
],
332332
"datetime": [
333333
"timestamp",
@@ -548,8 +548,12 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
548548
_insert_to_table(dst_conn, dst_table, values_in_source, target_type)
549549
insertion_target_duration = time.time() - start
550550

551-
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False)
552-
self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False)
551+
if type_category == "uuid":
552+
self.table = TableSegment(self.src_conn, src_table_path, "col", None, ("id",), case_sensitive=False)
553+
self.table2 = TableSegment(self.dst_conn, dst_table_path, "col", None, ("id",), case_sensitive=False)
554+
else:
555+
self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False)
556+
self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False)
553557

554558
start = time.time()
555559
self.assertEqual(N_SAMPLES, self.table.count())
@@ -595,7 +599,10 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
595599
download_duration = time.time() - start
596600
expected = []
597601
self.assertEqual(expected, diff)
598-
self.assertEqual(len(sample_values), differ.stats.get("rows_downloaded", 0))
602+
if type_category == "uuid":
603+
pass # UUIDs aren't serial, so they mess with the first max_rows estimation.
604+
else:
605+
self.assertEqual(len(sample_values), differ.stats.get("rows_downloaded", 0))
599606

600607
result = {
601608
"test": self._testMethodName,

0 commit comments

Comments
 (0)