@@ -554,6 +554,63 @@ def test_diff_no_prod_configs(
554554 mock_local_diff .assert_not_called ()
555555 mock_print .assert_not_called ()
556556
557+ @patch ("data_diff.dbt._get_diff_vars" )
558+ @patch ("data_diff.dbt._local_diff" )
559+ @patch ("data_diff.dbt._cloud_diff" )
560+ @patch ("data_diff.dbt.DbtParser.__new__" )
561+ @patch ("data_diff.dbt.rich.print" )
562+ def test_diff_only_prod_db (self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars ):
563+ mock_dbt_parser_inst = Mock ()
564+ mock_dbt_parser .return_value = mock_dbt_parser_inst
565+ mock_model = Mock ()
566+ expected_dbt_vars_dict = {
567+ "prod_database" : "prod_db" ,
568+ "datasource_id" : 1 ,
569+ }
570+ mock_dbt_parser_inst .get_models .return_value = [mock_model ]
571+ mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
572+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], 123 , None )
573+ mock_get_diff_vars .return_value = expected_diff_vars
574+ dbt_diff (is_cloud = False )
575+
576+ mock_dbt_parser_inst .get_models .assert_called_once ()
577+ mock_dbt_parser_inst .set_project_dict .assert_called_once ()
578+ mock_dbt_parser_inst .set_connection .assert_called_once ()
579+ mock_cloud_diff .assert_not_called ()
580+ mock_local_diff .assert_called_once_with (expected_diff_vars )
581+ mock_print .assert_called_once ()
582+
583+ @patch ("data_diff.dbt._get_diff_vars" )
584+ @patch ("data_diff.dbt._local_diff" )
585+ @patch ("data_diff.dbt._cloud_diff" )
586+ @patch ("data_diff.dbt.DbtParser.__new__" )
587+ @patch ("data_diff.dbt.rich.print" )
588+ def test_diff_only_prod_schema (
589+ self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars
590+ ):
591+ mock_dbt_parser_inst = Mock ()
592+ mock_dbt_parser .return_value = mock_dbt_parser_inst
593+ mock_model = Mock ()
594+ expected_dbt_vars_dict = {
595+ "datasource_id" : 1 ,
596+ "prod_schema" : "prod_schema" ,
597+ }
598+
599+ mock_dbt_parser_inst .get_models .return_value = [mock_model ]
600+ mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
601+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], 123 , None )
602+ mock_get_diff_vars .return_value = expected_diff_vars
603+ with self .assertRaises (ValueError ):
604+ dbt_diff (is_cloud = False )
605+
606+ mock_dbt_parser_inst .get_models .assert_called_once ()
607+ mock_dbt_parser_inst .set_project_dict .assert_called_once ()
608+ mock_dbt_parser_inst .set_connection .assert_called_once ()
609+ mock_dbt_parser_inst .get_primary_keys .assert_not_called ()
610+ mock_cloud_diff .assert_not_called ()
611+ mock_local_diff .assert_not_called ()
612+ mock_print .assert_not_called ()
613+
557614 @patch ("data_diff.dbt._get_diff_vars" )
558615 @patch ("data_diff.dbt._local_diff" )
559616 @patch ("data_diff.dbt._cloud_diff" )
0 commit comments