diff --git a/examples/bcauss.py b/examples/bcauss.py new file mode 100644 index 000000000..36b7209a9 --- /dev/null +++ b/examples/bcauss.py @@ -0,0 +1,181 @@ +"""Reported (reproduced) E_ATT of BCAUSS based on Table 1 of the paper! +BAUSS + in_sample 0.02 (0.0284). +BAUSS + out_of_sample 0.05 +/- 0.02 (0.0290). +""" +import argparse +import copy +import os.path as osp + +import torch +from torch.optim.lr_scheduler import ExponentialLR +from tqdm import tqdm + +from torch_frame import TensorFrame, stype +from torch_frame.data import DataLoader, Dataset +from torch_frame.datasets import Jobs +from torch_frame.nn.models import BCAUSS + +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", type=int, default=1024) +parser.add_argument("--lr", type=float, default=0.00001) +parser.add_argument("--epochs", type=int, default=5) +parser.add_argument("--seed", type=int, default=2) +parser.add_argument("--feature-engineering", action="store_true", default=True) +parser.add_argument("--out-of-distribution", action="store_true", default=True) +args = parser.parse_args() + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") +dataset = Jobs(root=path, feature_engineering=args.feature_engineering) +ATT = dataset.get_att() +print(f"ATT is {ATT}") + +torch.manual_seed(args.seed) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +dataset.materialize(path=osp.join(path, "data.pt")) + +dataset = dataset.shuffle() +if args.out_of_distribution: + if dataset.split_col is None: + train_dataset, val_dataset, test_dataset = dataset[:0.62], dataset[ + 0.62:0.80], dataset[0.80:] + else: + train_dataset, _, test_dataset = dataset.split() + train_dataset, val_dataset = train_dataset[:0.775], dataset[0.775:] + # Calculating the validation dataset + treated_df = val_dataset.df[(val_dataset.df['source'] == 1) + & (val_dataset.df['treated'] == 1)] + treated_val_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_val_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_val_dataset.materialize(path=osp.join(path, "treated_val_data.pt")) + control_val_dataset.materialize(path=osp.join(path, "control_val_data.pt")) + # Calculating the evaluation dataset + treated_df = test_dataset.df[(test_dataset.df['source'] == 1) + & (test_dataset.df['treated'] == 1)] + treated_test_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_test_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_test_dataset.materialize( + path=osp.join(path, "treated_test_data.pt")) + control_test_dataset.materialize( + path=osp.join(path, "control_test_data.pt")) +else: + train_dataset = dataset + + # Calculating the evaluation dataset + treated_df = dataset.df[(dataset.df['source'] == 1) + & (dataset.df['treated'] == 1)] + treated_test_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_test_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_test_dataset.materialize( + path=osp.join(path, "treated_eval_data.pt")) + control_test_dataset.materialize( + path=osp.join(path, "control_eval_data.pt")) + +train_tensor_frame = train_dataset.tensor_frame +treatment_idx = train_tensor_frame.col_names_dict[stype.categorical].index( + 'treated') +if args.out_of_distribution: + val_tensor_frame = val_dataset.tensor_frame + test_tensor_frame = test_dataset.tensor_frame + treated_val_tensor_frame = treated_val_dataset.tensor_frame + # This is a bad hack. Currently the materialization logic would override + # 1's to 0's due to 1's being the popular class + treated_val_tensor_frame.feat_dict[stype.categorical][:, + treatment_idx] = 1. + control_val_tensor_frame = control_val_dataset.tensor_frame + +treated_test_tensor_frame = treated_test_dataset.tensor_frame +# This is a bad hack. Currently the materialization logic would override 1's +# to 0's due to 1's being the popular class +treated_test_tensor_frame.feat_dict[stype.categorical][:, treatment_idx] = 1. +control_test_tensor_frame = control_test_dataset.tensor_frame + +train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, + shuffle=True) +# val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) +# test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) + +model = BCAUSS( + channels=train_tensor_frame.num_cols - 1, + hidden_channels=200, + decoder_hidden_channels=100, + out_channels=1, + col_stats=dataset.col_stats if not args.feature_engineering else None, + col_names_dict=train_tensor_frame.col_names_dict + if not args.feature_engineering else None, +).to(device) + +optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) +lr_scheduler = ExponentialLR(optimizer, gamma=0.95) + +is_classification = True + + +def train(epoch: int) -> float: + model.train() + loss_accum = total_count = 0 + + for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): + tf = tf.to(device) + out, balance_score, treated_mask = model(tf, + treatment_index=treatment_idx) + loss = ( + (torch.sum(treated_mask * torch.square(tf.y - out.squeeze(-1))) + + torch.sum(~treated_mask * torch.square(tf.y - out.squeeze(-1)))) / + len(treated_mask) + balance_score) + optimizer.zero_grad() + loss.backward() + loss_accum += float(loss) * len(out) + total_count += len(out) + optimizer.step() + return loss_accum / total_count + + +@torch.no_grad() +def eval(treated: TensorFrame, control: TensorFrame) -> float: + model.eval() + + treated = treated.to(device) + treated_effect, _, _ = model(treated, treatment_idx) + + control = control.to(device) + control_effect, _, _ = model(control, treatment_idx) + + return torch.abs(ATT - torch.mean(treated_effect - control_effect)) + + +best_val_metric = float('inf') +best_test_metric = float('inf') + +for epoch in range(1, args.epochs + 1): + train_loss = train(epoch) + error = eval(treated_test_tensor_frame, control_test_tensor_frame) + if args.out_of_distribution: + val_error = eval(treated_val_tensor_frame, control_val_tensor_frame) + if val_error < best_val_metric: + best_val_metric = val_error + best_test_metric = error + print( + f'Train Loss: {train_loss:.4f} Val Error_ATT: {val_error:.4f},\n' + f' Error_ATT: {error:.4f}\n') + else: + print(f'Train Loss: {train_loss:.4f} Error_ATT: {error:.4f},\n') + +if args.out_of_distribution: + print(f'Best Val Error: {best_val_metric:.4f}, ' + f'Best Test Error: {best_test_metric:.4f}') diff --git a/examples/causalml.py b/examples/causalml.py new file mode 100644 index 000000000..e8e80a5be --- /dev/null +++ b/examples/causalml.py @@ -0,0 +1,250 @@ +import argparse +import copy +import os.path as osp + +import torch +from torch.optim.lr_scheduler import ExponentialLR +from tqdm import tqdm + +from torch_frame import TensorFrame, stype +from torch_frame.data import DataLoader, Dataset +from torch_frame.datasets import IHDP +from torch_frame.nn.models import BCAUSS, CFR + +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", type=int, default=200) +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument("--epochs", type=int, default=3000) +parser.add_argument("--seed", type=int, default=2) +parser.add_argument('--model', type=str, default='cfr-mdd', + choices=["bcauss", "cfr-mdd"]) +parser.add_argument("--feature-engineering", action="store_true", default=True) +parser.add_argument("--out-of-distribution", action="store_true", + default=False) +parser.add_argument("--lambda-reg", type=float, default=0.01, + help="l2 normalization score") +parser.add_argument("--split-num", type=int, default=0) +args = parser.parse_args() + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "ihdp") +dataset = IHDP(root=path, split_num=args.split_num) +ATE = dataset.get_att() +print(f"True Average Treatment Effect is {ATE}") + +torch.manual_seed(args.seed) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +dataset.materialize(path=osp.join(path, f"data_{args.split_num}.pt")) + +dataset = dataset.shuffle() +within_sample_dataset, _, test_dataset = dataset.split() +# Train and Validation set is to compute within distribution metric +# Test set is to compute out of distribution metric +train_dataset, val_dataset = within_sample_dataset[: + 0.7], within_sample_dataset[ + 0.7:] + +# Counterfactual Set is to predict the outcome of the counterfactual +counterfactual_train_df = copy.deepcopy(train_dataset.df) +counterfactual_train_df['treated'] = 1 - counterfactual_train_df['treated'] +counterfactual_train_df['target'] = counterfactual_train_df[ + 'counterfactual_target'] +counterfactual_train_dataset = Dataset(counterfactual_train_df, + dataset.col_to_stype, + target_col='target') + +counterfactual_train_dataset.materialize( + path=osp.join(path, f"counterfactual_train_data_{args.split_num}.pt")) + +counterfactual_within_sample_df = copy.deepcopy(within_sample_dataset.df) +counterfactual_within_sample_df[ + 'treated'] = 1 - counterfactual_within_sample_df['treated'] +counterfactual_within_sample_df['target'] = counterfactual_within_sample_df[ + 'counterfactual_target'] +counterfactual_within_sample_dataset = Dataset(counterfactual_within_sample_df, + dataset.col_to_stype, + target_col='target') + +counterfactual_within_sample_dataset.materialize(path=osp.join( + path, f"counterfactual_within_sample_data_{args.split_num}.pt")) + +counterfactual_val_df = copy.deepcopy(val_dataset.df) +counterfactual_val_df['treated'] = 1 - counterfactual_val_df['treated'] +counterfactual_val_df['target'] = counterfactual_val_df[ + 'counterfactual_target'] +counterfactual_val_dataset = Dataset(counterfactual_val_df, + dataset.col_to_stype, target_col='target') + +counterfactual_val_dataset.materialize( + path=osp.join(path, f"counterfactual_val_data_{args.split_num}.pt")) + +counterfactual_test_df = copy.deepcopy(test_dataset.df) +counterfactual_test_df['treated'] = 1 - counterfactual_test_df['treated'] +counterfactual_test_df['target'] = counterfactual_test_df[ + 'counterfactual_target'] +counterfactual_test_dataset = Dataset(counterfactual_test_df, + dataset.col_to_stype, + target_col='target') + +counterfactual_test_dataset.materialize( + path=osp.join(path, f"counterfactual_test_data_{args.split_num}.pt")) + +train_tensor_frame = train_dataset.tensor_frame + +treatment_idx = train_tensor_frame.col_names_dict[stype.categorical].index( + 'treated') +assert torch.all( + torch.tensor(train_dataset.df['treated'].values, dtype=torch.long) == + train_tensor_frame.feat_dict[stype.categorical][:, treatment_idx]) + +counterfactual_train_tensor_frame = counterfactual_train_dataset.tensor_frame +counterfactual_train_tensor_frame.feat_dict[ + stype.categorical][:, treatment_idx] = torch.tensor( + counterfactual_train_df['treated'].values) +val_tensor_frame = val_dataset.tensor_frame +counterfactual_val_tensor_frame = counterfactual_val_dataset.tensor_frame +counterfactual_val_tensor_frame.feat_dict[ + stype.categorical][:, treatment_idx] = torch.tensor( + counterfactual_val_df['treated'].values) + +within_sample_tensor_frame = within_sample_dataset.tensor_frame +counterfactual_within_sample_tensor_frame = ( + counterfactual_within_sample_dataset.tensor_frame) +counterfactual_within_sample_tensor_frame.feat_dict[ + stype.categorical][:, treatment_idx] = torch.tensor( + counterfactual_within_sample_df['treated'].values) + +test_tensor_frame = test_dataset.tensor_frame +counterfactual_test_tensor_frame = counterfactual_test_dataset.tensor_frame +counterfactual_test_tensor_frame.feat_dict[ + stype.categorical][:, treatment_idx] = torch.tensor( + counterfactual_test_df['treated'].values) + +train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, + shuffle=True) +# val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) +# test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) + +if args.model == 'cfr-mmd': + model = CFR( + channels=train_tensor_frame.num_cols - 1, + hidden_channels=200, + decoder_hidden_channels=100, + out_channels=1, + col_stats=dataset.col_stats if not args.feature_engineering else None, + col_names_dict=train_tensor_frame.col_names_dict + if not args.feature_engineering else None, + ).to(device) +else: + model = BCAUSS( + channels=train_tensor_frame.num_cols - 1, + hidden_channels=200, + decoder_hidden_channels=100, + out_channels=1, + col_stats=dataset.col_stats if not args.feature_engineering else None, + col_names_dict=train_tensor_frame.col_names_dict + if not args.feature_engineering else None, + ).to(device) + +optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) +lr_scheduler = ExponentialLR(optimizer, gamma=0.8) + + +def train(epoch: int) -> float: + model.train() + loss_accum = total_count = 0 + + for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): + tf = tf.to(device) + if args.model == 'cfr-mdd': + out, ipm = model(tf, treatment_index=treatment_idx) + treatment_val = tf.feat_dict[stype.categorical][:, treatment_idx] + avg_treatment = torch.sum(treatment_val) / len(treatment_val) + w_val = treatment_val / (2 * avg_treatment) + ( + 1 - treatment_val) / (2 - 2 * avg_treatment) + rmse = torch.sqrt(torch.mean(torch.square(tf.y - out.squeeze(-1)))) + loss = torch.mean(w_val * rmse) + ipm + else: + out, balance_score = model(tf, treatment_index=treatment_idx) + treated_mask = tf.feat_dict[stype.categorical][:, + treatment_idx] == 1 + loss = ( + (torch.sum(treated_mask * torch.square(tf.y - out.squeeze(-1))) + + torch.sum( + ~treated_mask * torch.square(tf.y - out.squeeze(-1)))) / + len(treated_mask) + balance_score) + optimizer.zero_grad() + loss.backward() + if args.model == 'cfr-mdd': + for name, param in model.named_parameters(): + if name.startswith('treatment_decoder') or name.startswith( + 'control_decoder'): + loss += args.lambda_reg * torch.sum(param**2) + loss_accum += float(loss) * len(out) + total_count += len(out) + optimizer.step() + return loss_accum / total_count + + +@torch.no_grad() +def eval(factual: TensorFrame, counterfactual: TensorFrame): + model.eval() + + factual = factual.to(device) + factual_effect, _ = model(factual, treatment_idx) + # RMSE for factual predictions + rmse_fact = torch.sqrt(torch.mean(torch.square(factual.y - + factual_effect))) + + counterfactual = counterfactual.to(device) + counterfactual_effect, _ = model(counterfactual, treatment_idx) + # RMSE for counterfactual predictions + rmse_cfact = torch.sqrt( + torch.mean(torch.square(counterfactual.y - counterfactual_effect))) + eff_pred = counterfactual_effect - factual_effect + t = factual.feat_dict[stype.categorical][:, treatment_idx] + eff_pred[t > 0] = -eff_pred[t > 0] # f(x, 1) - f(x, 0) + ate_pred = torch.mean(eff_pred.squeeze(-1)) + bias_ate = torch.abs(ate_pred - ATE) + + pehe = torch.sqrt(torch.mean(torch.square(eff_pred - ATE))) + return rmse_fact, rmse_cfact, bias_ate, pehe + + +best_val_error = float('inf') +best_test_error = float('inf') +best_val_pehe = float('inf') +best_test_pehe = float('inf') + +for epoch in range(1, args.epochs + 1): + train_loss = train(epoch) + train_rmse, train_rmse_cfact, train_error, train_pehe = eval( + train_tensor_frame, counterfactual_train_tensor_frame) + val_rmse, val_rmse_cfact, val_error, val_pehe = eval( + val_tensor_frame, counterfactual_val_tensor_frame) + test_rmse, test_rmse_cfact, test_error, test_pehe = eval( + test_tensor_frame, counterfactual_test_tensor_frame) + within_rmse, within_rmse_cfact, within_error, within_pehe = eval( + within_sample_tensor_frame, counterfactual_within_sample_tensor_frame) + + if within_pehe < best_val_pehe: + best_val_error = within_error + best_test_error = test_error + best_val_pehe = within_pehe + best_test_pehe = test_pehe + print( + f'Train Loss: {train_loss:.4f} Train Factual RMSE: {train_rmse:.4f} ' + f'Train Counterfactual RMSE: {train_rmse_cfact:.4f}, \n' + f'Val Factual RMSE: {val_rmse:.4f} ' + f'Val Counterfactual RMSE: {val_rmse_cfact:.4f}, ' + f'Within Sample PEHE: {within_pehe:.4f}, ' + f'Within Sample Error: {within_error:.4f}, \n' + f'Out of Distribution Factual RMSE: {val_rmse:.4f} ' + f'Out of Distribution Counterfactual RMSE: {val_rmse_cfact:.4f}, ' + f'Out of Distribution PEHE: {test_pehe:.4f}, ' + f'Out of Distribution Error: {test_error:.4f}, \n') + +print(f'Best Within Sample Error: {best_val_error:.4f}, ' + f'Best Within Sample PEHE: {best_val_pehe}, \n' + f'Best Out of Distribution Error: {best_test_error:.4f}, ' + f'Best Out of Distribution PEHE: {best_test_pehe:.4f}') diff --git a/test/nn/decoder/test_excelformer_decoder.py b/test/nn/decoder/test_excelformer_decoder.py index 2d92a116d..aee4c2f8a 100644 --- a/test/nn/decoder/test_excelformer_decoder.py +++ b/test/nn/decoder/test_excelformer_decoder.py @@ -5,7 +5,7 @@ def test_excelformer_decoder(): batch_size = 10 - num_cols = 18 + num_cols = 8 in_channels = 8 out_channels = 3 x = torch.randn(batch_size, num_cols, in_channels) diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index f838f1ce9..052cb21f3 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -18,6 +18,8 @@ from .amazon_fine_food_reviews import AmazonFineFoodReviews from .diamond_images import DiamondImages from .huggingface_dataset import HuggingFaceDatasetDict +from .jobs import Jobs +from .ihdp import IHDP real_world_datasets = [ 'Titanic', @@ -36,6 +38,8 @@ 'Mercari', 'AmazonFineFoodReviews', 'DiamondImages', + 'Jobs', + 'IHDP', ] synthetic_datasets = [ diff --git a/torch_frame/datasets/ihdp.py b/torch_frame/datasets/ihdp.py new file mode 100644 index 000000000..6dd8f7d0d --- /dev/null +++ b/torch_frame/datasets/ihdp.py @@ -0,0 +1,96 @@ +import os.path as osp +import zipfile + +import numpy as np +import pandas as pd + +import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM + + +class IHDP(torch_frame.data.Dataset): + r"""Counterfactual target is generated with knn.""" + train_url = 'https://www.fredjo.com/files/ihdp_npci_1-1000.train.npz.zip' + test_url = 'https://www.fredjo.com/files/ihdp_npci_1-1000.test.npz.zip' + + def __init__(self, root: str, split_num: int = 0): + train_path = self.download_url(self.train_url, root) + test_path = self.download_url(self.test_url, root) + self.split_num = split_num + folder_path = osp.dirname(train_path) + with zipfile.ZipFile(train_path, 'r') as zip_ref: + zip_ref.extractall(folder_path) + with zipfile.ZipFile(test_path, 'r') as zip_ref: + zip_ref.extractall(folder_path) + train_np = np.load(osp.join(folder_path, 'ihdp_npci_1-1000.train.npz')) + test_np = np.load(osp.join(folder_path, 'ihdp_npci_1-1000.test.npz')) + self.train_np = train_np + self.test_np = test_np + train_data = np.concatenate([ + train_np.f.t[:, split_num].reshape(-1, 1), + train_np.f.x[:, :, split_num], train_np.f.yf[:, split_num].reshape( + -1, 1), train_np.f.ycf[:, split_num].reshape(-1, 1) + ], axis=1) + test_data = np.concatenate([ + test_np.f.t[:, split_num].reshape(-1, 1), + test_np.f.x[:, :, split_num], test_np.f.yf[:, split_num].reshape( + -1, 1), test_np.f.ycf[:, split_num].reshape(-1, 1) + ], axis=1) + train_df = pd.DataFrame( + train_data, columns=['treated'] + + [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target'] + + ['counterfactual_target']) + train_df['split'] = SPLIT_TO_NUM['train'] + test_df = pd.DataFrame( + test_data, columns=['treated'] + + [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target'] + + ['counterfactual_target']) + test_df['split'] = SPLIT_TO_NUM['test'] + df = pd.concat([train_df, test_df], axis=0) + col_to_stype = { + 'treated': torch_frame.categorical, + 'Col_0': torch_frame.numerical, + 'Col_1': torch_frame.numerical, + 'Col_2': torch_frame.numerical, + 'Col_3': torch_frame.numerical, + 'Col_4': torch_frame.numerical, + 'Col_5': torch_frame.numerical, + 'Col_6': torch_frame.categorical, + 'Col_7': torch_frame.categorical, + 'Col_8': torch_frame.categorical, + 'Col_9': torch_frame.categorical, + 'Col_10': torch_frame.categorical, + 'Col_11': torch_frame.categorical, + 'Col_12': torch_frame.categorical, + 'Col_13': torch_frame.categorical, + 'Col_14': torch_frame.categorical, + 'Col_15': torch_frame.categorical, + 'Col_16': torch_frame.categorical, + 'Col_17': torch_frame.categorical, + 'Col_18': torch_frame.categorical, + 'Col_19': torch_frame.categorical, + 'Col_20': torch_frame.categorical, + 'Col_21': torch_frame.categorical, + 'Col_22': torch_frame.categorical, + 'Col_23': torch_frame.categorical, + 'Col_24': torch_frame.categorical, + 'target': torch_frame.numerical, + } + super().__init__(df, col_to_stype=col_to_stype, target_col='target', + split_col='split') + + def get_att(self): + r"""Obtain the ATT(true Average Treatment effect on Treated)). + + Returns: + float: The ATT score from the original randomized experiments. + """ + mu1 = np.concatenate([ + self.train_np.f.mu1[:, self.split_num], + self.test_np.f.mu1[:, self.split_num] + ], axis=0) + mu0 = np.concatenate([ + self.train_np.f.mu0[:, self.split_num], + self.test_np.f.mu0[:, self.split_num] + ], axis=0) + return np.mean(mu1) - np.mean(mu0) diff --git a/torch_frame/datasets/jobs.py b/torch_frame/datasets/jobs.py new file mode 100644 index 000000000..9eb525ce5 --- /dev/null +++ b/torch_frame/datasets/jobs.py @@ -0,0 +1,144 @@ +import numpy as np +import pandas as pd + +import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM + + +class Jobs(torch_frame.data.Dataset): + r"""The Jobs dataset from "Evaluating the Econometric + Evaluations of Training Programs with Experimental Data" + by Robert Lalonde. There are two versions of the data. One + version is the `Dehejia subsampe. + `_. The version + is a subsample of Lalonde's original dataset because it includes + one more feature--RE74 (earnings in 1974, two years prior the + treatment). The use of more than one year of pretreatment + earnings is key in accurately estimating the treatment effect, + because many people who volunteer for training programs experience + a drop in their earnings just prior to entering the training program. + + Another version is a version containing additional columns obtained + from feature engineering, from + `Dr.Johansson's website _`. + + The target in the dataset is index to the target tensor. The target + tensor is a :obj:`Tensor` of size (num_rows, 2), where the first + column represents the target and the second column represents the + treatment. + """ + dehejia_treated_url = 'https://users.nber.org/~rdehejia/data/nswre74_treated.txt' # noqa + dehejia_control_url = 'https://users.nber.org/~rdehejia/data/nswre74_control.txt' # noqa + psid_url = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' + train_url = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz' + test_url = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.test.npz' + + def __init__(self, root: str, feature_engineering: bool = False): + if feature_engineering: + split = 0 + train = self.download_url(self.train_url, root) + test = self.download_url(self.test_url, root) + train_np = np.load(train) + test_np = np.load(test) + train_data = np.concatenate([ + train_np.f.t[:, split].reshape(-1, 1), + train_np.f.x[:, :, split], train_np.f.e[:, split].reshape( + -1, 1), train_np.f.yf[:, split].reshape(-1, 1) + ], axis=1) + test_data = np.concatenate([ + test_np.f.t[:, split].reshape(-1, 1), test_np.f.x[:, :, split], + test_np.f.e[:, split].reshape( + -1, 1), test_np.f.yf[:, split].reshape(-1, 1) + ], axis=1) + train_df = pd.DataFrame( + train_data, columns=['treated'] + + [f'Col_{i}' + for i in range(train_np.f.x.shape[1])] + ['source', 'target']) + train_df['split'] = SPLIT_TO_NUM['train'] + test_df = pd.DataFrame( + test_data, columns=['treated'] + + [f'Col_{i}' + for i in range(train_np.f.x.shape[1])] + ['source', 'target']) + test_df['split'] = SPLIT_TO_NUM['test'] + df = pd.concat([train_df, test_df], axis=0) + col_to_stype = { + 'treated': torch_frame.categorical, + 'Col_0': torch_frame.numerical, + 'Col_1': torch_frame.numerical, + 'Col_2': torch_frame.categorical, + 'Col_3': torch_frame.categorical, + 'Col_4': torch_frame.categorical, + 'Col_5': torch_frame.categorical, + 'Col_6': torch_frame.numerical, + 'Col_7': torch_frame.numerical, + 'Col_8': torch_frame.numerical, + 'Col_9': torch_frame.numerical, + 'Col_10': torch_frame.numerical, + 'Col_11': torch_frame.numerical, + 'Col_12': torch_frame.numerical, + 'Col_13': torch_frame.categorical, + 'Col_14': torch_frame.categorical, + 'Col_15': torch_frame.numerical, + 'Col_16': torch_frame.categorical, + 'target': torch_frame.categorical + } + super().__init__(df, col_to_stype, target_col='target', + split_col='split') + else: + # National Supported Work Demonstration + nsw_treated = self.download_url(self.dehejia_treated_url, root) + nsw_control = self.download_url(self.dehejia_control_url, root) + # Population Survey of Income Dynamics + psid = self.download_url(self.psid_url, root) + names = [ + 'treated', 'age', 'education', 'Black', 'Hispanic', 'married', + 'nodegree', 'RE74', 'RE75', 'RE78' + ] + + nsw_treated_df = pd.read_csv( + nsw_treated, + sep='\s+', # noqa + names=names) + assert (nsw_treated_df['treated'] == 1).all() + nsw_treated_df['source'] = 1 + + nsw_control_df = pd.read_csv( + nsw_control, + sep='\s+', # noqa + names=names) + assert (nsw_control_df['treated'] == 0).all() + nsw_control_df['source'] = 1 + + psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa + assert (psid_df['treated'] == 0).all() + psid_df['source'] = 0 + + df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0) + df['target'] = df['RE78'] != 0 + + col_to_stype = { + 'treated': torch_frame.categorical, + 'age': torch_frame.numerical, + 'education': torch_frame.numerical, + 'Black': torch_frame.categorical, + 'Hispanic': torch_frame.categorical, + 'married': torch_frame.categorical, + 'nodegree': torch_frame.categorical, + 'RE74': torch_frame.numerical, + 'RE75': torch_frame.numerical, + 'target': torch_frame.categorical, + } + super().__init__(df, col_to_stype, target_col='target') + self.df = df + + def get_att(self): + r"""Obtain the ATT(true Average Treatment effect on Treated)). + + Returns: + float: The ATT score from the original randomized experiments. + """ + df = self.df[self.df['source'] == 1] + treated = df[df['treated'] == 1] + control = df[df['treated'] == 0] + return sum(treated['target']) / len(treated) - sum( + control['target']) / len(control) diff --git a/torch_frame/nn/models/__init__.py b/torch_frame/nn/models/__init__.py index ef3c8c4c2..4268e7843 100644 --- a/torch_frame/nn/models/__init__.py +++ b/torch_frame/nn/models/__init__.py @@ -6,13 +6,10 @@ from .resnet import ResNet from .tab_transformer import TabTransformer from .mlp import MLP +from .bcauss import BCAUSS +from .cfr import CFR __all__ = classes = [ - 'Trompt', - 'FTTransformer', - 'ExcelFormer', - 'TabNet', - 'ResNet', - 'TabTransformer', - 'MLP', + 'Trompt', 'FTTransformer', 'ExcelFormer', 'TabNet', 'ResNet', + 'TabTransformer', 'MLP', 'BCAUSS', 'CFR' ] diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py new file mode 100644 index 000000000..ff66d42eb --- /dev/null +++ b/torch_frame/nn/models/bcauss.py @@ -0,0 +1,168 @@ +from typing import Any, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Linear, Module, ReLU, Sequential + +import torch_frame +from torch_frame import TensorFrame, categorical, numerical +from torch_frame.data.stats import StatType + + +class MLPBlock(Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.model = Sequential(*[ + Linear(in_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, out_channels) + ]) + + def reset_parameters(self) -> None: + for block in self.model: + if isinstance(block, Linear): + block.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + r"""Transforming :obj:`x` into output predictions. + + Args: + x (Tensor): Input column-wise tensor of shape + [batch_size, num_cols, in_channels] + + Returns: + Tensor: [batch_size, out_channels]. + """ + return self.model(x) + + +class BalanceScoreEstimator(Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.model = Linear(in_channels, out_channels) + + def reset_parameters(self): + self.model.reset_parameters() + + def forward(self, x: Tensor): + return F.sigmoid(self.model(x)) + + +class BCAUSS(Module): + r"""The BCAUSS model introduced in the + `"Learning end-to-end patient representations through self-supervised + covariate balancing for causal treatment effect estimation" + `_ + paper. + + + .. note:: + + For an example of using ExcelFormer, see `examples/bcauss.py + `_. + + Args: + in_channels (int): Input channel dimensionality + hidden_channels (int): Hidden channel dimensionality + decoder_hidden_channels (int): Hidden channel dimensionality + in decoder. + out_channels (int): Output channels dimensionality + col_stats(dict[str,dict[:class:`torch_frame.data.stats.StatType`,Any]]): + A dictionary that maps column name into stats. + Available as :obj:`dataset.col_stats`. + col_names_dict (dict[:obj:`torch_frame.stype`, list[str]]): A + dictionary that maps stype to a list of column names. The column + names are sorted based on the ordering that appear in + :obj:`tensor_frame.feat_dict`. Available as + :obj:`tensor_frame.col_names_dict`. + epsilon (float): Constant weighting factor that controls the relative + importance of balance score w.r.t. squared factual loss + (default: :obj:`0.5`). + """ + def __init__( + self, + channels: int, + hidden_channels: int, + decoder_hidden_channels: int, + out_channels: int, + col_stats: dict[str, dict[StatType, Any]] | None, + col_names_dict: dict[torch_frame.stype, list[str]] | None, + epsilon: float = 1.0, + ): + + super().__init__() + if col_stats is not None and col_names_dict is not None: + numerical_stats_list = [ + col_stats[col_name] + for col_name in col_names_dict[torch_frame.numerical] + ] + mean = torch.tensor( + [stats[StatType.MEAN] for stats in numerical_stats_list]) + self.register_buffer("mean", mean) + std = (torch.tensor( + [stats[StatType.STD] + for stats in numerical_stats_list]) + 1e-6) + self.register_buffer("std", std) + self.representation_learner = MLPBlock(channels, hidden_channels, + hidden_channels) + self.balance_score_learner = BalanceScoreEstimator( + hidden_channels, out_channels) + # decoder for treatment group + self.treatment_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, + out_channels) + # decoder for control group + self.control_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, out_channels) + self.epsilon = epsilon + self.reset_parameters() + + def reset_parameters(self): + self.representation_learner.reset_parameters() + self.balance_score_learner.reset_parameters() + self.treatment_decoder.reset_parameters() + self.control_decoder.reset_parameters() + + def forward(self, tf: TensorFrame, + treatment_index: int) -> Tuple[Tensor, Tensor]: + r"""T stands for treatment and y stands for output.""" + feat_cat = tf.feat_dict[categorical] + feat_num = tf.feat_dict[numerical] + if hasattr(self, 'mean'): + feat_num = (feat_num - self.mean) / self.std + assert isinstance(feat_cat, Tensor) + assert isinstance(feat_num, Tensor) + x = torch.cat([feat_cat, feat_num], dim=1) + t = x[:, treatment_index].clone() + # Swap the treatment col with the last column of x + x[:, treatment_index] = x[:, -1] + x[:, -1] = t + # Remove the treatment column + x = x[:, :-1] + + out = self.representation_learner(x) # batch_size, hidden_channels + treated_mask = t == 1 + treated = out[treated_mask, :] + control = out[~treated_mask, :] + pred = torch.zeros((len(x), 1), dtype=x.dtype, device=x.device) + pred[~treated_mask, :] = self.control_decoder(control) + pred[treated_mask, :] = self.treatment_decoder(treated) + penalty = self.balance_score_learner(out) + treated_weight = treated_mask.unsqueeze(-1) / (penalty + 0.01) + control_weight = ~treated_mask.unsqueeze(-1) / (penalty + 0.01) + balance_score = torch.mean( + torch.square( + torch.sum(treated_weight * x, dim=0) / + torch.sum(treated_weight + 0.01) - + torch.sum(control_weight * x, dim=0) / + (torch.sum(control_weight + 0.01)))) + return pred, self.epsilon * balance_score diff --git a/torch_frame/nn/models/cfr.py b/torch_frame/nn/models/cfr.py new file mode 100644 index 000000000..205b12df3 --- /dev/null +++ b/torch_frame/nn/models/cfr.py @@ -0,0 +1,115 @@ +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch.nn import Linear, Module, ReLU, Sequential + +import torch_frame +from torch_frame import TensorFrame, categorical, numerical +from torch_frame.data.stats import StatType + + +class MLPBlock(Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.model = Sequential(*[ + Linear(in_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, out_channels) + ]) + + def reset_parameters(self) -> None: + for block in self.model: + if isinstance(block, Linear): + block.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + r"""Transforming :obj:`x` into output predictions. + + Args: + x (Tensor): Input column-wise tensor of shape + [batch_size, num_cols, in_channels] + + Returns: + Tensor: [batch_size, out_channels]. + """ + return self.model(x) + + +class CFR(Module): + def __init__( + self, + channels: int, + hidden_channels: int, + decoder_hidden_channels: int, + out_channels: int, + col_stats: dict[str, dict[StatType, Any]] | None, + col_names_dict: dict[torch_frame.stype, list[str]] | None, + epsilon: float = 0.3, + ): + + super().__init__() + if col_stats is not None and col_names_dict is not None: + numerical_stats_list = [ + col_stats[col_name] + for col_name in col_names_dict[torch_frame.numerical] + ] + mean = torch.tensor( + [stats[StatType.MEAN] for stats in numerical_stats_list]) + self.register_buffer("mean", mean) + std = (torch.tensor( + [stats[StatType.STD] + for stats in numerical_stats_list]) + 1e-6) + self.register_buffer("std", std) + self.representation_learner = MLPBlock(channels, hidden_channels, + hidden_channels) + # decoder for treatment group + self.treatment_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, + out_channels) + # decoder for control group + self.control_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, out_channels) + self.epsilon = epsilon + self.reset_parameters() + + def reset_parameters(self): + self.representation_learner.reset_parameters() + self.treatment_decoder.reset_parameters() + self.control_decoder.reset_parameters() + + def forward(self, tf: TensorFrame, + treatment_index: int) -> Tuple[Tensor, Tensor]: + r"""T stands for treatment and y stands for output.""" + feat_cat = tf.feat_dict[categorical] + feat_num = tf.feat_dict[numerical] + if hasattr(self, 'mean'): + feat_num = (feat_num - self.mean) / self.std + assert isinstance(feat_cat, Tensor) + assert isinstance(feat_num, Tensor) + x = torch.cat([feat_cat, feat_num], dim=1) + t = x[:, treatment_index].clone() + # Swap the treatment col with the last column of x + x[:, treatment_index] = x[:, -1] + x[:, -1] = t + # Remove the treatment column + x = x[:, :-1] + + out = self.representation_learner(x) # batch_size, hidden_channels + treated_mask = t == 1 + treated = out[treated_mask, :] + control = out[~treated_mask, :] + pred = torch.zeros((len(x), 1), dtype=x.dtype, device=x.device) + pred[~treated_mask, :] = self.control_decoder(control) + pred[treated_mask, :] = self.treatment_decoder(treated) + treated_mean = torch.mean(treated, dim=0) + control_mean = torch.mean(control, dim=0) + ipm = 2 * torch.norm(treated_mean - control_mean, p=2) + return pred, self.epsilon * ipm