Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/code_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405
uses: actions/setup-python@ece7cb06caefa5fff74198d8649806c4678c61a1
with:
python-version-file: ".python-version"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
enable-cache: true

- name: Set up Python
uses: actions/setup-python@v6.2.0
uses: actions/setup-python@v6.3.0
with:
python-version-file: ".python-version"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405
uses: actions/setup-python@ece7cb06caefa5fff74198d8649806c4678c61a1
with:
python-version-file: ".python-version"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405
uses: actions/setup-python@ece7cb06caefa5fff74198d8649806c4678c61a1
with:
python-version-file: ".python-version"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405
uses: actions/setup-python@ece7cb06caefa5fff74198d8649806c4678c61a1
with:
python-version-file: ".python-version"

Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ tests/integration/attacks/ensemble/assets/workspace
tests/integration/assets/tabsyn/processed_data
tests/integration/assets/tabsyn/results

# Emitted SynthEval analysis config file during metric creation. Unfortunately cannot be turned off...
SE_analysis_config.json

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New Syntheval emits this json file and it cannot be disabled after looking at their code...


# Training Logs
*.err
*.out
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,6 @@ max-doc-length = 119
markers = [
"integration_test: marks tests as integration tests",
]
env = [
"OMP_NUM_THREADS=1", # Forces single threading in tests to avoid segfaults due to nested spawning
]
4 changes: 2 additions & 2 deletions src/midst_toolkit/attacks/tartan_federer/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def save_results_and_plot_roc_curve(
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.4f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy complained about this.

plt.xlim((0.0, 1.0))
plt.ylim((0.0, 1.05))
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,15 @@ def make_dataset_from_df_with_loaded(
table_metadata,
is_target_conditioned,
)
numerical_features = {DataSplit.TRAIN.value: data[numerical_column_names].values.astype(np.float32)}
categorical_features = {DataSplit.TRAIN.value: data[categorical_column_names].to_numpy()}
targets = {DataSplit.TRAIN.value: data[[table_metadata.target_column_name]].values.astype(np.float32)}
numerical_features: dict[str, np.ndarray] = {
DataSplit.TRAIN.value: data[numerical_column_names].values.astype(np.float32)
}
categorical_features: dict[str, np.ndarray] = {
DataSplit.TRAIN.value: data[categorical_column_names].to_numpy(dtype=np.str_)
}
targets: dict[str, np.ndarray] = {
DataSplit.TRAIN.value: data[[table_metadata.target_column_name]].values.astype(np.float32)
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy complained about the typing.


if len(categorical_column_names) > 0:
all_categorical_features = categorical_features[DataSplit.TRAIN.value]
Expand Down
9 changes: 6 additions & 3 deletions src/midst_toolkit/data_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import pandas as pd
from pandas.api.types import is_object_dtype, is_string_dtype
from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder

from midst_toolkit.common.logger import log
Expand Down Expand Up @@ -191,9 +192,11 @@ def get_categorical_columns(dataframe: pd.DataFrame, threshold: int) -> list[str
categorical_variables: list[str] = []

for column_name in dataframe.columns:
# If dtype is an object (as str columns are), assume categorical
if dataframe[column_name].dtype == "object" or (
is_column_type_numerical(dataframe, column_name) and dataframe[column_name].nunique() <= threshold
# If dtype is an object or string type, assume categorical
if (
is_string_dtype(dataframe[column_name])
or is_object_dtype(dataframe[column_name])
or (is_column_type_numerical(dataframe, column_name) and dataframe[column_name].nunique() <= threshold)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll see this throughout, but Pandas typing for strings has changed quite a bit. Now they have a separate string type you can use, along with the object type. So this catches that.

):
categorical_variables.append(column_name)

Expand Down
19 changes: 19 additions & 0 deletions src/midst_toolkit/evaluation/metrics_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
"""
raise NotImplementedError("Inheriting class must define compute")

def validate_dataframe_dtypes(self, dataframe: pd.DataFrame) -> None:
"""
Validates that the dataframe does not contain string types. This is a requirement for many metrics in this
library, which require categorical columns to be preprocessed and encoded prior to computation.

Args:
dataframe: dataframe to validate.

Raises:
ValueError: If the dataframe contains string types.
"""
any_string_dtypes = any(
(isinstance(dtype, pd.StringDtype) or dtype.name == "str") for dtype in dataframe.dtypes
)
if any_string_dtypes:
raise ValueError(
"Dataframe contains string types. Categorical columns must be preprocessed and encoded prior to computation."
)


class SynthEvalMetric(MetricBase, ABC):
def __init__(
Expand Down
10 changes: 10 additions & 0 deletions src/midst_toolkit/evaluation/privacy/distance_closest_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def compute(
self.meta_info, real_data, synthetic_data, holdout_data
)

# Make sure the categorical columns are preprocessed and encoded before calling compute
self.validate_dataframe_dtypes(real_data)
self.validate_dataframe_dtypes(synthetic_data)
if holdout_data is not None:
self.validate_dataframe_dtypes(holdout_data)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added these validations to make sure that string types are not sneaking into our metrics computations. These were added wherever strings would cause issues.


real_data_train_tensor = torch.tensor(real_data.to_numpy()).to(self.device)
real_data_test_tensor = torch.tensor(holdout_data.to_numpy()).to(self.device)
synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device)
Expand Down Expand Up @@ -192,6 +198,10 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
real_data_tensor = torch.tensor(real_data.to_numpy()).to(self.device)
synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device)

# Make sure the categorical columns are preprocessed and encoded before calling compute
self.validate_dataframe_dtypes(real_data)
self.validate_dataframe_dtypes(synthetic_data)

dcr_synthetic_to_real = []
dcr_real_to_real = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ def compute(
else:
raise ValueError(f"Unrecognized EpsilonIdentifiabilityNorm Option: {self.norm}")

# Make sure the categorical columns are preprocessed and encoded before calling compute
self.validate_dataframe_dtypes(filtered_real_data)
self.validate_dataframe_dtypes(filtered_synthetic_data)
if filtered_holdout_data is not None:
self.validate_dataframe_dtypes(filtered_holdout_data)

self.syntheval_metric = EpsilonIdentifiability(
real_data=filtered_real_data,
synt_data=filtered_synthetic_data,
Expand All @@ -134,6 +140,7 @@ def compute(
do_preprocessing=False,
verbose=False,
nn_dist=self.norm.value,
plot_figures=False,

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new SynthEval, they generated plots by default, which was very annoying. This shuts them off.

)
result = self.syntheval_metric.evaluate()
result["epsilon_identifiability_risk"] = result.pop("eps_risk")
Expand Down
5 changes: 5 additions & 0 deletions src/midst_toolkit/evaluation/privacy/hitting_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
filtered_real_data = real_data[self.all_columns]
filtered_synthetic_data = synthetic_data[self.all_columns]

# Make sure the categorical columns are preprocessed and encoded before calling compute
self.validate_dataframe_dtypes(filtered_real_data)
self.validate_dataframe_dtypes(filtered_synthetic_data)

self.syntheval_metric = SynthEvalHittingRate(
real_data=filtered_real_data,
synt_data=filtered_synthetic_data,
Expand All @@ -86,6 +90,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
num_cols=self.numerical_columns,
do_preprocessing=False,
verbose=False,
plot_figures=False,
)
result = self.syntheval_metric.evaluate(self.hitting_threshold)
result["hitting_rate"] = result.pop("hit rate")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def compute(
self.meta_info, real_data, synthetic_data, holdout_data
)

# Make sure the categorical columns are preprocessed and encoded before calling compute
self.validate_dataframe_dtypes(real_data)
self.validate_dataframe_dtypes(synthetic_data)
if holdout_data is not None:
self.validate_dataframe_dtypes(holdout_data)

synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device)
real_data_tensor = torch.tensor(real_data.to_numpy()).to(self.device)
mean_nndr, nndr_standard_error = self._compute_mean_nearest_neighbor_distance_ratio(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
num_cols=self.numerical_columns,
do_preprocessing=False,
verbose=False,
plot_figures=False,
)

return self.syntheval_metric.evaluate(self.confidence_level.value)
return self.syntheval_metric.evaluate(ci="sem", confidence=self.confidence_level.value)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New upgrade changed this method signature a bit. So we're forcing consistency here.

Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
num_cols=self.numerical_columns,
do_preprocessing=False,
verbose=False,
plot_figures=False,
)

return self.syntheval_metric.evaluate(mixed_corr=self.compute_mixed_correlations, return_mats=False)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from syntheval.metrics.utility.metric_dimensionwise_means import MetricClassName as SynthEvalDwm
from syntheval.metrics.utility.metric_dimensionwise_means import DimensionWiseMeans

from midst_toolkit.evaluation.metrics_base import SynthEvalMetric

Expand Down Expand Up @@ -27,13 +27,14 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
if self.do_preprocess:
real_data, synthetic_data = self.preprocess(real_data, synthetic_data)

self.syntheval_metric = SynthEvalDwm(
self.syntheval_metric = DimensionWiseMeans(
real_data=real_data,
synt_data=synthetic_data,
cat_cols=self.categorical_columns,
num_cols=self.numerical_columns,
do_preprocessing=False,
verbose=False,
plot_figures=False,
)

result = self.syntheval_metric.evaluate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict
num_cols=self.numerical_columns,
do_preprocessing=False,
verbose=False,
plot_figures=False,
)

return self.syntheval_metric.evaluate(sig_lvl=self.significance_level, n_perms=self.permutations)
Loading