Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 8e073e8

Browse files
committed
apply black
1 parent 3d9a079 commit 8e073e8

File tree

4 files changed

+61
-47
lines changed

4 files changed

+61
-47
lines changed

data_diff/cloud/data_source.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,25 @@ def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
113113

114114

115115
def _get_data_from_bigquery_json(path: str):
116-
with open(path, 'r') as file:
116+
with open(path, "r") as file:
117117
return json.load(file)
118118

119119

120120
def _align_dbt_cred_params_with_datafold_params(dbt_creds: dict) -> dict:
121-
db_type = dbt_creds['type']
122-
if db_type == 'bigquery':
123-
method = dbt_creds['method']
124-
if method == 'service-account':
125-
data = _get_data_from_bigquery_json(path=dbt_creds['keyfile'])
126-
dbt_creds['jsonKeyFile'] = json.dumps(data)
127-
elif method == 'service-account-json':
128-
dbt_creds['jsonKeyFile'] = json.dumps(dbt_creds['keyfile_json'])
129-
dbt_creds['projectId'] = dbt_creds['project']
130-
elif db_type == 'snowflake':
131-
dbt_creds['default_db'] = dbt_creds['database']
132-
elif db_type == 'databricks':
133-
dbt_creds['http_password'] = dbt_creds['token']
134-
dbt_creds['database'] = dbt_creds.get('catalog')
121+
db_type = dbt_creds["type"]
122+
if db_type == "bigquery":
123+
method = dbt_creds["method"]
124+
if method == "service-account":
125+
data = _get_data_from_bigquery_json(path=dbt_creds["keyfile"])
126+
dbt_creds["jsonKeyFile"] = json.dumps(data)
127+
elif method == "service-account-json":
128+
dbt_creds["jsonKeyFile"] = json.dumps(dbt_creds["keyfile_json"])
129+
dbt_creds["projectId"] = dbt_creds["project"]
130+
elif db_type == "snowflake":
131+
dbt_creds["default_db"] = dbt_creds["database"]
132+
elif db_type == "databricks":
133+
dbt_creds["http_password"] = dbt_creds["token"]
134+
dbt_creds["database"] = dbt_creds.get("catalog")
135135
return dbt_creds
136136

137137

data_diff/cloud/datafold_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
193193

194194
def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
195195
payload = config.dict()
196-
if config.type == 'bigquery':
197-
json_string = payload['options']['jsonKeyFile'].encode('utf-8')
198-
payload['options']['jsonKeyFile'] = base64.b64encode(json_string).decode('utf-8')
196+
if config.type == "bigquery":
197+
json_string = payload["options"]["jsonKeyFile"].encode("utf-8")
198+
payload["options"]["jsonKeyFile"] = base64.b64encode(json_string).decode("utf-8")
199199
rv = self.make_post_request(url="api/v1/data_sources", payload=payload)
200200
return TCloudApiDataSource(**rv.json())
201201

@@ -255,7 +255,9 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
255255
status=item["result"]["code"].lower(),
256256
message=item["result"]["message"],
257257
outcome=item["result"]["outcome"],
258-
) if item["result"] is not None else None,
258+
)
259+
if item["result"] is not None
260+
else None,
259261
)
260262
for item in rv.json()["results"]
261263
]

data_diff/dbt_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def set_connection(self):
145145
"role": credentials.get("role"),
146146
"schema": credentials.get("schema"),
147147
"insecure_mode": credentials.get("insecure_mode", False),
148-
"client_session_keep_alive": credentials.get("client_session_keep_alive", False)
148+
"client_session_keep_alive": credentials.get("client_session_keep_alive", False),
149149
}
150150
self.threads = credentials.get("threads")
151151
self.requires_upper = True

tests/cloud/test_data_source.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_create_ds_config(self, config: TDsConfig):
170170

171171
@patch("data_diff.dbt_parser.DbtParser.__new__")
172172
def test_create_snowflake_ds_config_from_dbt_profiles(self, mock_dbt_parser):
173-
config = DATA_SOURCE_CONFIGS['snowflake']
173+
config = DATA_SOURCE_CONFIGS["snowflake"]
174174
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
175175
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
176176
actual_config = create_ds_config(
@@ -182,7 +182,7 @@ def test_create_snowflake_ds_config_from_dbt_profiles(self, mock_dbt_parser):
182182

183183
@patch("data_diff.dbt_parser.DbtParser.__new__")
184184
def test_create_bigquery_ds_config_dbt_oauth(self, mock_dbt_parser):
185-
config = DATA_SOURCE_CONFIGS['bigquery']
185+
config = DATA_SOURCE_CONFIGS["bigquery"]
186186
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
187187
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
188188
actual_config = create_ds_config(
@@ -195,18 +195,23 @@ def test_create_bigquery_ds_config_dbt_oauth(self, mock_dbt_parser):
195195
@patch("data_diff.dbt_parser.DbtParser.__new__")
196196
@patch("data_diff.cloud.data_source._get_data_from_bigquery_json")
197197
def test_create_bigquery_ds_config_dbt_service_account(self, mock_get_data_from_bigquery_json, mock_dbt_parser):
198-
config = DATA_SOURCE_CONFIGS['bigquery']
199-
200-
mock_get_data_from_bigquery_json.return_value = json.loads(config.options['jsonKeyFile'])
201-
mock_dbt_parser.get_connection_creds.return_value = {
202-
'type': 'bigquery',
203-
'method': 'service-account',
204-
'project': config.options['projectId'],
205-
'threads': 1,
206-
'keyfile': '/some/path'
207-
},
198+
config = DATA_SOURCE_CONFIGS["bigquery"]
199+
200+
mock_get_data_from_bigquery_json.return_value = json.loads(config.options["jsonKeyFile"])
201+
mock_dbt_parser.get_connection_creds.return_value = (
202+
{
203+
"type": "bigquery",
204+
"method": "service-account",
205+
"project": config.options["projectId"],
206+
"threads": 1,
207+
"keyfile": "/some/path",
208+
},
209+
)
208210

209-
with patch("rich.prompt.Console.input", side_effect=["y", config.options['location'], config.temp_schema, str(config.float_tolerance)]):
211+
with patch(
212+
"rich.prompt.Console.input",
213+
side_effect=["y", config.options["location"], config.temp_schema, str(config.float_tolerance)],
214+
):
210215
actual_config = create_ds_config(
211216
ds_config=self.db_type_data_source_schemas[config.type],
212217
data_source_name=config.name,
@@ -216,17 +221,22 @@ def test_create_bigquery_ds_config_dbt_service_account(self, mock_get_data_from_
216221

217222
@patch("data_diff.dbt_parser.DbtParser.__new__")
218223
def test_create_bigquery_ds_config_dbt_service_account_json(self, mock_dbt_parser):
219-
config = DATA_SOURCE_CONFIGS['bigquery']
220-
221-
mock_dbt_parser.get_connection_creds.return_value = {
222-
'type': 'bigquery',
223-
'method': 'service-account-json',
224-
'project': config.options['projectId'],
225-
'threads': 1,
226-
'keyfile_json': json.loads(config.options['jsonKeyFile'])
227-
},
224+
config = DATA_SOURCE_CONFIGS["bigquery"]
225+
226+
mock_dbt_parser.get_connection_creds.return_value = (
227+
{
228+
"type": "bigquery",
229+
"method": "service-account-json",
230+
"project": config.options["projectId"],
231+
"threads": 1,
232+
"keyfile_json": json.loads(config.options["jsonKeyFile"]),
233+
},
234+
)
228235

229-
with patch("rich.prompt.Console.input", side_effect=["y", config.options['location'], config.temp_schema, str(config.float_tolerance)]):
236+
with patch(
237+
"rich.prompt.Console.input",
238+
side_effect=["y", config.options["location"], config.temp_schema, str(config.float_tolerance)],
239+
):
230240
actual_config = create_ds_config(
231241
ds_config=self.db_type_data_source_schemas[config.type],
232242
data_source_name=config.name,
@@ -236,10 +246,12 @@ def test_create_bigquery_ds_config_dbt_service_account_json(self, mock_dbt_parse
236246

237247
@patch("sys.stdout", new_callable=StringIO)
238248
@patch("data_diff.dbt_parser.DbtParser.__new__")
239-
def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input(self, mock_dbt_parser, mock_stdout):
240-
config = DATA_SOURCE_CONFIGS['snowflake']
241-
options = {**config.options, 'type': 'snowflake'}
242-
options['database'] = options.pop('default_db')
249+
def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input(
250+
self, mock_dbt_parser, mock_stdout
251+
):
252+
config = DATA_SOURCE_CONFIGS["snowflake"]
253+
options = {**config.options, "type": "snowflake"}
254+
options["database"] = options.pop("default_db")
243255
account = options.pop("account")
244256
mock_dbt_parser.get_connection_creds.return_value = (options,)
245257
with patch(

0 commit comments

Comments
 (0)