1313from data_diff .queries import Expr , Compiler , table , Select , SKIP , Explain
1414from .database_types import (
1515 AbstractDatabase ,
16+ AbstractDialect ,
17+ AbstractMixin_MD5 ,
18+ AbstractMixin_NormalizeValue ,
1619 ColType ,
1720 Integer ,
1821 Decimal ,
@@ -99,6 +102,116 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
99102 return callback (sql_code )
100103
101104
105+ class BaseDialect (AbstractDialect , AbstractMixin_MD5 , AbstractMixin_NormalizeValue ):
106+ SUPPORTS_PRIMARY_KEY = False
107+ TYPE_CLASSES : Dict [str , type ] = {}
108+
109+ def offset_limit (self , offset : Optional [int ] = None , limit : Optional [int ] = None ):
110+ if offset :
111+ raise NotImplementedError ("No support for OFFSET in query" )
112+
113+ return f"LIMIT { limit } "
114+
115+ def concat (self , items : List [str ]) -> str :
116+ assert len (items ) > 1
117+ joined_exprs = ", " .join (items )
118+ return f"concat({ joined_exprs } )"
119+
120+ def is_distinct_from (self , a : str , b : str ) -> str :
121+ return f"{ a } is distinct from { b } "
122+
123+ def timestamp_value (self , t : DbTime ) -> str :
124+ return f"'{ t .isoformat ()} '"
125+
126+ def normalize_uuid (self , value : str , coltype : ColType_UUID ) -> str :
127+ if isinstance (coltype , String_UUID ):
128+ return f"TRIM({ value } )"
129+ return self .to_string (value )
130+
131+ def random (self ) -> str :
132+ return "RANDOM()"
133+
134+ def explain_as_text (self , query : str ) -> str :
135+ return f"EXPLAIN { query } "
136+
137+ def _constant_value (self , v ):
138+ if v is None :
139+ return "NULL"
140+ elif isinstance (v , str ):
141+ return f"'{ v } '"
142+ elif isinstance (v , datetime ):
143+ # TODO use self.timestamp_value
144+ return f"timestamp '{ v } '"
145+ elif isinstance (v , UUID ):
146+ return f"'{ v } '"
147+ return repr (v )
148+
149+ def constant_values (self , rows ) -> str :
150+ values = ", " .join ("(%s)" % ", " .join (self ._constant_value (v ) for v in row ) for row in rows )
151+ return f"VALUES { values } "
152+
153+ def type_repr (self , t ) -> str :
154+ if isinstance (t , str ):
155+ return t
156+ return {
157+ int : "INT" ,
158+ str : "VARCHAR" ,
159+ bool : "BOOLEAN" ,
160+ float : "FLOAT" ,
161+ datetime : "TIMESTAMP" ,
162+ }[t ]
163+
164+ def _parse_type_repr (self , type_repr : str ) -> Optional [Type [ColType ]]:
165+ return self .TYPE_CLASSES .get (type_repr )
166+
167+ def parse_type (
168+ self ,
169+ table_path : DbPath ,
170+ col_name : str ,
171+ type_repr : str ,
172+ datetime_precision : int = None ,
173+ numeric_precision : int = None ,
174+ numeric_scale : int = None ,
175+ ) -> ColType :
176+ """ """
177+
178+ cls = self ._parse_type_repr (type_repr )
179+ if not cls :
180+ return UnknownColType (type_repr )
181+
182+ if issubclass (cls , TemporalType ):
183+ return cls (
184+ precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
185+ rounds = self .ROUNDS_ON_PREC_LOSS ,
186+ )
187+
188+ elif issubclass (cls , Integer ):
189+ return cls ()
190+
191+ elif issubclass (cls , Decimal ):
192+ if numeric_scale is None :
193+ numeric_scale = 0 # Needed for Oracle.
194+ return cls (precision = numeric_scale )
195+
196+ elif issubclass (cls , Float ):
197+ # assert numeric_scale is None
198+ return cls (
199+ precision = self ._convert_db_precision_to_digits (
200+ numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
201+ )
202+ )
203+
204+ elif issubclass (cls , (Text , Native_UUID )):
205+ return cls ()
206+
207+ raise TypeError (f"Parsing { type_repr } returned an unknown type '{ cls } '." )
208+
209+ def _convert_db_precision_to_digits (self , p : int ) -> int :
210+ """Convert from binary precision, used by floats, to decimal precision."""
211+ # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
212+ return math .floor (math .log (2 ** p , 10 ))
213+
214+
102215class Database (AbstractDatabase ):
103216 """Base abstract class for databases.
104217
@@ -107,10 +220,10 @@ class Database(AbstractDatabase):
107220 Instanciated using :meth:`~data_diff.connect`
108221 """
109222
110- TYPE_CLASSES : Dict [str , type ] = {}
111223 default_schema : str = None
224+ dialect : AbstractDialect = None
225+
112226 SUPPORTS_ALPHANUMS = True
113- SUPPORTS_PRIMARY_KEY = False
114227 SUPPORTS_UNIQUE_CONSTAINT = False
115228
116229 _interactive = False
@@ -169,56 +282,6 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list):
169282 def enable_interactive (self ):
170283 self ._interactive = True
171284
172- def _convert_db_precision_to_digits (self , p : int ) -> int :
173- """Convert from binary precision, used by floats, to decimal precision."""
174- # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
175- return math .floor (math .log (2 ** p , 10 ))
176-
177- def _parse_type_repr (self , type_repr : str ) -> Optional [Type [ColType ]]:
178- return self .TYPE_CLASSES .get (type_repr )
179-
180- def _parse_type (
181- self ,
182- table_path : DbPath ,
183- col_name : str ,
184- type_repr : str ,
185- datetime_precision : int = None ,
186- numeric_precision : int = None ,
187- numeric_scale : int = None ,
188- ) -> ColType :
189- """ """
190-
191- cls = self ._parse_type_repr (type_repr )
192- if not cls :
193- return UnknownColType (type_repr )
194-
195- if issubclass (cls , TemporalType ):
196- return cls (
197- precision = datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION ,
198- rounds = self .ROUNDS_ON_PREC_LOSS ,
199- )
200-
201- elif issubclass (cls , Integer ):
202- return cls ()
203-
204- elif issubclass (cls , Decimal ):
205- if numeric_scale is None :
206- numeric_scale = 0 # Needed for Oracle.
207- return cls (precision = numeric_scale )
208-
209- elif issubclass (cls , Float ):
210- # assert numeric_scale is None
211- return cls (
212- precision = self ._convert_db_precision_to_digits (
213- numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
214- )
215- )
216-
217- elif issubclass (cls , (Text , Native_UUID )):
218- return cls ()
219-
220- raise TypeError (f"Parsing { type_repr } returned an unknown type '{ cls } '." )
221-
222285 def select_table_schema (self , path : DbPath ) -> str :
223286 schema , table = self ._normalize_table_path (path )
224287
@@ -257,7 +320,9 @@ def _process_table_schema(
257320 ):
258321 accept = {i .lower () for i in filter_columns }
259322
260- col_dict = {row [0 ]: self ._parse_type (path , * row ) for name , row in raw_schema .items () if name .lower () in accept }
323+ col_dict = {
324+ row [0 ]: self .dialect .parse_type (path , * row ) for name , row in raw_schema .items () if name .lower () in accept
325+ }
261326
262327 self ._refine_coltypes (path , col_dict , where )
263328
@@ -274,7 +339,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
274339 if not text_columns :
275340 return
276341
277- fields = [self .normalize_uuid (self .quote (c ), String_UUID ()) for c in text_columns ]
342+ fields = [self .dialect . normalize_uuid (self . dialect .quote (c ), String_UUID ()) for c in text_columns ]
278343 samples_by_row = self .query (table (* table_path ).select (* fields ).where (where or SKIP ).limit (sample_size ), list )
279344 if not samples_by_row :
280345 raise ValueError (f"Table { table_path } is empty." )
@@ -321,58 +386,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
321386 def parse_table_name (self , name : str ) -> DbPath :
322387 return parse_table_name (name )
323388
324- def offset_limit (self , offset : Optional [int ] = None , limit : Optional [int ] = None ):
325- if offset :
326- raise NotImplementedError ("No support for OFFSET in query" )
327-
328- return f"LIMIT { limit } "
329-
330- def concat (self , items : List [str ]) -> str :
331- assert len (items ) > 1
332- joined_exprs = ", " .join (items )
333- return f"concat({ joined_exprs } )"
334-
335- def is_distinct_from (self , a : str , b : str ) -> str :
336- return f"{ a } is distinct from { b } "
337-
338- def timestamp_value (self , t : DbTime ) -> str :
339- return f"'{ t .isoformat ()} '"
340-
341- def normalize_uuid (self , value : str , coltype : ColType_UUID ) -> str :
342- if isinstance (coltype , String_UUID ):
343- return f"TRIM({ value } )"
344- return self .to_string (value )
345-
346- def random (self ) -> str :
347- return "RANDOM()"
348-
349- def _constant_value (self , v ):
350- if v is None :
351- return "NULL"
352- elif isinstance (v , str ):
353- return f"'{ v } '"
354- elif isinstance (v , datetime ):
355- # TODO use self.timestamp_value
356- return f"timestamp '{ v } '"
357- elif isinstance (v , UUID ):
358- return f"'{ v } '"
359- return repr (v )
360-
361- def constant_values (self , rows ) -> str :
362- values = ", " .join ("(%s)" % ", " .join (self ._constant_value (v ) for v in row ) for row in rows )
363- return f"VALUES { values } "
364-
365- def type_repr (self , t ) -> str :
366- if isinstance (t , str ):
367- return t
368- return {
369- int : "INT" ,
370- str : "VARCHAR" ,
371- bool : "BOOLEAN" ,
372- float : "FLOAT" ,
373- datetime : "TIMESTAMP" ,
374- }[t ]
375-
376389 def _query_cursor (self , c , sql_code : str ):
377390 assert isinstance (sql_code , str ), sql_code
378391 try :
@@ -389,9 +402,6 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis
389402 callback = partial (self ._query_cursor , c )
390403 return apply_query (callback , sql_code )
391404
392- def explain_as_text (self , query : str ) -> str :
393- return f"EXPLAIN { query } "
394-
395405
396406class ThreadedDatabase (Database ):
397407 """Access the database through singleton threads.
0 commit comments