-
Notifications
You must be signed in to change notification settings - Fork 1
Libraries Upgrade et al. #150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: marcelo/tabsyn-ensemble
Are you sure you want to change the base?
Changes from all commits
1739f18
8a1173a
c5eeb50
8bd398e
a9265a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -139,8 +139,8 @@ def save_results_and_plot_roc_curve( | |
| plt.figure(figsize=(8, 6)) | ||
| plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.4f})") | ||
| plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") | ||
| plt.xlim([0.0, 1.0]) | ||
| plt.ylim([0.0, 1.05]) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mypy complained about this. |
||
| plt.xlim((0.0, 1.0)) | ||
| plt.ylim((0.0, 1.05)) | ||
| plt.xlabel("False Positive Rate") | ||
| plt.ylabel("True Positive Rate") | ||
| plt.title("ROC Curve") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,9 +127,15 @@ def make_dataset_from_df_with_loaded( | |
| table_metadata, | ||
| is_target_conditioned, | ||
| ) | ||
| numerical_features = {DataSplit.TRAIN.value: data[numerical_column_names].values.astype(np.float32)} | ||
| categorical_features = {DataSplit.TRAIN.value: data[categorical_column_names].to_numpy()} | ||
| targets = {DataSplit.TRAIN.value: data[[table_metadata.target_column_name]].values.astype(np.float32)} | ||
| numerical_features: dict[str, np.ndarray] = { | ||
| DataSplit.TRAIN.value: data[numerical_column_names].values.astype(np.float32) | ||
| } | ||
| categorical_features: dict[str, np.ndarray] = { | ||
| DataSplit.TRAIN.value: data[categorical_column_names].to_numpy(dtype=np.str_) | ||
| } | ||
| targets: dict[str, np.ndarray] = { | ||
| DataSplit.TRAIN.value: data[[table_metadata.target_column_name]].values.astype(np.float32) | ||
| } | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mypy complained about the typing. |
||
|
|
||
| if len(categorical_column_names) > 0: | ||
| all_categorical_features = categorical_features[DataSplit.TRAIN.value] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| from pandas.api.types import is_object_dtype, is_string_dtype | ||
| from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder | ||
|
|
||
| from midst_toolkit.common.logger import log | ||
|
|
@@ -191,9 +192,11 @@ def get_categorical_columns(dataframe: pd.DataFrame, threshold: int) -> list[str | |
| categorical_variables: list[str] = [] | ||
|
|
||
| for column_name in dataframe.columns: | ||
| # If dtype is an object (as str columns are), assume categorical | ||
| if dataframe[column_name].dtype == "object" or ( | ||
| is_column_type_numerical(dataframe, column_name) and dataframe[column_name].nunique() <= threshold | ||
| # If dtype is an object or string type, assume categorical | ||
| if ( | ||
| is_string_dtype(dataframe[column_name]) | ||
| or is_object_dtype(dataframe[column_name]) | ||
| or (is_column_type_numerical(dataframe, column_name) and dataframe[column_name].nunique() <= threshold) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll see this throughout, but Pandas typing for strings has changed quite a bit. Now they have a separate string type you can use, along with the object type. So this catches that. |
||
| ): | ||
| categorical_variables.append(column_name) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,6 +93,12 @@ def compute( | |
| self.meta_info, real_data, synthetic_data, holdout_data | ||
| ) | ||
|
|
||
| # Make sure the categorical columns are preprocessed and encoded before calling compute | ||
| self.validate_dataframe_dtypes(real_data) | ||
| self.validate_dataframe_dtypes(synthetic_data) | ||
| if holdout_data is not None: | ||
| self.validate_dataframe_dtypes(holdout_data) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added these validations to make sure that string types are not sneaking into our metrics computations. These were added wherever strings would cause issues. |
||
|
|
||
| real_data_train_tensor = torch.tensor(real_data.to_numpy()).to(self.device) | ||
| real_data_test_tensor = torch.tensor(holdout_data.to_numpy()).to(self.device) | ||
| synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device) | ||
|
|
@@ -192,6 +198,10 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict | |
| real_data_tensor = torch.tensor(real_data.to_numpy()).to(self.device) | ||
| synthetic_data_tensor = torch.tensor(synthetic_data.to_numpy()).to(self.device) | ||
|
|
||
| # Make sure the categorical columns are preprocessed and encoded before calling compute | ||
| self.validate_dataframe_dtypes(real_data) | ||
| self.validate_dataframe_dtypes(synthetic_data) | ||
|
|
||
| dcr_synthetic_to_real = [] | ||
| dcr_real_to_real = [] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,6 +125,12 @@ def compute( | |
| else: | ||
| raise ValueError(f"Unrecognized EpsilonIdentifiabilityNorm Option: {self.norm}") | ||
|
|
||
| # Make sure the categorical columns are preprocessed and encoded before calling compute | ||
| self.validate_dataframe_dtypes(filtered_real_data) | ||
| self.validate_dataframe_dtypes(filtered_synthetic_data) | ||
| if filtered_holdout_data is not None: | ||
| self.validate_dataframe_dtypes(filtered_holdout_data) | ||
|
|
||
| self.syntheval_metric = EpsilonIdentifiability( | ||
| real_data=filtered_real_data, | ||
| synt_data=filtered_synthetic_data, | ||
|
|
@@ -134,6 +140,7 @@ def compute( | |
| do_preprocessing=False, | ||
| verbose=False, | ||
| nn_dist=self.norm.value, | ||
| plot_figures=False, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the new SynthEval, they generated plots by default, which was very annoying. This shuts them off. |
||
| ) | ||
| result = self.syntheval_metric.evaluate() | ||
| result["epsilon_identifiability_risk"] = result.pop("eps_risk") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,6 +91,7 @@ def compute(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame) -> dict | |
| num_cols=self.numerical_columns, | ||
| do_preprocessing=False, | ||
| verbose=False, | ||
| plot_figures=False, | ||
| ) | ||
|
|
||
| return self.syntheval_metric.evaluate(self.confidence_level.value) | ||
| return self.syntheval_metric.evaluate(ci="sem", confidence=self.confidence_level.value) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New upgrade changed this method signature a bit. So we're forcing consistency here. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New Syntheval emits this json file and it cannot be disabled after looking at their code...