@@ -77,6 +77,9 @@ def dbt_diff(
7777 config_prod_database = datadiff_variables .get ("prod_database" )
7878 config_prod_schema = datadiff_variables .get ("prod_schema" )
7979 datasource_id = datadiff_variables .get ("datasource_id" )
80+ custom_schemas = datadiff_variables .get ("custom_schemas" )
81+ # custom schemas is default dbt behavior, so default to True if the var doesn't exist
82+ custom_schemas = True if custom_schemas is None else custom_schemas
8083
8184 if not is_cloud :
8285 dbt_parser .set_connection ()
@@ -87,7 +90,9 @@ def dbt_diff(
8790 )
8891
8992 for model in models :
90- diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , datasource_id )
93+ diff_vars = _get_diff_vars (
94+ dbt_parser , config_prod_database , config_prod_schema , model , datasource_id , custom_schemas
95+ )
9196
9297 if is_cloud and len (diff_vars .primary_keys ) > 0 :
9398 _cloud_diff (diff_vars )
@@ -112,6 +117,7 @@ def _get_diff_vars(
112117 config_prod_schema : Optional [str ],
113118 model ,
114119 datasource_id : int ,
120+ custom_schemas : bool ,
115121) -> DiffVars :
116122 dev_database = model .database
117123 dev_schema = model .schema_
@@ -120,6 +126,12 @@ def _get_diff_vars(
120126 prod_database = config_prod_database if config_prod_database else dev_database
121127 prod_schema = config_prod_schema if config_prod_schema else dev_schema
122128
129+ # if project has custom schemas (default)
130+ # need to construct the prod schema as <prod_target_schema>_<custom_schema>
131+ # https://docs.getdbt.com/docs/build/custom-schemas
132+ if custom_schemas and model .config .schema_ :
133+ prod_schema = prod_schema + "_" + model .config .schema_
134+
123135 if dbt_parser .requires_upper :
124136 dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .alias ]]
125137 prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .alias ]]
@@ -128,16 +140,22 @@ def _get_diff_vars(
128140 dev_qualified_list = [dev_database , dev_schema , model .alias ]
129141 prod_qualified_list = [prod_database , prod_schema , model .alias ]
130142
131- return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads )
143+ return DiffVars (
144+ dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads
145+ )
132146
133147
134148def _local_diff (diff_vars : DiffVars ) -> None :
135149 column_diffs_str = ""
136150 dev_qualified_string = "." .join (diff_vars .dev_path )
137151 prod_qualified_string = "." .join (diff_vars .prod_path )
138152
139- table1 = connect_to_table (diff_vars .connection , dev_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads )
140- table2 = connect_to_table (diff_vars .connection , prod_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads )
153+ table1 = connect_to_table (
154+ diff_vars .connection , dev_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads
155+ )
156+ table2 = connect_to_table (
157+ diff_vars .connection , prod_qualified_string , tuple (diff_vars .primary_keys ), diff_vars .threads
158+ )
141159
142160 table1_columns = list (table1 .get_schema ())
143161 try :
0 commit comments