@@ -425,20 +425,27 @@ def test_local_diff(self, mock_diff_tables):
425425 mock_diff = MagicMock ()
426426 mock_diff_tables .return_value = mock_diff
427427 mock_diff .__iter__ .return_value = [1 , 2 , 3 ]
428+ threads = None
429+ where = "a_string"
428430 dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
429431 prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
430432 expected_keys = ["key" ]
431- diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , None )
433+ diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , threads , where )
432434 with patch ("data_diff.dbt.connect_to_table" , side_effect = [mock_table1 , mock_table2 ]) as mock_connect :
433435 _local_diff (diff_vars )
434436
435437 mock_diff_tables .assert_called_once_with (
436- mock_table1 , mock_table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = ANY
438+ mock_table1 ,
439+ mock_table2 ,
440+ threaded = True ,
441+ algorithm = Algorithm .JOINDIFF ,
442+ extra_columns = ANY ,
443+ where = where ,
437444 )
438445 self .assertEqual (len (mock_diff_tables .call_args [1 ]["extra_columns" ]), 2 )
439446 self .assertEqual (mock_connect .call_count , 2 )
440- mock_connect .assert_any_call (mock_connection , "." .join (dev_qualified_list ), tuple (expected_keys ), None )
441- mock_connect .assert_any_call (mock_connection , "." .join (prod_qualified_list ), tuple (expected_keys ), None )
447+ mock_connect .assert_any_call (mock_connection , "." .join (dev_qualified_list ), tuple (expected_keys ), threads )
448+ mock_connect .assert_any_call (mock_connection , "." .join (prod_qualified_list ), tuple (expected_keys ), threads )
442449 mock_diff .get_stats_string .assert_called_once ()
443450
444451 @patch ("data_diff.dbt.diff_tables" )
@@ -455,12 +462,14 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
455462 dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
456463 prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
457464 expected_keys = ["primary_key_column" ]
458- diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , None )
465+ threads = None
466+ where = "a_string"
467+ diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , threads , where )
459468 with patch ("data_diff.dbt.connect_to_table" , side_effect = [mock_table1 , mock_table2 ]) as mock_connect :
460469 _local_diff (diff_vars )
461470
462471 mock_diff_tables .assert_called_once_with (
463- mock_table1 , mock_table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = ANY
472+ mock_table1 , mock_table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = ANY , where = where
464473 )
465474 self .assertEqual (len (mock_diff_tables .call_args [1 ]["extra_columns" ]), 2 )
466475 self .assertEqual (mock_connect .call_count , 2 )
@@ -479,7 +488,10 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
479488 prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
480489 expected_datasource_id = 1
481490 expected_primary_keys = ["primary_key_column" ]
482- diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_primary_keys , None , None )
491+ connection = None
492+ threads = None
493+ where = "a_string"
494+ diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_primary_keys , connection , threads , where )
483495 _cloud_diff (diff_vars , expected_datasource_id , api = mock_api )
484496
485497 mock_api .create_data_diff .assert_called_once ()
@@ -491,6 +503,8 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
491503 self .assertEqual (payload .table1 , prod_qualified_list )
492504 self .assertEqual (payload .table2 , dev_qualified_list )
493505 self .assertEqual (payload .pk_columns , expected_primary_keys )
506+ self .assertEqual (payload .filter1 , where )
507+ self .assertEqual (payload .filter2 , where )
494508
495509 @patch ("data_diff.dbt._initialize_api" )
496510 @patch ("data_diff.dbt._get_diff_vars" )
@@ -512,11 +526,14 @@ def test_diff_is_cloud(
512526 api_key = "a_api_key"
513527 api = DatafoldAPI (api_key = api_key , host = host )
514528 mock_initialize_api .return_value = api
529+ connection = None
530+ threads = None
531+ where = "a_string"
515532
516533 mock_dbt_parser .return_value = mock_dbt_parser_inst
517534 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
518535 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
519- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
536+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
520537 mock_get_diff_vars .return_value = expected_diff_vars
521538 dbt_diff (is_cloud = True )
522539 mock_dbt_parser_inst .get_models .assert_called_once ()
@@ -547,11 +564,14 @@ def test_diff_is_cloud_no_ds_id(
547564 api_key = "a_api_key"
548565 api = DatafoldAPI (api_key = api_key , host = host )
549566 mock_initialize_api .return_value = api
567+ connection = None
568+ threads = None
569+ where = "a_string"
550570
551571 mock_dbt_parser .return_value = mock_dbt_parser_inst
552572 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
553573 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
554- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
574+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
555575 mock_get_diff_vars .return_value = expected_diff_vars
556576
557577 with self .assertRaises (ValueError ):
@@ -579,7 +599,10 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
579599 }
580600 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
581601 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
582- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
602+ connection = None
603+ threads = None
604+ where = "a_string"
605+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
583606 mock_get_diff_vars .return_value = expected_diff_vars
584607 dbt_diff (is_cloud = False )
585608
@@ -606,7 +629,10 @@ def test_diff_no_prod_configs(
606629
607630 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
608631 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
609- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
632+ connection = None
633+ threads = None
634+ where = "a_string"
635+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
610636 mock_get_diff_vars .return_value = expected_diff_vars
611637 with self .assertRaises (ValueError ):
612638 dbt_diff (is_cloud = False )
@@ -633,7 +659,10 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
633659 }
634660 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
635661 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
636- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
662+ connection = None
663+ threads = None
664+ where = "a_string"
665+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
637666 mock_get_diff_vars .return_value = expected_diff_vars
638667 dbt_diff (is_cloud = False )
639668
@@ -661,7 +690,10 @@ def test_diff_only_prod_schema(
661690
662691 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
663692 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
664- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
693+ connection = None
694+ threads = None
695+ where = "a_string"
696+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
665697 mock_get_diff_vars .return_value = expected_diff_vars
666698 with self .assertRaises (ValueError ):
667699 dbt_diff (is_cloud = False )
@@ -697,7 +729,10 @@ def test_diff_is_cloud_no_pks(
697729
698730 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
699731 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
700- expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], None , None )
732+ connection = None
733+ threads = None
734+ where = "a_string"
735+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], connection , threads , where )
701736 mock_get_diff_vars .return_value = expected_diff_vars
702737 dbt_diff (is_cloud = True )
703738
@@ -727,8 +762,10 @@ def test_diff_not_is_cloud_no_pks(
727762
728763 mock_dbt_parser_inst .get_models .return_value = [mock_model ]
729764 mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
730-
731- expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], None , None )
765+ connection = None
766+ threads = None
767+ where = "a_string"
768+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], connection , threads , where )
732769 mock_get_diff_vars .return_value = expected_diff_vars
733770 dbt_diff (is_cloud = False )
734771 mock_dbt_parser_inst .get_models .assert_called_once ()
@@ -749,6 +786,7 @@ def test_get_diff_vars_replace_custom_schema(self):
749786 mock_dbt_parser = Mock ()
750787 mock_dbt_parser .get_pk_from_model .return_value = primary_keys
751788 mock_dbt_parser .requires_upper = False
789+ mock_model .meta = None
752790
753791 diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod_<custom_schema>" , mock_model )
754792
@@ -773,6 +811,7 @@ def test_get_diff_vars_static_custom_schema(self):
773811 mock_dbt_parser = Mock ()
774812 mock_dbt_parser .get_pk_from_model .return_value = primary_keys
775813 mock_dbt_parser .requires_upper = False
814+ mock_model .meta = None
776815
777816 diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
778817
@@ -796,6 +835,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
796835 mock_dbt_parser = Mock ()
797836 mock_dbt_parser .get_pk_from_model .return_value = primary_keys
798837 mock_dbt_parser .requires_upper = False
838+ mock_model .meta = None
799839
800840 diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
801841
@@ -817,6 +857,7 @@ def test_get_diff_vars_match_dev_schema(self):
817857 mock_dbt_parser = Mock ()
818858 mock_dbt_parser .get_pk_from_model .return_value = primary_keys
819859 mock_dbt_parser .requires_upper = False
860+ mock_model .meta = None
820861
821862 diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
822863
@@ -844,3 +885,75 @@ def test_get_diff_custom_schema_no_config_exception(self):
844885 _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , None , mock_model )
845886
846887 mock_dbt_parser .get_pk_from_model .assert_called_once ()
888+
889+ def test_get_diff_vars_meta_where (self ):
890+ mock_model = Mock ()
891+ prod_database = "a_prod_db"
892+ primary_keys = ["a_primary_key" ]
893+ mock_model .database = "a_dev_db"
894+ mock_model .schema_ = "a_schema"
895+ mock_model .config .schema_ = None
896+ mock_model .alias = "a_model_name"
897+ mock_dbt_parser = Mock ()
898+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
899+ mock_dbt_parser .requires_upper = False
900+ where = "a filter"
901+ mock_model .meta = {"datafold" : {"datadiff" : {"filter" : where }}}
902+
903+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
904+
905+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
906+ assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
907+ assert diff_vars .primary_keys == primary_keys
908+ assert diff_vars .connection == mock_dbt_parser .connection
909+ assert diff_vars .threads == mock_dbt_parser .threads
910+ self .assertEqual (diff_vars .where_filter , where )
911+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
912+
913+ def test_get_diff_vars_meta_unrelated (self ):
914+ mock_model = Mock ()
915+ prod_database = "a_prod_db"
916+ primary_keys = ["a_primary_key" ]
917+ mock_model .database = "a_dev_db"
918+ mock_model .schema_ = "a_schema"
919+ mock_model .config .schema_ = None
920+ mock_model .alias = "a_model_name"
921+ mock_dbt_parser = Mock ()
922+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
923+ mock_dbt_parser .requires_upper = False
924+ where = None
925+ mock_model .meta = {"key" : "value" }
926+
927+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
928+
929+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
930+ assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
931+ assert diff_vars .primary_keys == primary_keys
932+ assert diff_vars .connection == mock_dbt_parser .connection
933+ assert diff_vars .threads == mock_dbt_parser .threads
934+ self .assertEqual (diff_vars .where_filter , where )
935+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
936+
937+ def test_get_diff_vars_meta_none (self ):
938+ mock_model = Mock ()
939+ prod_database = "a_prod_db"
940+ primary_keys = ["a_primary_key" ]
941+ mock_model .database = "a_dev_db"
942+ mock_model .schema_ = "a_schema"
943+ mock_model .config .schema_ = None
944+ mock_model .alias = "a_model_name"
945+ mock_dbt_parser = Mock ()
946+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
947+ mock_dbt_parser .requires_upper = False
948+ where = None
949+ mock_model .meta = None
950+
951+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
952+
953+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
954+ assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
955+ assert diff_vars .primary_keys == primary_keys
956+ assert diff_vars .connection == mock_dbt_parser .connection
957+ assert diff_vars .threads == mock_dbt_parser .threads
958+ self .assertEqual (diff_vars .where_filter , where )
959+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
0 commit comments