@@ -740,9 +740,6 @@ def test_diff_not_is_cloud_no_pks(
740740 "prod_schema" : "prod_schema" ,
741741 "datasource_id" : 1 ,
742742 }
743- host = "a_host"
744- url = "a_url"
745- api_key = "a_api_key"
746743
747744 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
748745 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -756,42 +753,67 @@ def test_diff_not_is_cloud_no_pks(
756753 mock_local_diff .assert_not_called ()
757754 self .assertEqual (mock_print .call_count , 1 )
758755
759- def test_get_diff_vars_custom_schemas_prod_db_and_schema (self ):
756+ def test_get_diff_vars_replace_custom_schema (self ):
760757 mock_model = Mock ()
761758 prod_database = "a_prod_db"
762759 prod_schema = "a_prod_schema"
763760 primary_keys = ["a_primary_key" ]
764761 mock_model .database = "a_dev_db"
765- mock_model .schema_ = "a_custom_dev_schema "
762+ mock_model .schema_ = "a_custom_schema "
766763 mock_model .config .schema_ = mock_model .schema_
767764 mock_model .alias = "a_model_name"
768765 mock_dbt_parser = Mock ()
769766 mock_dbt_parser .get_pk_from_model .return_value = primary_keys
770767 mock_dbt_parser .requires_upper = False
771768
772- diff_vars = _get_diff_vars (mock_dbt_parser , "a_prod_db" , "a_prod_schema " , mock_model , custom_schemas = True )
769+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod_<custom_schema> " , mock_model )
773770
774771 assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
775- assert diff_vars .prod_path == [prod_database , prod_schema + "_ " + mock_model .schema_ , mock_model .alias ]
772+ assert diff_vars .prod_path == [prod_database , "prod_ " + mock_model .schema_ , mock_model .alias ]
776773 assert diff_vars .primary_keys == primary_keys
777774 assert diff_vars .connection == mock_dbt_parser .connection
778775 assert diff_vars .threads == mock_dbt_parser .threads
776+ assert prod_schema not in diff_vars .prod_path
777+
779778 mock_dbt_parser .get_pk_from_model .assert_called_once ()
780779
781- def test_get_diff_vars_false_custom_schemas_prod_db_and_schema (self ):
780+ def test_get_diff_vars_static_custom_schema (self ):
782781 mock_model = Mock ()
783782 prod_database = "a_prod_db"
784783 prod_schema = "a_prod_schema"
785784 primary_keys = ["a_primary_key" ]
786785 mock_model .database = "a_dev_db"
787- mock_model .schema_ = "a_custom_dev_schema "
786+ mock_model .schema_ = "a_custom_schema "
788787 mock_model .config .schema_ = mock_model .schema_
789788 mock_model .alias = "a_model_name"
790789 mock_dbt_parser = Mock ()
791- mock_dbt_parser .get_pk_from_model .return_value = [ "a_primary_key" ]
790+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
792791 mock_dbt_parser .requires_upper = False
793792
794- diff_vars = _get_diff_vars (mock_dbt_parser , "a_prod_db" , "a_prod_schema" , mock_model , custom_schemas = False )
793+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
794+
795+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
796+ assert diff_vars .prod_path == [prod_database , "prod" , mock_model .alias ]
797+ assert diff_vars .primary_keys == primary_keys
798+ assert diff_vars .connection == mock_dbt_parser .connection
799+ assert diff_vars .threads == mock_dbt_parser .threads
800+ assert prod_schema not in diff_vars .prod_path
801+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
802+
803+ def test_get_diff_vars_no_custom_schema_on_model (self ):
804+ mock_model = Mock ()
805+ prod_database = "a_prod_db"
806+ prod_schema = "a_prod_schema"
807+ primary_keys = ["a_primary_key" ]
808+ mock_model .database = "a_dev_db"
809+ mock_model .schema_ = "a_custom_schema"
810+ mock_model .config .schema_ = None
811+ mock_model .alias = "a_model_name"
812+ mock_dbt_parser = Mock ()
813+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
814+ mock_dbt_parser .requires_upper = False
815+
816+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
795817
796818 assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
797819 assert diff_vars .prod_path == [prod_database , prod_schema , mock_model .alias ]
@@ -800,23 +822,41 @@ def test_get_diff_vars_false_custom_schemas_prod_db_and_schema(self):
800822 assert diff_vars .threads == mock_dbt_parser .threads
801823 mock_dbt_parser .get_pk_from_model .assert_called_once ()
802824
803- def test_get_diff_vars_false_custom_schemas_prod_db (self ):
825+ def test_get_diff_vars_match_dev_schema (self ):
804826 mock_model = Mock ()
805827 prod_database = "a_prod_db"
806828 primary_keys = ["a_primary_key" ]
807829 mock_model .database = "a_dev_db"
808- mock_model .schema_ = "a_custom_dev_schema "
809- mock_model .config .schema_ = mock_model . schema_
830+ mock_model .schema_ = "a_schema "
831+ mock_model .config .schema_ = None
810832 mock_model .alias = "a_model_name"
811833 mock_dbt_parser = Mock ()
812- mock_dbt_parser .get_pk_from_model .return_value = [ "a_primary_key" ]
834+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
813835 mock_dbt_parser .requires_upper = False
814836
815- diff_vars = _get_diff_vars (mock_dbt_parser , "a_prod_db" , None , mock_model , custom_schemas = False )
837+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
816838
817839 assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
818840 assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
819841 assert diff_vars .primary_keys == primary_keys
820842 assert diff_vars .connection == mock_dbt_parser .connection
821843 assert diff_vars .threads == mock_dbt_parser .threads
822844 mock_dbt_parser .get_pk_from_model .assert_called_once ()
845+
846+ def test_get_diff_custom_schema_no_config_exception (self ):
847+ mock_model = Mock ()
848+ prod_database = "a_prod_db"
849+ prod_schema = "a_prod_schema"
850+ primary_keys = ["a_primary_key" ]
851+ mock_model .database = "a_dev_db"
852+ mock_model .schema_ = "a_schema"
853+ mock_model .config .schema_ = "a_custom_schema"
854+ mock_model .alias = "a_model_name"
855+ mock_dbt_parser = Mock ()
856+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
857+ mock_dbt_parser .requires_upper = False
858+
859+ with self .assertRaises (ValueError ):
860+ _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , None , mock_model )
861+
862+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
0 commit comments