Skip to content

Commit e0a3a3f

Browse files
authored
Allow me to turn off or control any subsampling done within the quality report (#792)
1 parent 49cc559 commit e0a3a3f

File tree

8 files changed

+38
-6
lines changed

8 files changed

+38
-6
lines changed

sdmetrics/reports/base_report.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import tqdm
1515

1616
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata
17+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
1718
from sdmetrics.visualization import set_plotly_config
1819

1920

@@ -27,6 +28,7 @@ def __init__(self):
2728
self._overall_score = None
2829
self.is_generated = False
2930
self._properties = {}
31+
self.num_rows_subsample = DEFAULT_NUM_ROWS_SUBSAMPLE
3032
self.report_info = {
3133
'report_type': self.__class__.__name__,
3234
'generated_date': None,
@@ -163,6 +165,7 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
163165
f'({ind + 1}/{len(self._properties)}) Evaluating {property_name}'
164166
)
165167

168+
self._properties[property_name].num_rows_subsample = self.num_rows_subsample
166169
score = self._properties[property_name].get_score(
167170
real_data, synthetic_data, metadata, progress_bar=progress_bar
168171
)

sdmetrics/reports/multi_table/_properties/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numpy as np
44
import pandas as pd
55

6+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
7+
68

79
class BaseMultiTableProperty:
810
"""Base class for multi table properties.
@@ -26,6 +28,7 @@ def __init__(self):
2628
self._properties = {}
2729
self.is_computed = False
2830
self.details = pd.DataFrame()
31+
self.num_rows_subsample = DEFAULT_NUM_ROWS_SUBSAMPLE
2932

3033
def _get_num_iterations(self, metadata):
3134
"""Get the number of iterations for the property."""

sdmetrics/reports/single_table/_properties/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pandas as pd
44

5+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
6+
57

68
class BaseSingleTableProperty:
79
"""Base class for single table properties.
@@ -14,6 +16,7 @@ class BaseSingleTableProperty:
1416

1517
def __init__(self):
1618
self.details = pd.DataFrame()
19+
self.num_rows_subsample = DEFAULT_NUM_ROWS_SUBSAMPLE
1720

1821
def _compute_average(self):
1922
"""Average the scores for each column."""

sdmetrics/reports/single_table/_properties/column_pair_trends.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from sdmetrics.reports.single_table._properties import BaseSingleTableProperty
1111
from sdmetrics.reports.utils import PlotConfig
1212

13-
DEFAULT_NUM_ROWS_SUBSAMPLE = 50000
14-
1513

1614
class ColumnPairTrends(BaseSingleTableProperty):
1715
"""Column pair trends property.
@@ -30,6 +28,7 @@ class ColumnPairTrends(BaseSingleTableProperty):
3028
}
3129

3230
def __init__(self):
31+
super().__init__()
3332
self._columns_datetime_conversion_failed = {}
3433
self._columns_discretization_failed = {}
3534

@@ -276,10 +275,12 @@ def _generate_details(
276275
)
277276

278277
metric_params = {}
279-
if (metric == ContingencySimilarity) and (
280-
max(len(col_real), len(col_synthetic)) > DEFAULT_NUM_ROWS_SUBSAMPLE
278+
if (
279+
self.num_rows_subsample
280+
and (metric == ContingencySimilarity)
281+
and (max(len(col_real), len(col_synthetic)) > self.num_rows_subsample)
281282
):
282-
metric_params['num_rows_subsample'] = DEFAULT_NUM_ROWS_SUBSAMPLE
283+
metric_params['num_rows_subsample'] = self.num_rows_subsample
283284

284285
try:
285286
error = self._preprocessing_failed(

sdmetrics/reports/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
CONTINUOUS_SDTYPES = ['numerical', 'datetime']
1919
DISCRETE_SDTYPES = ['categorical', 'boolean']
20+
DEFAULT_NUM_ROWS_SUBSAMPLE = 50000
2021

2122

2223
class PlotConfig:

tests/integration/reports/single_table/test_quality_report.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_report_end_to_end(self):
7777
key: val for key, val in metadata['columns'].items() if key in column_names
7878
}
7979
report = QualityReport()
80+
report.num_rows_subsample = None
8081

8182
# Run
8283
generate_start_time = time.time()
@@ -141,7 +142,8 @@ def test_report_end_to_end(self):
141142
report.get_details('Column Pair Trends'), expected_details_cpt
142143
)
143144
assert report.get_score() == 0.8393750143888287
144-
145+
assert report._properties['Column Shapes'].num_rows_subsample is None
146+
assert report._properties['Column Pair Trends'].num_rows_subsample is None
145147
report_info = report.get_info()
146148
assert report_info == report.report_info
147149

@@ -183,6 +185,8 @@ def test_with_large_dataset(self):
183185
# Assert
184186
cpt_report_1 = report_1.get_properties().iloc[1]['Score']
185187
cpt_report_2 = report_2.get_properties().iloc[1]['Score']
188+
assert report_1._properties['Column Pair Trends'].num_rows_subsample == 50000
189+
assert report_2._properties['Column Pair Trends'].num_rows_subsample == 50000
186190
assert score_1_run_1 != score_1_run_2
187191
assert np.isclose(score_1_run_1, score_1_run_2, atol=0.001)
188192
assert np.isclose(report_2.get_score(), score_1_run_1, atol=0.001)

tests/unit/reports/multi_table/test_base_multi_table_report.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from sdmetrics.demos import load_demo
1010
from sdmetrics.reports.multi_table.base_multi_table_report import BaseMultiTableReport
11+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
1112

1213

1314
class TestBaseReport:
@@ -21,6 +22,7 @@ def test__init__(self):
2122
assert report.is_generated is False
2223
assert report._properties == {}
2324
assert report.table_names == []
25+
assert report.num_rows_subsample == DEFAULT_NUM_ROWS_SUBSAMPLE
2426

2527
def test__validate_data_format(self):
2628
"""Test the ``_validate_data_format`` method.

tests/unit/reports/test_base_report.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,21 @@
99

1010
from sdmetrics.demos import load_demo
1111
from sdmetrics.reports.base_report import BaseReport
12+
from sdmetrics.reports.utils import DEFAULT_NUM_ROWS_SUBSAMPLE
1213

1314

1415
class TestBaseReport:
16+
def test__init__(self):
17+
"""Test the initialization of the BaseReport class."""
18+
# Run
19+
base_report = BaseReport()
20+
21+
# Assert
22+
assert base_report._overall_score is None
23+
assert not base_report.is_generated
24+
assert base_report._properties == {}
25+
assert base_report.num_rows_subsample == DEFAULT_NUM_ROWS_SUBSAMPLE
26+
1527
def test__validate_data_format(self):
1628
"""Test the ``_validate_data_format`` method.
1729
@@ -268,6 +280,7 @@ def test_generate(self, version_mock, time_mock, datetime_mock):
268280
version_mock.return_value = 'version'
269281

270282
base_report = BaseReport()
283+
base_report.num_rows_subsample = 1000
271284
mock_validate = Mock()
272285
mock__print_results = Mock()
273286
base_report._print_results = mock__print_results
@@ -292,9 +305,11 @@ def test_generate(self, version_mock, time_mock, datetime_mock):
292305
base_report._properties['Property 1'].get_score.assert_called_with(
293306
real_data, synthetic_data, metadata, progress_bar=None
294307
)
308+
assert base_report._properties['Property 1'].num_rows_subsample == 1000
295309
base_report._properties['Property 2'].get_score.assert_called_with(
296310
real_data, synthetic_data, metadata, progress_bar=None
297311
)
312+
assert base_report._properties['Property 2'].num_rows_subsample == 1000
298313
expected_info = {
299314
'report_type': 'BaseReport',
300315
'generated_date': '2020-01-05',

0 commit comments

Comments
 (0)