1- import copy
21from io import StringIO
32import json
43from pathlib import Path
54from parameterized import parameterized
65import unittest
7- from unittest .mock import MagicMock , Mock , patch
6+ from unittest .mock import Mock , patch
87
98from data_diff .cloud .datafold_api import (
109 TCloudApiDataSourceConfigSchema ,
1312 TCloudApiDataSourceTestResult ,
1413 TCloudDataSourceTestResult ,
1514 TDsConfig ,
16- TestDataSourceStatus ,
1715)
18- from data_diff .dbt_parser import DbtParser
1916from data_diff .cloud .data_source import (
2017 TDataSourceTestStage ,
2118 TestDataSourceStatus ,
2219 create_ds_config ,
2320 _check_data_source_exists ,
21+ _get_temp_schema ,
2422 _test_data_source ,
2523)
2624
2725
28- DATA_SOURCE_CONFIGS = [
29- TDsConfig (
26+ DATA_SOURCE_CONFIGS = {
27+ "snowflake" : TDsConfig (
3028 name = "ds_name" ,
3129 type = "snowflake" ,
3230 options = {
4038 float_tolerance = 0.000001 ,
4139 temp_schema = "database.temp_schema" ,
4240 ),
43- TDsConfig (
41+ "pg" : TDsConfig (
4442 name = "ds_name" ,
4543 type = "pg" ,
4644 options = {
5351 float_tolerance = 0.000001 ,
5452 temp_schema = "database.temp_schema" ,
5553 ),
56- TDsConfig (
54+ "bigquery" : TDsConfig (
5755 name = "ds_name" ,
5856 type = "bigquery" ,
5957 options = {
6058 "projectId" : "project_id" ,
61- "jsonKeyFile" : "some_string" ,
59+ "jsonKeyFile" : '{"key1": "value1"}' ,
6260 "location" : "US" ,
6361 },
6462 float_tolerance = 0.000001 ,
6563 temp_schema = "database.temp_schema" ,
6664 ),
67- TDsConfig (
65+ "databricks" : TDsConfig (
6866 name = "ds_name" ,
6967 type = "databricks" ,
7068 options = {
7674 float_tolerance = 0.000001 ,
7775 temp_schema = "database.temp_schema" ,
7876 ),
79- TDsConfig (
77+ "redshift" : TDsConfig (
8078 name = "ds_name" ,
8179 type = "redshift" ,
8280 options = {
8987 float_tolerance = 0.000001 ,
9088 temp_schema = "database.temp_schema" ,
9189 ),
92- TDsConfig (
90+ "postgres_aurora" : TDsConfig (
9391 name = "ds_name" ,
9492 type = "postgres_aurora" ,
9593 options = {
102100 float_tolerance = 0.000001 ,
103101 temp_schema = "database.temp_schema" ,
104102 ),
105- TDsConfig (
103+ "postgres_aws_rds" : TDsConfig (
106104 name = "ds_name" ,
107105 type = "postgres_aws_rds" ,
108106 options = {
115113 float_tolerance = 0.000001 ,
116114 temp_schema = "database.temp_schema" ,
117115 ),
118- ]
116+ }
119117
120118
121119def format_data_source_config_test (testcase_func , param_num , param ):
@@ -144,7 +142,23 @@ def setUp(self) -> None:
144142 self .api .get_data_source_schema_config .return_value = self .data_source_schema
145143 self .api .get_data_sources .return_value = self .data_sources
146144
147- @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS ], name_func = format_data_source_config_test )
145+ @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS .values ()], name_func = format_data_source_config_test )
146+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
147+ def test_get_temp_schema (self , config : TDsConfig , mock_dbt_parser ):
148+ diff_vars = {
149+ "prod_database" : "db" ,
150+ "prod_schema" : "schema" ,
151+ }
152+ mock_dbt_parser .get_datadiff_variables .return_value = diff_vars
153+ temp_schema = f'{ diff_vars ["prod_database" ]} .{ diff_vars ["prod_schema" ]} '
154+ if config .type == "snowflake" :
155+ temp_schema = temp_schema .upper ()
156+ elif config .type in {"pg" , "postgres_aurora" , "postgres_aws_rds" , "redshift" }:
157+ temp_schema = temp_schema .lower ()
158+
159+ assert _get_temp_schema (dbt_parser = mock_dbt_parser , db_type = config .type ) == temp_schema
160+
161+ @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS .values ()], name_func = format_data_source_config_test )
148162 def test_create_ds_config (self , config : TDsConfig ):
149163 inputs = list (config .options .values ()) + [config .temp_schema , config .float_tolerance ]
150164 with patch ("rich.prompt.Console.input" , side_effect = map (str , inputs )):
@@ -155,8 +169,8 @@ def test_create_ds_config(self, config: TDsConfig):
155169 self .assertEqual (actual_config , config )
156170
157171 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
158- def test_create_ds_config_from_dbt_profiles (self , mock_dbt_parser ):
159- config = DATA_SOURCE_CONFIGS [0 ]
172+ def test_create_snowflake_ds_config_from_dbt_profiles (self , mock_dbt_parser ):
173+ config = DATA_SOURCE_CONFIGS ["snowflake" ]
160174 mock_dbt_parser .get_connection_creds .return_value = (config .options ,)
161175 with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .temp_schema , str (config .float_tolerance )]):
162176 actual_config = create_ds_config (
@@ -166,11 +180,78 @@ def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
166180 )
167181 self .assertEqual (actual_config , config )
168182
183+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
184+ def test_create_bigquery_ds_config_dbt_oauth (self , mock_dbt_parser ):
185+ config = DATA_SOURCE_CONFIGS ["bigquery" ]
186+ mock_dbt_parser .get_connection_creds .return_value = (config .options ,)
187+ with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .temp_schema , str (config .float_tolerance )]):
188+ actual_config = create_ds_config (
189+ ds_config = self .db_type_data_source_schemas [config .type ],
190+ data_source_name = config .name ,
191+ dbt_parser = mock_dbt_parser ,
192+ )
193+ self .assertEqual (actual_config , config )
194+
195+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
196+ @patch ("data_diff.cloud.data_source._get_data_from_bigquery_json" )
197+ def test_create_bigquery_ds_config_dbt_service_account (self , mock_get_data_from_bigquery_json , mock_dbt_parser ):
198+ config = DATA_SOURCE_CONFIGS ["bigquery" ]
199+
200+ mock_get_data_from_bigquery_json .return_value = json .loads (config .options ["jsonKeyFile" ])
201+ mock_dbt_parser .get_connection_creds .return_value = (
202+ {
203+ "type" : "bigquery" ,
204+ "method" : "service-account" ,
205+ "project" : config .options ["projectId" ],
206+ "threads" : 1 ,
207+ "keyfile" : "/some/path" ,
208+ },
209+ )
210+
211+ with patch (
212+ "rich.prompt.Console.input" ,
213+ side_effect = ["y" , config .options ["location" ], config .temp_schema , str (config .float_tolerance )],
214+ ):
215+ actual_config = create_ds_config (
216+ ds_config = self .db_type_data_source_schemas [config .type ],
217+ data_source_name = config .name ,
218+ dbt_parser = mock_dbt_parser ,
219+ )
220+ self .assertEqual (actual_config , config )
221+
222+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
223+ def test_create_bigquery_ds_config_dbt_service_account_json (self , mock_dbt_parser ):
224+ config = DATA_SOURCE_CONFIGS ["bigquery" ]
225+
226+ mock_dbt_parser .get_connection_creds .return_value = (
227+ {
228+ "type" : "bigquery" ,
229+ "method" : "service-account-json" ,
230+ "project" : config .options ["projectId" ],
231+ "threads" : 1 ,
232+ "keyfile_json" : json .loads (config .options ["jsonKeyFile" ]),
233+ },
234+ )
235+
236+ with patch (
237+ "rich.prompt.Console.input" ,
238+ side_effect = ["y" , config .options ["location" ], config .temp_schema , str (config .float_tolerance )],
239+ ):
240+ actual_config = create_ds_config (
241+ ds_config = self .db_type_data_source_schemas [config .type ],
242+ data_source_name = config .name ,
243+ dbt_parser = mock_dbt_parser ,
244+ )
245+ self .assertEqual (actual_config , config )
246+
169247 @patch ("sys.stdout" , new_callable = StringIO )
170248 @patch ("data_diff.dbt_parser.DbtParser.__new__" )
171- def test_create_ds_config_from_dbt_profiles_one_param_passed_through_input (self , mock_dbt_parser , mock_stdout ):
172- config = DATA_SOURCE_CONFIGS [0 ]
173- options = copy .copy (config .options )
249+ def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input (
250+ self , mock_dbt_parser , mock_stdout
251+ ):
252+ config = DATA_SOURCE_CONFIGS ["snowflake" ]
253+ options = {** config .options , "type" : "snowflake" }
254+ options ["database" ] = options .pop ("default_db" )
174255 account = options .pop ("account" )
175256 mock_dbt_parser .get_connection_creds .return_value = (options ,)
176257 with patch (
0 commit comments