Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 3c0ec0f

Browse files
committed
align dbt cred params with datafold cred params
1 parent 0b1b54e commit 3c0ec0f

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

data_diff/cloud/data_source.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import time
23
from typing import List, Optional, Union, overload
34

@@ -111,6 +112,29 @@ def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
111112
return value
112113

113114

115+
def _get_data_from_bigquery_json(path: str):
116+
with open(path, 'r') as file:
117+
return json.load(file)
118+
119+
120+
def _align_dbt_cred_params_with_datafold_params(dbt_creds: dict) -> dict:
121+
db_type = dbt_creds['type']
122+
if db_type == 'bigquery':
123+
method = dbt_creds['method']
124+
if method == 'service-account':
125+
data = _get_data_from_bigquery_json(path=dbt_creds['keyfile'])
126+
dbt_creds['jsonKeyFile'] = json.dumps(data)
127+
elif method == 'service-account-json':
128+
dbt_creds['jsonKeyFile'] = json.dumps(dbt_creds['keyfile_json'])
129+
dbt_creds['projectId'] = dbt_creds['project']
130+
elif db_type == 'snowflake':
131+
dbt_creds['default_db'] = dbt_creds['database']
132+
elif db_type == 'databricks':
133+
dbt_creds['http_password'] = dbt_creds['token']
134+
dbt_creds['database'] = dbt_creds.get('catalog')
135+
return dbt_creds
136+
137+
114138
def _parse_ds_credentials(
115139
ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True, dbt_parser: Optional[DbtParser] = None
116140
):
@@ -120,6 +144,7 @@ def _parse_ds_credentials(
120144
use_dbt_data = Confirm.ask("Would you like to extract database credentials from dbt profiles.yml?")
121145
try:
122146
creds = dbt_parser.get_connection_creds()[0]
147+
creds = _align_dbt_cred_params_with_datafold_params(dbt_creds=creds)
123148
except Exception as e:
124149
rich.print(f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: {e}")
125150

data_diff/cloud/datafold_api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import dataclasses
23
import enum
34
import time
@@ -159,7 +160,7 @@ class TCloudDataSourceTestResult(pydantic.BaseModel):
159160
class TCloudApiDataSourceTestResult(pydantic.BaseModel):
160161
name: str
161162
status: str
162-
result: TCloudDataSourceTestResult
163+
result: Optional[TCloudDataSourceTestResult]
163164

164165

165166
@dataclasses.dataclass
@@ -191,7 +192,11 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
191192
return [TCloudApiDataSource(**item) for item in rv.json()]
192193

193194
def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
194-
rv = self.make_post_request(url="api/v1/data_sources", payload=config.dict())
195+
payload = config.dict()
196+
if config.type == 'bigquery':
197+
json_string = payload['options']['jsonKeyFile'].encode('utf-8')
198+
payload['options']['jsonKeyFile'] = base64.b64encode(json_string).decode('utf-8')
199+
rv = self.make_post_request(url="api/v1/data_sources", payload=payload)
195200
return TCloudApiDataSource(**rv.json())
196201

197202
def get_data_source_schema_config(
@@ -250,7 +255,7 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
250255
status=item["result"]["code"].lower(),
251256
message=item["result"]["message"],
252257
outcome=item["result"]["outcome"],
253-
),
258+
) if item["result"] is not None else None,
254259
)
255260
for item in rv.json()["results"]
256261
]

0 commit comments

Comments
 (0)