|
15 | 15 | _generate_heatmap_plot, |
16 | 16 | _generate_line_plot, |
17 | 17 | _generate_scatter_plot, |
| 18 | + _generate_violin_plot, |
18 | 19 | _get_cardinality, |
19 | 20 | _get_max_between_datasets, |
20 | 21 | _get_min_between_datasets, |
@@ -1004,6 +1005,47 @@ def test__generate_box_plot_title_one_dataset_only(px_mock): |
1004 | 1005 | assert fig_real == mock_figure |
1005 | 1006 |
|
1006 | 1007 |
|
| 1008 | +@patch('sdmetrics.visualization.px') |
| 1009 | +def test__generate_violin_plot(px_mock): |
| 1010 | + """Test the ``_generate_violin_plot`` method.""" |
| 1011 | + # Setup |
| 1012 | + real_column = pd.DataFrame({ |
| 1013 | + 'col1': [1, 2, 3, 4], |
| 1014 | + 'col2': ['a', 'b', 'c', 'd'], |
| 1015 | + 'Data': ['Real'] * 4, |
| 1016 | + }) |
| 1017 | + synthetic_column = pd.DataFrame({ |
| 1018 | + 'col1': [1, 2, 4, 5], |
| 1019 | + 'col2': ['a', 'b', 'c', 'd'], |
| 1020 | + 'Data': ['Synthetic'] * 4, |
| 1021 | + }) |
| 1022 | + columns = ['col1', 'col2'] |
| 1023 | + all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True) |
| 1024 | + |
| 1025 | + mock_figure = Mock() |
| 1026 | + px_mock.violin.return_value = mock_figure |
| 1027 | + |
| 1028 | + # Run |
| 1029 | + fig = _generate_violin_plot(all_data, columns) |
| 1030 | + |
| 1031 | + # Assert |
| 1032 | + px_mock.violin.assert_called_once_with( |
| 1033 | + DataFrameMatcher(all_data), |
| 1034 | + x='col1', |
| 1035 | + y='col2', |
| 1036 | + box=False, |
| 1037 | + violinmode='overlay', |
| 1038 | + color='Data', |
| 1039 | + color_discrete_map={'Real': '#000036', 'Synthetic': '#01E0C9'}, |
| 1040 | + ) |
| 1041 | + mock_figure.update_layout.assert_called_once_with( |
| 1042 | + title="Real vs. Synthetic Data for columns 'col1' and 'col2'", |
| 1043 | + plot_bgcolor='#F5F5F8', |
| 1044 | + font={'size': 18}, |
| 1045 | + ) |
| 1046 | + assert fig == mock_figure |
| 1047 | + |
| 1048 | + |
1007 | 1049 | def test_get_column_pair_plot_invalid_column_names(): |
1008 | 1050 | """Test ``get_column_pair_plot`` method with invalid ``column_names``.""" |
1009 | 1051 | # Setup |
@@ -1049,12 +1091,34 @@ def test_get_column_pair_plot_invalid_plot_type(): |
1049 | 1091 |
|
1050 | 1092 | # Run and Assert |
1051 | 1093 | match = re.escape( |
1052 | | - "Invalid plot_type 'distplot'. Please use one of ['box', 'heatmap', 'scatter', None]." |
| 1094 | + "Invalid plot_type 'distplot'. Please use one of ['box', 'heatmap', 'scatter'," |
| 1095 | + " 'violin', None]." |
1053 | 1096 | ) |
1054 | 1097 | with pytest.raises(ValueError, match=match): |
1055 | 1098 | get_column_pair_plot(real_data, synthetic_data, columns, plot_type='distplot') |
1056 | 1099 |
|
1057 | 1100 |
|
| 1101 | +@patch('sdmetrics.visualization._generate_violin_plot') |
| 1102 | +def test_get_column_pair_plot_violin(mock__generate_violin_plot): |
| 1103 | + """Test ``get_column_pair_plot`` with ``plot_type`` set to ``violin``.""" |
| 1104 | + # Setup |
| 1105 | + columns = ['amount', 'price'] |
| 1106 | + real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]}) |
| 1107 | + synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]}) |
| 1108 | + all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True) |
| 1109 | + all_data['Data'] = ['Real'] * 3 + ['Synthetic'] * 3 |
| 1110 | + |
| 1111 | + # Run |
| 1112 | + fig = get_column_pair_plot(real_data, synthetic_data, columns, plot_type='violin') |
| 1113 | + |
| 1114 | + # Assert |
| 1115 | + mock__generate_violin_plot.assert_called_once_with( |
| 1116 | + DataFrameMatcher(all_data), |
| 1117 | + ['amount', 'price'], |
| 1118 | + ) |
| 1119 | + assert fig == mock__generate_violin_plot.return_value |
| 1120 | + |
| 1121 | + |
1058 | 1122 | @patch('sdmetrics.visualization._generate_scatter_plot') |
1059 | 1123 | def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scatter_plot): |
1060 | 1124 | """Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``.""" |
|
0 commit comments