Skip to content

Commit b8587ba

Browse files
committed
make release-tag: Merge branch 'main' into stable
2 parents 3c5bacd + 072f64a commit b8587ba

File tree

6 files changed

+106
-7
lines changed

6 files changed

+106
-7
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# History
22

3+
## v0.21.0 - 2025-05-29
4+
5+
### New Features
6+
7+
* Add a violin plot visualizations to compare a pair of columns - Issue [#759](https://github.com/sdv-dev/SDMetrics/issues/759) by @R-Palazzo
8+
39
## v0.20.1 - 2025-04-14
410

511
### Bugs Fixed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ convention = 'google'
141141
add-ignore = ['D107', 'D407', 'D417']
142142

143143
[tool.bumpversion]
144-
current_version = "0.20.1"
144+
current_version = "0.21.0.dev1"
145145
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
146146
serialize = [
147147
'{major}.{minor}.{patch}.{release}{candidate}',

sdmetrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
__author__ = 'MIT Data To AI Lab'
66
__email__ = 'dailabmit@gmail.com'
7-
__version__ = '0.20.1'
7+
__version__ = '0.21.0.dev1'
88

99
import sys
1010
import warnings as python_warnings

sdmetrics/visualization.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,33 @@ def _generate_box_plot(all_data, columns):
153153
return fig
154154

155155

156+
def _generate_violin_plot(data, columns):
157+
"""Return a violin plot for a given column pair."""
158+
fig = px.violin(
159+
data,
160+
x=columns[0],
161+
y=columns[1],
162+
box=False,
163+
violinmode='overlay',
164+
color='Data',
165+
color_discrete_map={
166+
'Real': PlotConfig.DATACEBO_DARK,
167+
'Synthetic': PlotConfig.DATACEBO_GREEN,
168+
},
169+
)
170+
171+
unique_values = data['Data'].unique()
172+
title = ' vs. '.join(unique_values)
173+
title += f" Data for columns '{columns[0]}' and '{columns[1]}'"
174+
fig.update_layout(
175+
title=title,
176+
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
177+
font={'size': PlotConfig.FONT_SIZE},
178+
)
179+
180+
return fig
181+
182+
156183
def _generate_scatter_plot(all_data, columns):
157184
"""Generate a scatter plot for column pair plot.
158185
@@ -615,10 +642,10 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
615642
)
616643
synthetic_data = synthetic_data[column_names]
617644

618-
if plot_type not in ['box', 'heatmap', 'scatter', None]:
645+
if plot_type not in ['box', 'heatmap', 'scatter', 'violin', None]:
619646
raise ValueError(
620647
f"Invalid plot_type '{plot_type}'. Please use one of "
621-
"['box', 'heatmap', 'scatter', None]."
648+
"['box', 'heatmap', 'scatter', 'violin', None]."
622649
)
623650

624651
if plot_type is None:
@@ -654,6 +681,8 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
654681
return _generate_scatter_plot(all_data, column_names)
655682
elif plot_type == 'heatmap':
656683
return _generate_heatmap_plot(all_data, column_names)
684+
elif plot_type == 'violin':
685+
return _generate_violin_plot(all_data, column_names)
657686

658687
return _generate_box_plot(all_data, column_names)
659688

static_code_analysis.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Run started:2025-04-14 17:22:41.156789
1+
Run started:2025-05-29 18:58:21.274050
22

33
Test results:
44
>> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
@@ -117,7 +117,7 @@ Test results:
117117
--------------------------------------------------
118118

119119
Code scanned:
120-
Total lines of code: 11344
120+
Total lines of code: 11374
121121
Total lines skipped (#nosec): 0
122122
Total potential issues skipped due to specifically being disabled (e.g., #nosec BXXX): 0
123123

tests/unit/test_visualization.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_generate_heatmap_plot,
1616
_generate_line_plot,
1717
_generate_scatter_plot,
18+
_generate_violin_plot,
1819
_get_cardinality,
1920
_get_max_between_datasets,
2021
_get_min_between_datasets,
@@ -1004,6 +1005,47 @@ def test__generate_box_plot_title_one_dataset_only(px_mock):
10041005
assert fig_real == mock_figure
10051006

10061007

1008+
@patch('sdmetrics.visualization.px')
1009+
def test__generate_violin_plot(px_mock):
1010+
"""Test the ``_generate_violin_plot`` method."""
1011+
# Setup
1012+
real_column = pd.DataFrame({
1013+
'col1': [1, 2, 3, 4],
1014+
'col2': ['a', 'b', 'c', 'd'],
1015+
'Data': ['Real'] * 4,
1016+
})
1017+
synthetic_column = pd.DataFrame({
1018+
'col1': [1, 2, 4, 5],
1019+
'col2': ['a', 'b', 'c', 'd'],
1020+
'Data': ['Synthetic'] * 4,
1021+
})
1022+
columns = ['col1', 'col2']
1023+
all_data = pd.concat([real_column, synthetic_column], axis=0, ignore_index=True)
1024+
1025+
mock_figure = Mock()
1026+
px_mock.violin.return_value = mock_figure
1027+
1028+
# Run
1029+
fig = _generate_violin_plot(all_data, columns)
1030+
1031+
# Assert
1032+
px_mock.violin.assert_called_once_with(
1033+
DataFrameMatcher(all_data),
1034+
x='col1',
1035+
y='col2',
1036+
box=False,
1037+
violinmode='overlay',
1038+
color='Data',
1039+
color_discrete_map={'Real': '#000036', 'Synthetic': '#01E0C9'},
1040+
)
1041+
mock_figure.update_layout.assert_called_once_with(
1042+
title="Real vs. Synthetic Data for columns 'col1' and 'col2'",
1043+
plot_bgcolor='#F5F5F8',
1044+
font={'size': 18},
1045+
)
1046+
assert fig == mock_figure
1047+
1048+
10071049
def test_get_column_pair_plot_invalid_column_names():
10081050
"""Test ``get_column_pair_plot`` method with invalid ``column_names``."""
10091051
# Setup
@@ -1049,12 +1091,34 @@ def test_get_column_pair_plot_invalid_plot_type():
10491091

10501092
# Run and Assert
10511093
match = re.escape(
1052-
"Invalid plot_type 'distplot'. Please use one of ['box', 'heatmap', 'scatter', None]."
1094+
"Invalid plot_type 'distplot'. Please use one of ['box', 'heatmap', 'scatter',"
1095+
" 'violin', None]."
10531096
)
10541097
with pytest.raises(ValueError, match=match):
10551098
get_column_pair_plot(real_data, synthetic_data, columns, plot_type='distplot')
10561099

10571100

1101+
@patch('sdmetrics.visualization._generate_violin_plot')
1102+
def test_get_column_pair_plot_violin(mock__generate_violin_plot):
1103+
"""Test ``get_column_pair_plot`` with ``plot_type`` set to ``violin``."""
1104+
# Setup
1105+
columns = ['amount', 'price']
1106+
real_data = pd.DataFrame({'amount': [1, 2, 3], 'price': [4, 5, 6]})
1107+
synthetic_data = pd.DataFrame({'amount': [1.0, 2.0, 3.0], 'price': [4.0, 5.0, 6.0]})
1108+
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
1109+
all_data['Data'] = ['Real'] * 3 + ['Synthetic'] * 3
1110+
1111+
# Run
1112+
fig = get_column_pair_plot(real_data, synthetic_data, columns, plot_type='violin')
1113+
1114+
# Assert
1115+
mock__generate_violin_plot.assert_called_once_with(
1116+
DataFrameMatcher(all_data),
1117+
['amount', 'price'],
1118+
)
1119+
assert fig == mock__generate_violin_plot.return_value
1120+
1121+
10581122
@patch('sdmetrics.visualization._generate_scatter_plot')
10591123
def test_get_column_pair_plot_plot_type_none_continuous_data(mock__generate_scatter_plot):
10601124
"""Test ``get_column_pair_plot`` with continuous data and ``plot_type`` ``None``."""

0 commit comments

Comments
 (0)