1+ import math
12from typing import Dict , Sequence
23import logging
34
1314 ColType ,
1415 UnknownColType ,
1516)
16- from .base import MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , Database , import_helper , parse_table_name
17+ from .base import MD5_HEXDIGITS , CHECKSUM_HEXDIGITS , BaseDialect , ThreadedDatabase , import_helper , parse_table_name
1718
1819
1920@import_helper (text = "You can install it using 'pip install databricks-sql-connector'" )
@@ -61,54 +62,57 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6162 return f"date_format({ value } , 'yyyy-MM-dd HH:mm:ss.{ precision_format } ')"
6263
6364 def normalize_number (self , value : str , coltype : NumericType ) -> str :
64- return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
65+ value = f"cast({ value } as decimal(38, { coltype .precision } ))"
66+ if coltype .precision > 0 :
67+ value = f"format_number({ value } , { coltype .precision } )"
68+ return f"replace({ self .to_string (value )} , ',', '')"
6569
6670 def _convert_db_precision_to_digits (self , p : int ) -> int :
67- # Subtracting 1 due to wierd precision issues
68- return max (super ()._convert_db_precision_to_digits (p ) - 1 , 0 )
71+ # Subtracting 2 due to wierd precision issues
72+ return max (super ()._convert_db_precision_to_digits (p ) - 2 , 0 )
6973
7074
71- class Databricks (Database ):
75+ class Databricks (ThreadedDatabase ):
7276 dialect = Dialect ()
7377
74- def __init__ (
75- self ,
76- http_path : str ,
77- access_token : str ,
78- server_hostname : str ,
79- catalog : str = "hive_metastore" ,
80- schema : str = "default" ,
81- ** kwargs ,
82- ):
83- databricks = import_databricks ()
84-
85- self ._conn = databricks .sql .connect (
86- server_hostname = server_hostname , http_path = http_path , access_token = access_token
87- )
88-
78+ def __init__ (self , * , thread_count , ** kw ):
8979 logging .getLogger ("databricks.sql" ).setLevel (logging .WARNING )
9080
91- self .catalog = catalog
92- self .default_schema = schema
93- self . kwargs = kwargs
81+ self ._args = kw
82+ self .default_schema = kw . get ( " schema" , "hive_metastore" )
83+ super (). __init__ ( thread_count = thread_count )
9484
95- def _query (self , sql_code : str ) -> list :
96- "Uses the standard SQL cursor interface"
97- return self ._query_conn (self ._conn , sql_code )
85+ def create_connection (self ):
86+ databricks = import_databricks ()
87+
88+ try :
89+ return databricks .sql .connect (
90+ server_hostname = self ._args ["server_hostname" ],
91+ http_path = self ._args ["http_path" ],
92+ access_token = self ._args ["access_token" ],
93+ catalog = self ._args ["catalog" ],
94+ )
95+ except databricks .sql .exc .Error as e :
96+ raise ConnectionError (* e .args ) from e
9897
9998 def query_table_schema (self , path : DbPath ) -> Dict [str , tuple ]:
10099 # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
101100 # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
102101 # So, to obtain information about schema, we should use another approach.
103102
103+ conn = self .create_connection ()
104+
104105 schema , table = self ._normalize_table_path (path )
105- with self ._conn .cursor () as cursor :
106- cursor .columns (catalog_name = self .catalog , schema_name = schema , table_name = table )
107- rows = cursor .fetchall ()
106+ with conn .cursor () as cursor :
107+ cursor .columns (catalog_name = self ._args ["catalog" ], schema_name = schema , table_name = table )
108+ try :
109+ rows = cursor .fetchall ()
110+ finally :
111+ conn .close ()
108112 if not rows :
109113 raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
110114
111- d = {r .COLUMN_NAME : r for r in rows }
115+ d = {r .COLUMN_NAME : ( r . COLUMN_NAME , r . TYPE_NAME , r . DECIMAL_DIGITS , None , None ) for r in rows }
112116 assert len (d ) == len (rows )
113117 return d
114118
@@ -120,27 +124,26 @@ def _process_table_schema(
120124
121125 resulted_rows = []
122126 for row in rows :
123- row_type = "DECIMAL" if row . DATA_TYPE == 3 else row . TYPE_NAME
124- type_cls = self .TYPE_CLASSES .get (row_type , UnknownColType )
127+ row_type = "DECIMAL" if row [ 1 ]. startswith ( "DECIMAL" ) else row [ 1 ]
128+ type_cls = self .dialect . TYPE_CLASSES .get (row_type , UnknownColType )
125129
126130 if issubclass (type_cls , Integer ):
127- row = (row . COLUMN_NAME , row_type , None , None , 0 )
131+ row = (row [ 0 ] , row_type , None , None , 0 )
128132
129133 elif issubclass (type_cls , Float ):
130- numeric_precision = self . _convert_db_precision_to_digits (row . DECIMAL_DIGITS )
131- row = (row . COLUMN_NAME , row_type , None , numeric_precision , None )
134+ numeric_precision = math . ceil (row [ 2 ] / math . log ( 2 , 10 ) )
135+ row = (row [ 0 ] , row_type , None , numeric_precision , None )
132136
133137 elif issubclass (type_cls , Decimal ):
134- # TYPE_NAME has a format DECIMAL(x,y)
135- items = row .TYPE_NAME [8 :].rstrip (")" ).split ("," )
138+ items = row [1 ][8 :].rstrip (")" ).split ("," )
136139 numeric_precision , numeric_scale = int (items [0 ]), int (items [1 ])
137- row = (row . COLUMN_NAME , row_type , None , numeric_precision , numeric_scale )
140+ row = (row [ 0 ] , row_type , None , numeric_precision , numeric_scale )
138141
139142 elif issubclass (type_cls , Timestamp ):
140- row = (row . COLUMN_NAME , row_type , row . DECIMAL_DIGITS , None , None )
143+ row = (row [ 0 ] , row_type , row [ 2 ] , None , None )
141144
142145 else :
143- row = (row . COLUMN_NAME , row_type , None , None , None )
146+ row = (row [ 0 ] , row_type , None , None , None )
144147
145148 resulted_rows .append (row )
146149
@@ -153,9 +156,6 @@ def parse_table_name(self, name: str) -> DbPath:
153156 path = parse_table_name (name )
154157 return self ._normalize_table_path (path )
155158
156- def close (self ):
157- self ._conn .close ()
158-
159159 @property
160160 def is_autocommit (self ) -> bool :
161161 return True
0 commit comments