Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 3d9a079

Browse files
committed
update unit tests
1 parent 3c0ec0f commit 3d9a079

File tree

1 file changed

+73
-21
lines changed

1 file changed

+73
-21
lines changed

tests/cloud/test_data_source.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import copy
21
from io import StringIO
32
import json
43
from pathlib import Path
54
from parameterized import parameterized
65
import unittest
7-
from unittest.mock import MagicMock, Mock, patch
6+
from unittest.mock import Mock, patch
87

98
from data_diff.cloud.datafold_api import (
109
TCloudApiDataSourceConfigSchema,
@@ -13,9 +12,7 @@
1312
TCloudApiDataSourceTestResult,
1413
TCloudDataSourceTestResult,
1514
TDsConfig,
16-
TestDataSourceStatus,
1715
)
18-
from data_diff.dbt_parser import DbtParser
1916
from data_diff.cloud.data_source import (
2017
TDataSourceTestStage,
2118
TestDataSourceStatus,
@@ -26,8 +23,8 @@
2623
)
2724

2825

29-
DATA_SOURCE_CONFIGS = [
30-
TDsConfig(
26+
DATA_SOURCE_CONFIGS = {
27+
"snowflake": TDsConfig(
3128
name="ds_name",
3229
type="snowflake",
3330
options={
@@ -41,7 +38,7 @@
4138
float_tolerance=0.000001,
4239
temp_schema="database.temp_schema",
4340
),
44-
TDsConfig(
41+
"pg": TDsConfig(
4542
name="ds_name",
4643
type="pg",
4744
options={
@@ -54,18 +51,18 @@
5451
float_tolerance=0.000001,
5552
temp_schema="database.temp_schema",
5653
),
57-
TDsConfig(
54+
"bigquery": TDsConfig(
5855
name="ds_name",
5956
type="bigquery",
6057
options={
6158
"projectId": "project_id",
62-
"jsonKeyFile": "some_string",
59+
"jsonKeyFile": '{"key1": "value1"}',
6360
"location": "US",
6461
},
6562
float_tolerance=0.000001,
6663
temp_schema="database.temp_schema",
6764
),
68-
TDsConfig(
65+
"databricks": TDsConfig(
6966
name="ds_name",
7067
type="databricks",
7168
options={
@@ -77,7 +74,7 @@
7774
float_tolerance=0.000001,
7875
temp_schema="database.temp_schema",
7976
),
80-
TDsConfig(
77+
"redshift": TDsConfig(
8178
name="ds_name",
8279
type="redshift",
8380
options={
@@ -90,7 +87,7 @@
9087
float_tolerance=0.000001,
9188
temp_schema="database.temp_schema",
9289
),
93-
TDsConfig(
90+
"postgres_aurora": TDsConfig(
9491
name="ds_name",
9592
type="postgres_aurora",
9693
options={
@@ -103,7 +100,7 @@
103100
float_tolerance=0.000001,
104101
temp_schema="database.temp_schema",
105102
),
106-
TDsConfig(
103+
"postgres_aws_rds": TDsConfig(
107104
name="ds_name",
108105
type="postgres_aws_rds",
109106
options={
@@ -116,7 +113,7 @@
116113
float_tolerance=0.000001,
117114
temp_schema="database.temp_schema",
118115
),
119-
]
116+
}
120117

121118

122119
def format_data_source_config_test(testcase_func, param_num, param):
@@ -145,7 +142,7 @@ def setUp(self) -> None:
145142
self.api.get_data_source_schema_config.return_value = self.data_source_schema
146143
self.api.get_data_sources.return_value = self.data_sources
147144

148-
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS], name_func=format_data_source_config_test)
145+
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test)
149146
@patch("data_diff.dbt_parser.DbtParser.__new__")
150147
def test_get_temp_schema(self, config: TDsConfig, mock_dbt_parser):
151148
diff_vars = {
@@ -161,7 +158,7 @@ def test_get_temp_schema(self, config: TDsConfig, mock_dbt_parser):
161158

162159
assert _get_temp_schema(dbt_parser=mock_dbt_parser, db_type=config.type) == temp_schema
163160

164-
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS], name_func=format_data_source_config_test)
161+
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test)
165162
def test_create_ds_config(self, config: TDsConfig):
166163
inputs = list(config.options.values()) + [config.temp_schema, config.float_tolerance]
167164
with patch("rich.prompt.Console.input", side_effect=map(str, inputs)):
@@ -172,8 +169,8 @@ def test_create_ds_config(self, config: TDsConfig):
172169
self.assertEqual(actual_config, config)
173170

174171
@patch("data_diff.dbt_parser.DbtParser.__new__")
175-
def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
176-
config = DATA_SOURCE_CONFIGS[0]
172+
def test_create_snowflake_ds_config_from_dbt_profiles(self, mock_dbt_parser):
173+
config = DATA_SOURCE_CONFIGS['snowflake']
177174
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
178175
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
179176
actual_config = create_ds_config(
@@ -183,11 +180,66 @@ def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
183180
)
184181
self.assertEqual(actual_config, config)
185182

183+
@patch("data_diff.dbt_parser.DbtParser.__new__")
184+
def test_create_bigquery_ds_config_dbt_oauth(self, mock_dbt_parser):
185+
config = DATA_SOURCE_CONFIGS['bigquery']
186+
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
187+
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
188+
actual_config = create_ds_config(
189+
ds_config=self.db_type_data_source_schemas[config.type],
190+
data_source_name=config.name,
191+
dbt_parser=mock_dbt_parser,
192+
)
193+
self.assertEqual(actual_config, config)
194+
195+
@patch("data_diff.dbt_parser.DbtParser.__new__")
196+
@patch("data_diff.cloud.data_source._get_data_from_bigquery_json")
197+
def test_create_bigquery_ds_config_dbt_service_account(self, mock_get_data_from_bigquery_json, mock_dbt_parser):
198+
config = DATA_SOURCE_CONFIGS['bigquery']
199+
200+
mock_get_data_from_bigquery_json.return_value = json.loads(config.options['jsonKeyFile'])
201+
mock_dbt_parser.get_connection_creds.return_value = {
202+
'type': 'bigquery',
203+
'method': 'service-account',
204+
'project': config.options['projectId'],
205+
'threads': 1,
206+
'keyfile': '/some/path'
207+
},
208+
209+
with patch("rich.prompt.Console.input", side_effect=["y", config.options['location'], config.temp_schema, str(config.float_tolerance)]):
210+
actual_config = create_ds_config(
211+
ds_config=self.db_type_data_source_schemas[config.type],
212+
data_source_name=config.name,
213+
dbt_parser=mock_dbt_parser,
214+
)
215+
self.assertEqual(actual_config, config)
216+
217+
@patch("data_diff.dbt_parser.DbtParser.__new__")
218+
def test_create_bigquery_ds_config_dbt_service_account_json(self, mock_dbt_parser):
219+
config = DATA_SOURCE_CONFIGS['bigquery']
220+
221+
mock_dbt_parser.get_connection_creds.return_value = {
222+
'type': 'bigquery',
223+
'method': 'service-account-json',
224+
'project': config.options['projectId'],
225+
'threads': 1,
226+
'keyfile_json': json.loads(config.options['jsonKeyFile'])
227+
},
228+
229+
with patch("rich.prompt.Console.input", side_effect=["y", config.options['location'], config.temp_schema, str(config.float_tolerance)]):
230+
actual_config = create_ds_config(
231+
ds_config=self.db_type_data_source_schemas[config.type],
232+
data_source_name=config.name,
233+
dbt_parser=mock_dbt_parser,
234+
)
235+
self.assertEqual(actual_config, config)
236+
186237
@patch("sys.stdout", new_callable=StringIO)
187238
@patch("data_diff.dbt_parser.DbtParser.__new__")
188-
def test_create_ds_config_from_dbt_profiles_one_param_passed_through_input(self, mock_dbt_parser, mock_stdout):
189-
config = DATA_SOURCE_CONFIGS[0]
190-
options = copy.copy(config.options)
239+
def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input(self, mock_dbt_parser, mock_stdout):
240+
config = DATA_SOURCE_CONFIGS['snowflake']
241+
options = {**config.options, 'type': 'snowflake'}
242+
options['database'] = options.pop('default_db')
191243
account = options.pop("account")
192244
mock_dbt_parser.get_connection_creds.return_value = (options,)
193245
with patch(

0 commit comments

Comments
 (0)