Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 991c844

Browse files
authored
Merge pull request #437 from dlawin/issue_404
support custom schemas
2 parents 5f56d9f + de2b7c9 commit 991c844

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

data_diff/dbt.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def dbt_diff(
7777
config_prod_database = datadiff_variables.get("prod_database")
7878
config_prod_schema = datadiff_variables.get("prod_schema")
7979
datasource_id = datadiff_variables.get("datasource_id")
80+
custom_schemas = datadiff_variables.get("custom_schemas")
81+
# custom schemas is default dbt behavior, so default to True if the var doesn't exist
82+
custom_schemas = True if custom_schemas is None else custom_schemas
8083

8184
if not is_cloud:
8285
dbt_parser.set_connection()
@@ -87,7 +90,9 @@ def dbt_diff(
8790
)
8891

8992
for model in models:
90-
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id)
93+
diff_vars = _get_diff_vars(
94+
dbt_parser, config_prod_database, config_prod_schema, model, datasource_id, custom_schemas
95+
)
9196

9297
if is_cloud and len(diff_vars.primary_keys) > 0:
9398
_cloud_diff(diff_vars)
@@ -112,6 +117,7 @@ def _get_diff_vars(
112117
config_prod_schema: Optional[str],
113118
model,
114119
datasource_id: int,
120+
custom_schemas: bool,
115121
) -> DiffVars:
116122
dev_database = model.database
117123
dev_schema = model.schema_
@@ -120,6 +126,12 @@ def _get_diff_vars(
120126
prod_database = config_prod_database if config_prod_database else dev_database
121127
prod_schema = config_prod_schema if config_prod_schema else dev_schema
122128

129+
# if project has custom schemas (default)
130+
# need to construct the prod schema as <prod_target_schema>_<custom_schema>
131+
# https://docs.getdbt.com/docs/build/custom-schemas
132+
if custom_schemas and model.config.schema_:
133+
prod_schema = prod_schema + "_" + model.config.schema_
134+
123135
if dbt_parser.requires_upper:
124136
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
125137
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
@@ -128,16 +140,22 @@ def _get_diff_vars(
128140
dev_qualified_list = [dev_database, dev_schema, model.alias]
129141
prod_qualified_list = [prod_database, prod_schema, model.alias]
130142

131-
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads)
143+
return DiffVars(
144+
dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads
145+
)
132146

133147

134148
def _local_diff(diff_vars: DiffVars) -> None:
135149
column_diffs_str = ""
136150
dev_qualified_string = ".".join(diff_vars.dev_path)
137151
prod_qualified_string = ".".join(diff_vars.prod_path)
138152

139-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
140-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads)
153+
table1 = connect_to_table(
154+
diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
155+
)
156+
table2 = connect_to_table(
157+
diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys), diff_vars.threads
158+
)
141159

142160
table1_columns = list(table1.get_schema())
143161
try:

0 commit comments

Comments
 (0)