1- import copy
21from io import StringIO
32import json
43from pathlib import Path
54from parameterized import parameterized
65import unittest
7- from unittest .mock import MagicMock , Mock , patch
6+ from unittest .mock import Mock , patch
87
98from data_diff .cloud .datafold_api import (
109 TCloudApiDataSourceConfigSchema ,
1312 TCloudApiDataSourceTestResult ,
1413 TCloudDataSourceTestResult ,
1514 TDsConfig ,
16- TestDataSourceStatus ,
1715)
18- from data_diff .dbt_parser import DbtParser
1916from data_diff .cloud .data_source import (
2017 TDataSourceTestStage ,
2118 TestDataSourceStatus ,
2623)
2724
2825
29- DATA_SOURCE_CONFIGS = [
30- TDsConfig (
26+ DATA_SOURCE_CONFIGS = {
27+ "snowflake" : TDsConfig (
3128 name = "ds_name" ,
3229 type = "snowflake" ,
3330 options = {
4138 float_tolerance = 0.000001 ,
4239 temp_schema = "database.temp_schema" ,
4340 ),
44- TDsConfig (
41+ "pg" : TDsConfig (
4542 name = "ds_name" ,
4643 type = "pg" ,
4744 options = {
5451 float_tolerance = 0.000001 ,
5552 temp_schema = "database.temp_schema" ,
5653 ),
57- TDsConfig (
54+ "bigquery" : TDsConfig (
5855 name = "ds_name" ,
5956 type = "bigquery" ,
6057 options = {
6158 "projectId" : "project_id" ,
62- "jsonKeyFile" : "some_string" ,
59+ "jsonKeyFile" : '{"key1": "value1"}' ,
6360 "location" : "US" ,
6461 },
6562 float_tolerance = 0.000001 ,
6663 temp_schema = "database.temp_schema" ,
6764 ),
68- TDsConfig (
65+ "databricks" : TDsConfig (
6966 name = "ds_name" ,
7067 type = "databricks" ,
7168 options = {
7774 float_tolerance = 0.000001 ,
7875 temp_schema = "database.temp_schema" ,
7976 ),
80- TDsConfig (
77+ "redshift" : TDsConfig (
8178 name = "ds_name" ,
8279 type = "redshift" ,
8380 options = {
9087 float_tolerance = 0.000001 ,
9188 temp_schema = "database.temp_schema" ,
9289 ),
93- TDsConfig (
90+ "postgres_aurora" : TDsConfig (
9491 name = "ds_name" ,
9592 type = "postgres_aurora" ,
9693 options = {
103100 float_tolerance = 0.000001 ,
104101 temp_schema = "database.temp_schema" ,
105102 ),
106- TDsConfig (
103+ "postgres_aws_rds" : TDsConfig (
107104 name = "ds_name" ,
108105 type = "postgres_aws_rds" ,
109106 options = {
116113 float_tolerance = 0.000001 ,
117114 temp_schema = "database.temp_schema" ,
118115 ),
119- ]
116+ }
120117
121118
122119def format_data_source_config_test (testcase_func , param_num , param ):
@@ -145,7 +142,7 @@ def setUp(self) -> None:
145142 self .api .get_data_source_schema_config .return_value = self .data_source_schema
146143 self .api .get_data_sources .return_value = self .data_sources
147144
148- @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS ], name_func = format_data_source_config_test )
145+ @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS . values () ], name_func = format_data_source_config_test )
149146 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
150147 def test_get_temp_schema (self , config : TDsConfig , mock_dbt_parser ):
151148 diff_vars = {
@@ -161,7 +158,7 @@ def test_get_temp_schema(self, config: TDsConfig, mock_dbt_parser):
161158
162159 assert _get_temp_schema (dbt_parser = mock_dbt_parser , db_type = config .type ) == temp_schema
163160
164- @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS ], name_func = format_data_source_config_test )
161+ @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS . values () ], name_func = format_data_source_config_test )
165162 def test_create_ds_config (self , config : TDsConfig ):
166163 inputs = list (config .options .values ()) + [config .temp_schema , config .float_tolerance ]
167164 with patch ("rich.prompt.Console.input" , side_effect = map (str , inputs )):
@@ -172,8 +169,8 @@ def test_create_ds_config(self, config: TDsConfig):
172169 self .assertEqual (actual_config , config )
173170
174171 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
175- def test_create_ds_config_from_dbt_profiles (self , mock_dbt_parser ):
176- config = DATA_SOURCE_CONFIGS [0 ]
172+ def test_create_snowflake_ds_config_from_dbt_profiles (self , mock_dbt_parser ):
173+ config = DATA_SOURCE_CONFIGS ['snowflake' ]
177174 mock_dbt_parser .get_connection_creds .return_value = (config .options ,)
178175 with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .temp_schema , str (config .float_tolerance )]):
179176 actual_config = create_ds_config (
@@ -183,11 +180,66 @@ def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
183180 )
184181 self .assertEqual (actual_config , config )
185182
183+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
184+ def test_create_bigquery_ds_config_dbt_oauth (self , mock_dbt_parser ):
185+ config = DATA_SOURCE_CONFIGS ['bigquery' ]
186+ mock_dbt_parser .get_connection_creds .return_value = (config .options ,)
187+ with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .temp_schema , str (config .float_tolerance )]):
188+ actual_config = create_ds_config (
189+ ds_config = self .db_type_data_source_schemas [config .type ],
190+ data_source_name = config .name ,
191+ dbt_parser = mock_dbt_parser ,
192+ )
193+ self .assertEqual (actual_config , config )
194+
195+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
196+ @patch ("data_diff.cloud.data_source._get_data_from_bigquery_json" )
197+ def test_create_bigquery_ds_config_dbt_service_account (self , mock_get_data_from_bigquery_json , mock_dbt_parser ):
198+ config = DATA_SOURCE_CONFIGS ['bigquery' ]
199+
200+ mock_get_data_from_bigquery_json .return_value = json .loads (config .options ['jsonKeyFile' ])
201+ mock_dbt_parser .get_connection_creds .return_value = {
202+ 'type' : 'bigquery' ,
203+ 'method' : 'service-account' ,
204+ 'project' : config .options ['projectId' ],
205+ 'threads' : 1 ,
206+ 'keyfile' : '/some/path'
207+ },
208+
209+ with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .options ['location' ], config .temp_schema , str (config .float_tolerance )]):
210+ actual_config = create_ds_config (
211+ ds_config = self .db_type_data_source_schemas [config .type ],
212+ data_source_name = config .name ,
213+ dbt_parser = mock_dbt_parser ,
214+ )
215+ self .assertEqual (actual_config , config )
216+
217+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
218+ def test_create_bigquery_ds_config_dbt_service_account_json (self , mock_dbt_parser ):
219+ config = DATA_SOURCE_CONFIGS ['bigquery' ]
220+
221+ mock_dbt_parser .get_connection_creds .return_value = {
222+ 'type' : 'bigquery' ,
223+ 'method' : 'service-account-json' ,
224+ 'project' : config .options ['projectId' ],
225+ 'threads' : 1 ,
226+ 'keyfile_json' : json .loads (config .options ['jsonKeyFile' ])
227+ },
228+
229+ with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .options ['location' ], config .temp_schema , str (config .float_tolerance )]):
230+ actual_config = create_ds_config (
231+ ds_config = self .db_type_data_source_schemas [config .type ],
232+ data_source_name = config .name ,
233+ dbt_parser = mock_dbt_parser ,
234+ )
235+ self .assertEqual (actual_config , config )
236+
186237 @patch ("sys.stdout" , new_callable = StringIO )
187238 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
188- def test_create_ds_config_from_dbt_profiles_one_param_passed_through_input (self , mock_dbt_parser , mock_stdout ):
189- config = DATA_SOURCE_CONFIGS [0 ]
190- options = copy .copy (config .options )
239+ def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input (self , mock_dbt_parser , mock_stdout ):
240+ config = DATA_SOURCE_CONFIGS ['snowflake' ]
241+ options = {** config .options , 'type' : 'snowflake' }
242+ options ['database' ] = options .pop ('default_db' )
191243 account = options .pop ("account" )
192244 mock_dbt_parser .get_connection_creds .return_value = (options ,)
193245 with patch (
0 commit comments