Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit d0d0e2f

Browse files
authored
Merge branch 'master' into oracle-via-oracledb
2 parents 30bd116 + fef52d0 commit d0d0e2f

File tree

9 files changed

+59
-34
lines changed

9 files changed

+59
-34
lines changed

data_diff/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
253253
"--select",
254254
"-s",
255255
default=None,
256-
metavar="PATH",
257-
help="select dbt resources to compare using dbt selection syntax.",
256+
metavar="SELECTION or MODEL_NAME",
257+
help="--select dbt resources to compare using dbt selection syntax in dbt versions >= 1.5.\nIn versions < 1.5, it will naively search for a model with MODEL_NAME as the name.",
258258
)
259259
@click.option(
260260
"--state",

data_diff/dbt_parser.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
DataDiffDbtRunResultsVersionError,
2121
DataDiffDbtSelectNoMatchingModelsError,
2222
DataDiffDbtSelectUnexpectedError,
23-
DataDiffDbtSelectVersionTooLowError,
2423
DataDiffDbtSnowflakeSetConnectionError,
24+
DataDiffSimpleSelectNotFound,
2525
)
2626

2727
from .utils import getLogger, get_from_dict_with_raise
@@ -167,9 +167,11 @@ def get_models(self, dbt_selection: Optional[str] = None):
167167
"data-diff is using a dbt-core version < 1.5, update the environment's dbt-core version via pip install 'dbt-core>=1.5' in order to use `--select`"
168168
)
169169
else:
170-
raise DataDiffDbtSelectVersionTooLowError(
171-
f"The `--select` feature requires dbt >= 1.5, but your project's manifest.json is from dbt v{dbt_version}. Please follow these steps to use the `--select` feature: \n 1. Update your dbt-core version via pip install 'dbt-core>=1.5'. Details: https://docs.getdbt.com/docs/core/pip-install#change-dbt-core-versions \n 2. Execute any `dbt` command (`run`, `compile`, `build`) to create a new manifest.json."
170+
# Naively get node named <dbt_selection>
171+
logger.warning(
172+
f"Full `--select` support requires dbt >= 1.5. Naively searching for a single model with name: '{dbt_selection}'."
172173
)
174+
return self.get_simple_model_selection(dbt_selection)
173175
else:
174176
return self.get_run_results_models()
175177

@@ -209,6 +211,25 @@ def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:
209211
logger.debug(str(results))
210212
raise DataDiffDbtSelectUnexpectedError("Encountered an unexpected error while finding `--select` models")
211213

214+
def get_simple_model_selection(self, dbt_selection: str):
215+
model_nodes = dict(filter(lambda item: item[0].startswith("model."), self.dev_manifest_obj.nodes.items()))
216+
217+
model_unique_key_list = [k for k, v in model_nodes.items() if v.name == dbt_selection]
218+
219+
# name *should* always be unique, but just in case:
220+
if len(model_unique_key_list) > 1:
221+
logger.warning(
222+
f"Found more than one model with name '{dbt_selection}' {model_unique_key_list}, using the first one."
223+
)
224+
elif len(model_unique_key_list) < 1:
225+
raise DataDiffSimpleSelectNotFound(
226+
f"Did not find a model node with name '{dbt_selection}' in the manifest."
227+
)
228+
229+
model = model_nodes.get(model_unique_key_list[0])
230+
231+
return [model]
232+
212233
def get_run_results_models(self):
213234
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
214235
logger.info(f"Parsing file {RUN_RESULTS_PATH}")

data_diff/errors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ class DataDiffDbtCoreNoRunnerError(Exception):
4242
"Raised when the manifest version >= 1.5, but the dbt-core package is < 1.5. This is an edge case most likely to occur in development."
4343

4444

45-
class DataDiffDbtSelectVersionTooLowError(Exception):
46-
"Raised when attempting to use `--select` with a dbt-core version < 1.5."
47-
48-
4945
class DataDiffCustomSchemaNoConfigError(Exception):
5046
"Raised when a model has a custom schema, but there is no prod_custom_schema config. (And not using --state)."
5147

@@ -68,3 +64,7 @@ class DataDiffCloudDiffFailed(Exception):
6864

6965
class DataDiffCloudDiffTimedOut(Exception):
7066
"Raised when using --cloud and the diff did not return finish before the timeout value."
67+
68+
69+
class DataDiffSimpleSelectNotFound(Exception):
70+
"Raised when using --select on dbt < 1.5 and a model node is not found in the manifest."

data_diff/sqeleton/databases/_connect.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def load_mixins(self, *abstract_mixins: AbstractMixin) -> Self:
106106
database_by_scheme = {k: db.load_mixins(*abstract_mixins) for k, db in self.database_by_scheme.items()}
107107
return type(self)(database_by_scheme)
108108

109-
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database:
109+
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1, **kwargs) -> Database:
110110
"""Connect to the given database uri
111111
112112
thread_count determines the max number of worker threads per database,
@@ -149,7 +149,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
149149
conn_dict = config["database"][database]
150150
except KeyError:
151151
raise ValueError(f"Cannot find database config named '{database}'.")
152-
return self.connect_with_dict(conn_dict, thread_count)
152+
return self.connect_with_dict(conn_dict, thread_count, **kwargs)
153153

154154
try:
155155
matcher = self.match_uri_path[scheme]
@@ -174,7 +174,7 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
174174

175175
if scheme == "bigquery":
176176
kw["project"] = dsn.host
177-
return cls(**kw)
177+
return cls(**kw, **kwargs)
178178

179179
if scheme == "snowflake":
180180
kw["account"] = dsn.host
@@ -194,13 +194,13 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
194194
kw = {k: v for k, v in kw.items() if v is not None}
195195

196196
if issubclass(cls, ThreadedDatabase):
197-
db = cls(thread_count=thread_count, **kw)
197+
db = cls(thread_count=thread_count, **kw, **kwargs)
198198
else:
199-
db = cls(**kw)
199+
db = cls(**kw, **kwargs)
200200

201201
return self._connection_created(db)
202202

203-
def connect_with_dict(self, d, thread_count):
203+
def connect_with_dict(self, d, thread_count, **kwargs):
204204
d = dict(d)
205205
driver = d.pop("driver")
206206
try:
@@ -210,17 +210,19 @@ def connect_with_dict(self, d, thread_count):
210210

211211
cls = matcher.database_cls
212212
if issubclass(cls, ThreadedDatabase):
213-
db = cls(thread_count=thread_count, **d)
213+
db = cls(thread_count=thread_count, **d, **kwargs)
214214
else:
215-
db = cls(**d)
215+
db = cls(**d, **kwargs)
216216

217217
return self._connection_created(db)
218218

219219
def _connection_created(self, db):
220220
"Nop function to be overridden by subclasses."
221221
return db
222222

223-
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database:
223+
def __call__(
224+
self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True, **kwargs
225+
) -> Database:
224226
"""Connect to a database using the given database configuration.
225227
226228
Configuration can be given either as a URI string, or as a dict of {option: value}.
@@ -234,6 +236,8 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s
234236
db_conf (str | dict): The configuration for the database to connect. URI or dict.
235237
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
236238
shared (bool): Whether to cache and return the same connection for the same db_conf. (default: True)
239+
bigquery_credentials (google.oauth2.credentials.Credentials): Custom Google oAuth2 credential for BigQuery.
240+
(default: None)
237241
238242
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
239243
@@ -263,9 +267,9 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, s
263267
return conn
264268

265269
if isinstance(db_conf, str):
266-
conn = self.connect_to_uri(db_conf, thread_count)
270+
conn = self.connect_to_uri(db_conf, thread_count, **kwargs)
267271
elif isinstance(db_conf, dict):
268-
conn = self.connect_with_dict(db_conf, thread_count)
272+
conn = self.connect_with_dict(db_conf, thread_count, **kwargs)
269273
else:
270274
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
271275

data_diff/sqeleton/databases/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ class BigQuery(Database):
210210
CONNECT_URI_PARAMS = ["dataset"]
211211
dialect = Dialect()
212212

213-
def __init__(self, project, *, dataset, **kw):
214-
credentials = None
213+
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
214+
credentials = bigquery_credentials
215215
bigquery = import_bigquery()
216216

217217
keyfile = kw.pop("keyfile", None)

data_diff/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.8.3"
1+
__version__ = "0.8.4"

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "data-diff"
3-
version = "0.8.3"
3+
version = "0.8.4"
44
description = "Command-line tool and Python library to efficiently diff rows across two different databases."
55
authors = ["Datafold <data-diff@datafold.com>"]
66
license = "MIT"
@@ -37,7 +37,7 @@ trino = {version="^0.314.0", optional=true}
3737
presto-python-client = {version="*", optional=true}
3838
clickhouse-driver = {version="*", optional=true}
3939
duckdb = {version="*", optional=true}
40-
dbt-artifacts-parser = {version="^0.4.0"}
40+
dbt-artifacts-parser = {version="^0.4.2"}
4141
dbt-core = {version="^1.0.0"}
4242
keyring = "*"
4343
tabulate = "^0.9.0"
@@ -59,7 +59,7 @@ presto-python-client = "*"
5959
clickhouse-driver = "*"
6060
vertica-python = "*"
6161
duckdb = "^0.7.0"
62-
dbt-artifacts-parser = "^0.4.0"
62+
dbt-artifacts-parser = "^0.4.2"
6363
dbt-core = "^1.0.0"
6464
# google-cloud-bigquery = "*"
6565
# databricks-sql-connector = "*"

tests/test_dbt_parser.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
DataDiffDbtProfileNotFoundError,
1111
DataDiffDbtRedshiftPasswordOnlyError,
1212
DataDiffDbtRunResultsVersionError,
13-
DataDiffDbtSelectVersionTooLowError,
1413
DataDiffDbtSnowflakeSetConnectionError,
1514
)
1615

@@ -56,17 +55,18 @@ def test_get_models(self):
5655
mock_self.get_dbt_selection_models.assert_called_once_with(selection)
5756
self.assertEqual(models, mock_return_value)
5857

59-
def test_get_models_unsupported_manifest_version(self):
58+
def test_get_models_simple_select(self):
6059
mock_self = Mock()
6160
mock_self.project_dir = Path()
6261
mock_self.dbt_version = "1.4.0"
6362
selection = "model+"
6463
mock_return_value = Mock()
65-
mock_self.get_dbt_selection_models.return_value = mock_return_value
64+
mock_self.get_simple_model_selection.return_value = mock_return_value
6665

67-
with self.assertRaises(DataDiffDbtSelectVersionTooLowError):
68-
_ = DbtParser.get_models(mock_self, selection)
66+
models = DbtParser.get_models(mock_self, selection)
6967
mock_self.get_dbt_selection_models.assert_not_called()
68+
mock_self.get_simple_model_selection.assert_called_with(selection)
69+
self.assertEqual(models, mock_return_value)
7070

7171
def test_get_models_no_runner(self):
7272
mock_self = Mock()

0 commit comments

Comments
 (0)