11import os
22
33from pathlib import Path
4-
54from data_diff .cloud .datafold_api import TCloudApiDataSource
5+ from data_diff .cloud .datafold_api import TCloudApiOrgMeta
66from data_diff .diff_tables import Algorithm
77from .test_cli import run_datadiff_cli
88
@@ -569,6 +569,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
569569 @patch ("data_diff.dbt.os.environ" )
570570 @patch ("data_diff.dbt.DatafoldAPI" )
571571 def test_cloud_diff (self , mock_api , mock_os_environ , mock_print ):
572+ org_meta = TCloudApiOrgMeta (org_id = 1 , org_name = "" , user_id = 1 )
572573 expected_api_key = "an_api_key"
573574 dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
574575 prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
@@ -591,7 +592,7 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
591592 exclude_columns = [],
592593 )
593594
594- _cloud_diff (diff_vars , expected_datasource_id , api = mock_api )
595+ _cloud_diff (diff_vars , expected_datasource_id , org_meta = org_meta , api = mock_api )
595596
596597 mock_api .create_data_diff .assert_called_once ()
597598 self .assertEqual (mock_print .call_count , 2 )
@@ -613,8 +614,16 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
613614 @patch ("data_diff.dbt.rich.print" )
614615 @patch ("data_diff.dbt.DatafoldAPI" )
615616 def test_diff_is_cloud (
616- self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api ,
617+ self ,
618+ mock_api ,
619+ mock_print ,
620+ mock_dbt_parser ,
621+ mock_cloud_diff ,
622+ mock_local_diff ,
623+ mock_get_diff_vars ,
624+ mock_initialize_api ,
617625 ):
626+ org_meta = TCloudApiOrgMeta (org_id = 1 , org_name = "" , user_id = 1 )
618627 connection = {}
619628 threads = None
620629 where = "a_string"
@@ -627,6 +636,8 @@ def test_diff_is_cloud(
627636 mock_model = Mock ()
628637 mock_api .get_data_source .return_value = TCloudApiDataSource (id = 1 , type = "snowflake" , name = "snowflake" )
629638 mock_initialize_api .return_value = mock_api
639+ mock_api .get_org_meta .return_value = org_meta
640+
630641 mock_dbt_parser .return_value = mock_dbt_parser_inst
631642 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
632643 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -649,7 +660,7 @@ def test_diff_is_cloud(
649660
650661 mock_initialize_api .assert_called_once ()
651662 mock_api .get_data_source .assert_called_once_with (1 )
652- mock_cloud_diff .assert_called_once_with (diff_vars , 1 , mock_api )
663+ mock_cloud_diff .assert_called_once_with (diff_vars , 1 , mock_api , org_meta )
653664 mock_local_diff .assert_not_called ()
654665 mock_print .assert_called_once ()
655666
@@ -663,20 +674,20 @@ def test_diff_is_cloud(
663674 def test_diff_is_cloud_no_ds_id (
664675 self , _ , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
665676 ):
677+ org_meta = TCloudApiOrgMeta (org_id = 1 , org_name = "" , user_id = 1 )
666678 connection = {}
667679 threads = None
668680 where = "a_string"
669- host = "a_host"
670- api_key = "a_api_key"
671681 mock_dbt_parser_inst = Mock ()
672682 mock_model = Mock ()
673683 expected_dbt_vars_dict = {
674684 "prod_database" : "prod_db" ,
675685 "prod_schema" : "prod_schema" ,
676686 }
687+ mock_api = Mock ()
688+ mock_initialize_api .return_value = mock_api
689+ mock_api .get_org_meta .return_value = org_meta
677690
678- api = DatafoldAPI (api_key = api_key , host = host )
679- mock_initialize_api .return_value = api
680691 mock_dbt_parser .return_value = mock_dbt_parser_inst
681692 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
682693 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -827,8 +838,18 @@ def test_diff_only_prod_schema(
827838 @patch ("data_diff.dbt.rich.print" )
828839 @patch ("data_diff.dbt.DatafoldAPI" )
829840 def test_diff_is_cloud_no_pks (
830- self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
841+ self ,
842+ mock_api ,
843+ mock_print ,
844+ mock_dbt_parser ,
845+ mock_cloud_diff ,
846+ mock_local_diff ,
847+ mock_get_diff_vars ,
848+ mock_initialize_api ,
831849 ):
850+ mock_dbt_parser_inst = Mock ()
851+ mock_dbt_parser .return_value = mock_dbt_parser_inst
852+ mock_model = Mock ()
832853 connection = {}
833854 threads = None
834855 where = "a_string"
@@ -837,11 +858,8 @@ def test_diff_is_cloud_no_pks(
837858 "prod_schema" : "prod_schema" ,
838859 "datasource_id" : 1 ,
839860 }
840- mock_dbt_parser_inst = Mock ()
841- mock_dbt_parser .return_value = mock_dbt_parser_inst
842- mock_model = Mock ()
861+ mock_api = Mock ()
843862 mock_initialize_api .return_value = mock_api
844- mock_api .get_data_source .return_value = TCloudApiDataSource (id = 1 , type = "snowflake" , name = "snowflake" )
845863
846864 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
847865 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
0 commit comments