From 6141684e17cc4e20d865e2ee09d33c26726710e1 Mon Sep 17 00:00:00 2001 From: yiweny Date: Sun, 14 Apr 2024 02:01:32 +0000 Subject: [PATCH 1/9] Add Causal ML benchmark --- examples/bcauss.py | 6 ++ torch_frame/data/multi_embedding_tensor.py | 2 +- torch_frame/datasets/__init__.py | 2 + torch_frame/datasets/jobs.py | 67 ++++++++++++++++++++++ torch_frame/nn/decoder/__init__.py | 5 +- torch_frame/nn/decoder/mlpdecoder.py | 44 ++++++++++++++ torch_frame/nn/models/bcauss.py | 39 +++++++++++++ 7 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 examples/bcauss.py create mode 100644 torch_frame/datasets/jobs.py create mode 100644 torch_frame/nn/decoder/mlpdecoder.py create mode 100644 torch_frame/nn/models/bcauss.py diff --git a/examples/bcauss.py b/examples/bcauss.py new file mode 100644 index 000000000..49138723f --- /dev/null +++ b/examples/bcauss.py @@ -0,0 +1,6 @@ +import os.path as osp + +from torch_frame.datasets import Jobs + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") +dataset = Jobs(root=path) diff --git a/torch_frame/data/multi_embedding_tensor.py b/torch_frame/data/multi_embedding_tensor.py index b2031b186..aa53d2f6e 100644 --- a/torch_frame/data/multi_embedding_tensor.py +++ b/torch_frame/data/multi_embedding_tensor.py @@ -218,7 +218,7 @@ def _empty(self, dim: int) -> MultiEmbeddingTensor: Returns: MultiEmbeddingTensor: An empty :class:`MultiEmbeddingTensor`. Note that if :obj:`dim=0`, it will return with the original - offset tensor. + offset tensor.git """ return MultiEmbeddingTensor( num_rows=0 if dim == 0 else self.num_rows, diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index f838f1ce9..28ad64da9 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -18,6 +18,7 @@ from .amazon_fine_food_reviews import AmazonFineFoodReviews from .diamond_images import DiamondImages from .huggingface_dataset import HuggingFaceDatasetDict +from .jobs import Jobs real_world_datasets = [ 'Titanic', @@ -36,6 +37,7 @@ 'Mercari', 'AmazonFineFoodReviews', 'DiamondImages', + 'jobs', ] synthetic_datasets = [ diff --git a/torch_frame/datasets/jobs.py b/torch_frame/datasets/jobs.py new file mode 100644 index 000000000..940ae5568 --- /dev/null +++ b/torch_frame/datasets/jobs.py @@ -0,0 +1,67 @@ +import pandas as pd + +import torch_frame + + +class Jobs(torch_frame.data.Dataset): + r"""The `Jobs + `_ + dataset from Lalonde. + treatment indicator (1 if treated, 0 if not treated), age, + education, Black (1 if black, 0 otherwise), Hispanic + (1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise), + nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974), + RE75 (earnings in 1975), and RE78 (earnings in 1978). + """ + lalonde_treated = 'https://users.nber.org/~rdehejia/data/nsw_treated.txt' + lalonde_control = 'https://users.nber.org/~rdehejia/data/nsw_control.txt' + psid = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' # noqa + + def __init__(self, root: str): + # National Supported Work Demonstration + nsw_treated = self.download_url(Jobs.lalonde_treated, root) + nsw_control = self.download_url(Jobs.lalonde_control, root) + # Population Survey of Income Dynamics + psid = self.download_url(Jobs.psid, root) + names = [ + 'treated', 'age', 'education', 'Black', 'Hispanic', 'married', + 'nodegree', 'RE75', 'RE78' + ] + + nsw_treated_df = pd.read_csv( + nsw_treated, + sep='\s+', # noqa + names=names) + assert (nsw_treated_df['treated'] == 1).all() + nsw_treated_df['source'] = 'nsw' + + nsw_control_df = pd.read_csv( + nsw_control, + sep='\s+', # noqa + names=names) + assert (nsw_control_df['treated'] == 0).all() + nsw_control_df['source'] = 'nsw' + + names.insert(7, 'RE74') + + psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa + assert (psid_df['treated'] == 0).all() + psid_df['source'] = 'psid' + psid_df = psid_df.drop('RE74', axis=1) + + df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0) + df['target'] = df['RE78'] != 0 + + col_to_stype = { + 'treated': torch_frame.categorical, + 'age': torch_frame.numerical, + 'education': torch_frame.categorical, + 'Black': torch_frame.categorical, + 'Hispanic': torch_frame.categorical, + 'married': torch_frame.categorical, + 'nodegree': torch_frame.categorical, + 'RE75': torch_frame.numerical, + 'target': torch_frame.categorical, + } + + super().__init__(df, col_to_stype, target_col='target') diff --git a/torch_frame/nn/decoder/__init__.py b/torch_frame/nn/decoder/__init__.py index 065233d9b..8322448a3 100644 --- a/torch_frame/nn/decoder/__init__.py +++ b/torch_frame/nn/decoder/__init__.py @@ -2,9 +2,8 @@ from .decoder import Decoder from .trompt_decoder import TromptDecoder from .excelformer_decoder import ExcelFormerDecoder +from .mlpdecoder import MLPDecoder __all__ = classes = [ - 'Decoder', - 'TromptDecoder', - 'ExcelFormerDecoder', + 'Decoder', 'TromptDecoder', 'ExcelFormerDecoder', 'MLPDecoder' ] diff --git a/torch_frame/nn/decoder/mlpdecoder.py b/torch_frame/nn/decoder/mlpdecoder.py new file mode 100644 index 000000000..bd37e0b91 --- /dev/null +++ b/torch_frame/nn/decoder/mlpdecoder.py @@ -0,0 +1,44 @@ +from torch import Tensor +from torch.nn import Linear, ReLU + +from torch_frame.nn.decoder import Decoder + + +class MLPDecoder(Decoder): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_cols: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = ReLU() + self.lin_1 = Linear(num_cols, hidden_channels) + self.lin_2 = Linear(hidden_channels, hidden_channels) + self.lin_3 = Linear(hidden_channels, out_channels) + self.reset_parameters() + + def reset_parameters(self) -> None: + self.lin_1.reset_parameters() + self.lin_2.reset_parameters() + self.lin_3.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + r"""Transforming :obj:`x` into output predictions. + + Args: + x (Tensor): Input column-wise tensor of shape + [batch_size, num_cols, in_channels] + + Returns: + Tensor: [batch_size, out_channels]. + """ + x = self.lin_1(x) + x = self.activation(x) + x = self.lin_2(x) + x = self.activation(x) + x = self.lin_3(x) + return x diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py new file mode 100644 index 000000000..141cb73c3 --- /dev/null +++ b/torch_frame/nn/models/bcauss.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as F +from torch.nn import Module, Parameter + +from torch_frame.nn.decoder import MLPDecoder +from torch_frame.nn.models import MLP + + +class EpsilonLayer(Module): + def __init__(self): + super().__init__() + self.epsilon = Parameter(torch.randn(1, 1)) + + def reset_parameters(self): + self.epsilon.reset_parameters() + + def forward(self, t): + return F.sigmoid(self.epsilon * torch.ones_like(t)[:, 0:1]) + + +class BCAUSS(Module): + def __init__(self): + super().__init__() + self.mlp = MLP() + self.epsilon = EpsilonLayer() + # decoder for treatment group + self.treatment_decoder = MLPDecoder() + # decoder for control group + self.control_decoder = MLPDecoder() + + def forward(self, x, t): + r"""T stands for treatment and y stands for output.""" + out = self.mlp(x) + if t == 0: + out = self.control_decoder(out) + else: + out = self.treatment_decoder(out) + penalty = self.epsilon(out) + return out + penalty From 2346679dffc52f5882858f653b53cd80d0b60c0a Mon Sep 17 00:00:00 2001 From: yiweny Date: Sun, 14 Apr 2024 04:45:49 +0000 Subject: [PATCH 2/9] bcauss --- examples/bcauss.py | 67 ++++++++++ test/nn/decoder/test_excelformer_decoder.py | 2 +- torch_frame/data/multi_embedding_tensor.py | 2 +- torch_frame/nn/decoder/__init__.py | 5 +- torch_frame/nn/decoder/mlpdecoder.py | 44 ------- torch_frame/nn/models/__init__.py | 10 +- torch_frame/nn/models/bcauss.py | 131 ++++++++++++++++---- 7 files changed, 182 insertions(+), 79 deletions(-) delete mode 100644 torch_frame/nn/decoder/mlpdecoder.py diff --git a/examples/bcauss.py b/examples/bcauss.py index 49138723f..eb25f3992 100644 --- a/examples/bcauss.py +++ b/examples/bcauss.py @@ -1,6 +1,73 @@ +import argparse import os.path as osp +import torch +from torch.optim.lr_scheduler import ExponentialLR +from tqdm import tqdm + +from torch_frame import stype +from torch_frame.data import DataLoader from torch_frame.datasets import Jobs +from torch_frame.nn.models import BCAUSS + +parser = argparse.ArgumentParser() +parser.add_argument("--channels", type=int, default=256) +parser.add_argument("--num_layers", type=int, default=4) +parser.add_argument("--batch_size", type=int, default=512) +parser.add_argument("--lr", type=float, default=0.0001) +parser.add_argument("--epochs", type=int, default=50) +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--compile", action="store_true") +args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") dataset = Jobs(root=path) + +torch.manual_seed(args.seed) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +dataset.materialize(path=osp.join(path, "data.pt")) + +dataset = dataset.shuffle() +tensor_frame = dataset.tensor_frame +train_loader = DataLoader(tensor_frame, batch_size=args.batch_size, + shuffle=True) + +model = BCAUSS( + channels=tensor_frame.num_cols - 1, + hidden_channels=200, + decoder_hidden_channels=100, + out_channels=1, + col_stats=dataset.col_stats, + col_names_dict=tensor_frame.col_names_dict, +).to(device) + +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) +lr_scheduler = ExponentialLR(optimizer, gamma=0.95) + +is_classification = True + + +def train(epoch: int) -> float: + model.train() + loss_accum = total_count = 0 + + for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): + tf = tf.to(device) + treatment_idx = tf.col_names_dict[stype.categorical].index('treated') + out, balance_score, treated_mask = model.forward( + tf, treatment_index=treatment_idx) + loss = torch.mean(treated_mask * + torch.square(tf.y - out.squeeze(-1))) + balance_score + optimizer.zero_grad() + loss.backward() + loss_accum += float(loss) * len(out) + total_count += len(out) + optimizer.step() + return loss_accum / total_count + + +for epoch in range(1, args.epochs + 1): + train_loss = train(epoch) + + print(f'Train Loss: {train_loss:.4f}\n') diff --git a/test/nn/decoder/test_excelformer_decoder.py b/test/nn/decoder/test_excelformer_decoder.py index 2d92a116d..aee4c2f8a 100644 --- a/test/nn/decoder/test_excelformer_decoder.py +++ b/test/nn/decoder/test_excelformer_decoder.py @@ -5,7 +5,7 @@ def test_excelformer_decoder(): batch_size = 10 - num_cols = 18 + num_cols = 8 in_channels = 8 out_channels = 3 x = torch.randn(batch_size, num_cols, in_channels) diff --git a/torch_frame/data/multi_embedding_tensor.py b/torch_frame/data/multi_embedding_tensor.py index aa53d2f6e..b2031b186 100644 --- a/torch_frame/data/multi_embedding_tensor.py +++ b/torch_frame/data/multi_embedding_tensor.py @@ -218,7 +218,7 @@ def _empty(self, dim: int) -> MultiEmbeddingTensor: Returns: MultiEmbeddingTensor: An empty :class:`MultiEmbeddingTensor`. Note that if :obj:`dim=0`, it will return with the original - offset tensor.git + offset tensor. """ return MultiEmbeddingTensor( num_rows=0 if dim == 0 else self.num_rows, diff --git a/torch_frame/nn/decoder/__init__.py b/torch_frame/nn/decoder/__init__.py index 8322448a3..2c64a258e 100644 --- a/torch_frame/nn/decoder/__init__.py +++ b/torch_frame/nn/decoder/__init__.py @@ -2,8 +2,5 @@ from .decoder import Decoder from .trompt_decoder import TromptDecoder from .excelformer_decoder import ExcelFormerDecoder -from .mlpdecoder import MLPDecoder -__all__ = classes = [ - 'Decoder', 'TromptDecoder', 'ExcelFormerDecoder', 'MLPDecoder' -] +__all__ = classes = ['Decoder', 'TromptDecoder', 'ExcelFormerDecoder'] diff --git a/torch_frame/nn/decoder/mlpdecoder.py b/torch_frame/nn/decoder/mlpdecoder.py deleted file mode 100644 index bd37e0b91..000000000 --- a/torch_frame/nn/decoder/mlpdecoder.py +++ /dev/null @@ -1,44 +0,0 @@ -from torch import Tensor -from torch.nn import Linear, ReLU - -from torch_frame.nn.decoder import Decoder - - -class MLPDecoder(Decoder): - def __init__( - self, - in_channels: int, - hidden_channels: int, - out_channels: int, - num_cols: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.activation = ReLU() - self.lin_1 = Linear(num_cols, hidden_channels) - self.lin_2 = Linear(hidden_channels, hidden_channels) - self.lin_3 = Linear(hidden_channels, out_channels) - self.reset_parameters() - - def reset_parameters(self) -> None: - self.lin_1.reset_parameters() - self.lin_2.reset_parameters() - self.lin_3.reset_parameters() - - def forward(self, x: Tensor) -> Tensor: - r"""Transforming :obj:`x` into output predictions. - - Args: - x (Tensor): Input column-wise tensor of shape - [batch_size, num_cols, in_channels] - - Returns: - Tensor: [batch_size, out_channels]. - """ - x = self.lin_1(x) - x = self.activation(x) - x = self.lin_2(x) - x = self.activation(x) - x = self.lin_3(x) - return x diff --git a/torch_frame/nn/models/__init__.py b/torch_frame/nn/models/__init__.py index ef3c8c4c2..754c07402 100644 --- a/torch_frame/nn/models/__init__.py +++ b/torch_frame/nn/models/__init__.py @@ -6,13 +6,9 @@ from .resnet import ResNet from .tab_transformer import TabTransformer from .mlp import MLP +from .bcauss import BCAUSS __all__ = classes = [ - 'Trompt', - 'FTTransformer', - 'ExcelFormer', - 'TabNet', - 'ResNet', - 'TabTransformer', - 'MLP', + 'Trompt', 'FTTransformer', 'ExcelFormer', 'TabNet', 'ResNet', + 'TabTransformer', 'MLP', 'BCAUSS' ] diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py index 141cb73c3..4e5855d63 100644 --- a/torch_frame/nn/models/bcauss.py +++ b/torch_frame/nn/models/bcauss.py @@ -1,39 +1,126 @@ +from typing import Any, Tuple + import torch import torch.nn.functional as F -from torch.nn import Module, Parameter +from torch import Tensor +from torch.nn import Linear, Module, Parameter, ReLU, Sequential -from torch_frame.nn.decoder import MLPDecoder -from torch_frame.nn.models import MLP +import torch_frame +from torch_frame import TensorFrame, categorical, numerical +from torch_frame.data.stats import StatType -class EpsilonLayer(Module): - def __init__(self): +class MLPBlock(Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + ) -> None: super().__init__() - self.epsilon = Parameter(torch.randn(1, 1)) + self.model = Sequential(*[ + Linear(in_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, out_channels) + ]) + self.reset_parameters() + + def reset_parameters(self) -> None: + for block in self.model: + if isinstance(block, Linear): + block.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + r"""Transforming :obj:`x` into output predictions. + + Args: + x (Tensor): Input column-wise tensor of shape + [batch_size, num_cols, in_channels] + + Returns: + Tensor: [batch_size, out_channels]. + """ + return self.model(x) + + +class BalanceScoreEstimator(Module): + def __init__(self, channels, out_channels): + super().__init__() + self.model = Linear(channels, out_channels) def reset_parameters(self): - self.epsilon.reset_parameters() + self.model.reset_parameters() - def forward(self, t): - return F.sigmoid(self.epsilon * torch.ones_like(t)[:, 0:1]) + def forward(self, x: Tensor): + return F.sigmoid(self.model(x)) class BCAUSS(Module): - def __init__(self): + def __init__(self, channels: int, hidden_channels: int, + decoder_hidden_channels: int, out_channels: int, + col_stats: dict[str, dict[StatType, Any]], + col_names_dict: dict[torch_frame.stype, list[str]]): + super().__init__() - self.mlp = MLP() - self.epsilon = EpsilonLayer() + numerical_stats_list = [ + col_stats[col_name] + for col_name in col_names_dict[torch_frame.numerical] + ] + mean = torch.tensor( + [stats[StatType.MEAN] for stats in numerical_stats_list]) + self.register_buffer("mean", mean) + std = (torch.tensor( + [stats[StatType.STD] for stats in numerical_stats_list]) + 1e-6) + self.register_buffer("std", std) + self.representation_learner = MLPBlock(channels, hidden_channels, + hidden_channels) + self.balance_score_learner = BalanceScoreEstimator( + hidden_channels, out_channels) # decoder for treatment group - self.treatment_decoder = MLPDecoder() + self.treatment_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, + out_channels) # decoder for control group - self.control_decoder = MLPDecoder() + self.control_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, out_channels) + self.epsilon = Parameter(torch.randn(1, 1)) + self.reset_parameters() - def forward(self, x, t): + def reset_parameters(self): + self.representation_learner.reset_parameters() + self.balance_score_learner.reset_parameters() + self.treatment_decoder.reset_parameters() + self.control_decoder.reset_parameters() + + def forward(self, tf: TensorFrame, + treatment_index: int) -> Tuple[Tensor, Tensor, Tensor]: r"""T stands for treatment and y stands for output.""" - out = self.mlp(x) - if t == 0: - out = self.control_decoder(out) - else: - out = self.treatment_decoder(out) - penalty = self.epsilon(out) - return out + penalty + feat_cat = tf.feat_dict[categorical] + feat_num = (tf.feat_dict[numerical] - self.mean) / self.std + assert isinstance(feat_cat, Tensor) + assert isinstance(feat_num, Tensor) + x = torch.cat([feat_cat, feat_num], dim=1) + t = x[:, treatment_index].clone() + # Swap the treatment col with the last column of x + x[:, treatment_index] = x[:, -1] + x[:, -1] = t + # Remove the treatment column + x = x[:, :-1] + + out = self.representation_learner(x) # batch_size, hidden_channels + treated_mask = t == 1 + treated = out[treated_mask, :] + control = out[~treated_mask, :] + pred = torch.zeros((len(x), 1), dtype=torch.float32, device=x.device) + pred[~treated_mask, :] = self.control_decoder(control) + pred[treated_mask, :] = self.treatment_decoder(treated) + penalty = self.balance_score_learner(out) + treated_weight = treated_mask.unsqueeze(-1) / penalty + control_weight = ~treated_mask.unsqueeze(-1) / penalty + balance_score = torch.mean( + torch.square( + torch.sum(treated_weight * x) / torch.sum(treated_weight) - + torch.sum(control_weight * x) / torch.sum(control_weight))) + return pred, self.epsilon * balance_score, treated_mask From 69a8f6b39d33c8fa3cb2e0bd1d6e19365be55d9e Mon Sep 17 00:00:00 2001 From: yiweny Date: Sun, 14 Apr 2024 08:26:06 +0000 Subject: [PATCH 3/9] fix code --- examples/bcauss.py | 85 +++++++++++++++---- torch_frame/datasets/jobs.py | 146 +++++++++++++++++++++++--------- torch_frame/nn/models/bcauss.py | 50 ++++++----- 3 files changed, 203 insertions(+), 78 deletions(-) diff --git a/examples/bcauss.py b/examples/bcauss.py index eb25f3992..a4dd78254 100644 --- a/examples/bcauss.py +++ b/examples/bcauss.py @@ -1,27 +1,32 @@ +"""Reported (reproduced) E_ATT of BCAUSS based on Table 1 of the paper. +BAUSS + in_sample 0.02 (0.0284) +""" import argparse +import copy import os.path as osp import torch from torch.optim.lr_scheduler import ExponentialLR from tqdm import tqdm -from torch_frame import stype -from torch_frame.data import DataLoader +from torch_frame import TensorFrame, stype +from torch_frame.data import DataLoader, Dataset from torch_frame.datasets import Jobs from torch_frame.nn.models import BCAUSS parser = argparse.ArgumentParser() -parser.add_argument("--channels", type=int, default=256) -parser.add_argument("--num_layers", type=int, default=4) -parser.add_argument("--batch_size", type=int, default=512) -parser.add_argument("--lr", type=float, default=0.0001) -parser.add_argument("--epochs", type=int, default=50) -parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--batch_size", type=int, default=3000) +parser.add_argument("--lr", type=float, default=0.00001) +parser.add_argument("--epochs", type=int, default=2) +parser.add_argument("--seed", type=int, default=2) parser.add_argument("--compile", action="store_true") +parser.add_argument("--feature_engineering", action="store_true", default=True) args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") -dataset = Jobs(root=path) +dataset = Jobs(root=path, feature_engineering=args.feature_engineering) +ATT = dataset.get_att() +print(f"ATT is {ATT}") torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -29,20 +34,53 @@ dataset.materialize(path=osp.join(path, "data.pt")) dataset = dataset.shuffle() -tensor_frame = dataset.tensor_frame -train_loader = DataLoader(tensor_frame, batch_size=args.batch_size, +""" +if dataset.split_col is None: + train_dataset, val_dataset, test_dataset = dataset[:0.62], dataset[ + 0.62:0.80], dataset[0.80:] +else: + train_dataset, val_dataset, test_dataset = dataset.split() +""" +train_dataset = dataset +treated_df = dataset.df[(dataset.df['source'] == 1) + & (dataset.df['treated'] == 1)] +treated_eval_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') +control_df = copy.deepcopy(treated_df) +control_df['treated'] = 0 +control_eval_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + +treated_eval_dataset.materialize(path=osp.join(path, "treated_eval_data.pt")) +control_eval_dataset.materialize(path=osp.join(path, "control_eval_data.pt")) + +train_tensor_frame = train_dataset.tensor_frame +treatment_idx = train_tensor_frame.col_names_dict[stype.categorical].index( + 'treated') +# val_tensor_frame = val_dataset.tensor_frame +# test_tensor_frame = test_dataset.tensor_frame +treated_eval_tensor_frame = treated_eval_dataset.tensor_frame +# This is a bad hack. Currently the materialization logic would override 1's +# to 0's due to 1's being the popular class +treated_eval_tensor_frame.feat_dict[stype.categorical][:, treatment_idx] = 1. +control_eval_tensor_frame = control_eval_dataset.tensor_frame + +train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, shuffle=True) +# val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) +# test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) model = BCAUSS( - channels=tensor_frame.num_cols - 1, + channels=train_tensor_frame.num_cols - 1, hidden_channels=200, decoder_hidden_channels=100, out_channels=1, - col_stats=dataset.col_stats, - col_names_dict=tensor_frame.col_names_dict, + col_stats=dataset.col_stats if not args.feature_engineering else None, + col_names_dict=train_tensor_frame.col_names_dict + if not args.feature_engineering else None, ).to(device) -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) +optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) lr_scheduler = ExponentialLR(optimizer, gamma=0.95) is_classification = True @@ -54,7 +92,6 @@ def train(epoch: int) -> float: for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): tf = tf.to(device) - treatment_idx = tf.col_names_dict[stype.categorical].index('treated') out, balance_score, treated_mask = model.forward( tf, treatment_index=treatment_idx) loss = torch.mean(treated_mask * @@ -67,7 +104,21 @@ def train(epoch: int) -> float: return loss_accum / total_count +@torch.no_grad() +def test(treated: TensorFrame, control: TensorFrame) -> float: + model.eval() + + treated = treated.to(device) + treated_effect, _, _ = model(treated, treatment_idx) + + control = control.to(device) + control_effect, _, _ = model(control, treatment_idx) + + return torch.abs(ATT - torch.mean(treated_effect - control_effect)) + + for epoch in range(1, args.epochs + 1): train_loss = train(epoch) + error = test(treated_eval_tensor_frame, control_eval_tensor_frame) - print(f'Train Loss: {train_loss:.4f}\n') + print(f'Train Loss: {train_loss:.4f} Error_ATT: {error:.4f},\n') diff --git a/torch_frame/datasets/jobs.py b/torch_frame/datasets/jobs.py index 940ae5568..8a6c2c469 100644 --- a/torch_frame/datasets/jobs.py +++ b/torch_frame/datasets/jobs.py @@ -1,6 +1,8 @@ +import numpy as np import pandas as pd import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM class Jobs(torch_frame.data.Dataset): @@ -15,53 +17,113 @@ class Jobs(torch_frame.data.Dataset): """ lalonde_treated = 'https://users.nber.org/~rdehejia/data/nsw_treated.txt' lalonde_control = 'https://users.nber.org/~rdehejia/data/nsw_control.txt' - psid = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' # noqa + psid = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' + train = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz' + test = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.test.npz' - def __init__(self, root: str): - # National Supported Work Demonstration - nsw_treated = self.download_url(Jobs.lalonde_treated, root) - nsw_control = self.download_url(Jobs.lalonde_control, root) - # Population Survey of Income Dynamics - psid = self.download_url(Jobs.psid, root) - names = [ - 'treated', 'age', 'education', 'Black', 'Hispanic', 'married', - 'nodegree', 'RE75', 'RE78' - ] + def __init__(self, root: str, feature_engineering: bool = False): + if feature_engineering: + train = self.download_url(Jobs.train, root) + test = self.download_url(Jobs.test, root) + train_np = np.load(train) + test_np = np.load(test) + train_data = np.concatenate([ + train_np.f.t[:, 0].reshape(-1, 1), train_np.f.x[:, :, 0], + train_np.f.e[:, 0].reshape(-1, 1), train_np.f.yf[:, 0].reshape( + -1, 1) + ], axis=1) + test_data = np.concatenate([ + test_np.f.t[:, 0].reshape(-1, 1), test_np.f.x[:, :, 0], + test_np.f.e[:, 0].reshape(-1, 1), test_np.f.yf[:, 0].reshape( + -1, 1) + ], axis=1) + train_df = pd.DataFrame( + train_data, columns=['treated'] + + [f'Col_{i}' + for i in range(train_np.f.x.shape[1])] + ['source', 'target']) + train_df['split'] = SPLIT_TO_NUM['train'] + test_df = pd.DataFrame( + test_data, columns=['treated'] + + [f'Col_{i}' + for i in range(train_np.f.x.shape[1])] + ['source', 'target']) + test_df['split'] = SPLIT_TO_NUM['test'] + df = pd.concat([train_df, test_df], axis=0) + col_to_stype = { + 'treated': torch_frame.categorical, + 'Col_0': torch_frame.numerical, + 'Col_1': torch_frame.numerical, + 'Col_2': torch_frame.categorical, + 'Col_3': torch_frame.categorical, + 'Col_4': torch_frame.categorical, + 'Col_5': torch_frame.categorical, + 'Col_6': torch_frame.numerical, + 'Col_7': torch_frame.numerical, + 'Col_8': torch_frame.numerical, + 'Col_9': torch_frame.numerical, + 'Col_10': torch_frame.numerical, + 'Col_11': torch_frame.numerical, + 'Col_12': torch_frame.numerical, + 'Col_13': torch_frame.categorical, + 'Col_14': torch_frame.categorical, + 'Col_15': torch_frame.numerical, + 'Col_16': torch_frame.categorical, + 'target': torch_frame.categorical + } + super().__init__(df, col_to_stype, target_col='target', + split_col='split') + else: + # National Supported Work Demonstration + nsw_treated = self.download_url(Jobs.lalonde_treated, root) + nsw_control = self.download_url(Jobs.lalonde_control, root) + # Population Survey of Income Dynamics + psid = self.download_url(Jobs.psid, root) + names = [ + 'treated', 'age', 'education', 'Black', 'Hispanic', 'married', + 'nodegree', 'RE75', 'RE78' + ] - nsw_treated_df = pd.read_csv( - nsw_treated, - sep='\s+', # noqa - names=names) - assert (nsw_treated_df['treated'] == 1).all() - nsw_treated_df['source'] = 'nsw' + nsw_treated_df = pd.read_csv( + nsw_treated, + sep='\s+', # noqa + names=names) + assert (nsw_treated_df['treated'] == 1).all() + nsw_treated_df['source'] = 'nsw' - nsw_control_df = pd.read_csv( - nsw_control, - sep='\s+', # noqa - names=names) - assert (nsw_control_df['treated'] == 0).all() - nsw_control_df['source'] = 'nsw' + nsw_control_df = pd.read_csv( + nsw_control, + sep='\s+', # noqa + names=names) + assert (nsw_control_df['treated'] == 0).all() + nsw_control_df['source'] = 1 - names.insert(7, 'RE74') + names.insert(7, 'RE74') - psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa - assert (psid_df['treated'] == 0).all() - psid_df['source'] = 'psid' - psid_df = psid_df.drop('RE74', axis=1) + psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa + assert (psid_df['treated'] == 0).all() + psid_df['source'] = 0 + psid_df = psid_df.drop('RE74', axis=1) - df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0) - df['target'] = df['RE78'] != 0 + df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0) + df['target'] = df['RE78'] != 0 - col_to_stype = { - 'treated': torch_frame.categorical, - 'age': torch_frame.numerical, - 'education': torch_frame.categorical, - 'Black': torch_frame.categorical, - 'Hispanic': torch_frame.categorical, - 'married': torch_frame.categorical, - 'nodegree': torch_frame.categorical, - 'RE75': torch_frame.numerical, - 'target': torch_frame.categorical, - } + col_to_stype = { + 'treated': torch_frame.categorical, + 'age': torch_frame.numerical, + 'education': torch_frame.categorical, + 'Black': torch_frame.categorical, + 'Hispanic': torch_frame.categorical, + 'married': torch_frame.categorical, + 'nodegree': torch_frame.categorical, + 'RE75': torch_frame.numerical, + 'target': torch_frame.categorical, + } - super().__init__(df, col_to_stype, target_col='target') + super().__init__(df, col_to_stype, target_col='target') + self.df = df + + def get_att(self): + df = self.df[self.df['source'] == 1] + treated = df[df['treated'] == 1] + control = df[df['treated'] == 0] + return sum(treated['target']) / len(treated) - sum( + control['target']) / len(control) diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py index 4e5855d63..e9811629a 100644 --- a/torch_frame/nn/models/bcauss.py +++ b/torch_frame/nn/models/bcauss.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.nn import Linear, Module, Parameter, ReLU, Sequential +from torch.nn import Linear, Module, ReLU, Sequential import torch_frame from torch_frame import TensorFrame, categorical, numerical @@ -58,22 +58,30 @@ def forward(self, x: Tensor): class BCAUSS(Module): - def __init__(self, channels: int, hidden_channels: int, - decoder_hidden_channels: int, out_channels: int, - col_stats: dict[str, dict[StatType, Any]], - col_names_dict: dict[torch_frame.stype, list[str]]): + def __init__( + self, + channels: int, + hidden_channels: int, + decoder_hidden_channels: int, + out_channels: int, + col_stats: dict[str, dict[StatType, Any]] | None, + col_names_dict: dict[torch_frame.stype, list[str]] | None, + epsilon: float = 0.5, + ): super().__init__() - numerical_stats_list = [ - col_stats[col_name] - for col_name in col_names_dict[torch_frame.numerical] - ] - mean = torch.tensor( - [stats[StatType.MEAN] for stats in numerical_stats_list]) - self.register_buffer("mean", mean) - std = (torch.tensor( - [stats[StatType.STD] for stats in numerical_stats_list]) + 1e-6) - self.register_buffer("std", std) + if col_stats is not None and col_names_dict is not None: + numerical_stats_list = [ + col_stats[col_name] + for col_name in col_names_dict[torch_frame.numerical] + ] + mean = torch.tensor( + [stats[StatType.MEAN] for stats in numerical_stats_list]) + self.register_buffer("mean", mean) + std = (torch.tensor( + [stats[StatType.STD] + for stats in numerical_stats_list]) + 1e-6) + self.register_buffer("std", std) self.representation_learner = MLPBlock(channels, hidden_channels, hidden_channels) self.balance_score_learner = BalanceScoreEstimator( @@ -85,7 +93,7 @@ def __init__(self, channels: int, hidden_channels: int, # decoder for control group self.control_decoder = MLPBlock(hidden_channels, decoder_hidden_channels, out_channels) - self.epsilon = Parameter(torch.randn(1, 1)) + self.epsilon = epsilon self.reset_parameters() def reset_parameters(self): @@ -98,7 +106,9 @@ def forward(self, tf: TensorFrame, treatment_index: int) -> Tuple[Tensor, Tensor, Tensor]: r"""T stands for treatment and y stands for output.""" feat_cat = tf.feat_dict[categorical] - feat_num = (tf.feat_dict[numerical] - self.mean) / self.std + feat_num = tf.feat_dict[numerical] + if hasattr(self, 'mean'): + feat_num = (feat_num - self.mean) / self.std assert isinstance(feat_cat, Tensor) assert isinstance(feat_num, Tensor) x = torch.cat([feat_cat, feat_num], dim=1) @@ -121,6 +131,8 @@ def forward(self, tf: TensorFrame, control_weight = ~treated_mask.unsqueeze(-1) / penalty balance_score = torch.mean( torch.square( - torch.sum(treated_weight * x) / torch.sum(treated_weight) - - torch.sum(control_weight * x) / torch.sum(control_weight))) + torch.sum(treated_weight * x, dim=0) / + torch.sum(treated_weight + 0.01) - + torch.sum(control_weight * x, dim=0) / + (torch.sum(control_weight + 0.01)))) return pred, self.epsilon * balance_score, treated_mask From dbc2628fcca190a7bb1f09c49905a01f35e2b38f Mon Sep 17 00:00:00 2001 From: yiweny Date: Mon, 15 Apr 2024 08:02:07 +0000 Subject: [PATCH 4/9] add out of distribution metric --- examples/bcauss.py | 115 +++++++++++++++++++++++--------- torch_frame/datasets/jobs.py | 21 +++--- torch_frame/nn/models/bcauss.py | 4 +- 3 files changed, 99 insertions(+), 41 deletions(-) diff --git a/examples/bcauss.py b/examples/bcauss.py index a4dd78254..bee3d0719 100644 --- a/examples/bcauss.py +++ b/examples/bcauss.py @@ -1,5 +1,6 @@ """Reported (reproduced) E_ATT of BCAUSS based on Table 1 of the paper. BAUSS + in_sample 0.02 (0.0284) +BAUSS + out_of_sample 0.0290 (0.05 +/- 0.02) """ import argparse import copy @@ -19,12 +20,12 @@ parser.add_argument("--lr", type=float, default=0.00001) parser.add_argument("--epochs", type=int, default=2) parser.add_argument("--seed", type=int, default=2) -parser.add_argument("--compile", action="store_true") -parser.add_argument("--feature_engineering", action="store_true", default=True) +parser.add_argument("--feature-engineering", action="store_true", default=True) +parser.add_argument("--out-of-distribution", action="store_true", default=True) args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") -dataset = Jobs(root=path, feature_engineering=args.feature_engineering) +dataset = Jobs(root=path, dehejia=False) ATT = dataset.get_att() print(f"ATT is {ATT}") @@ -34,36 +35,75 @@ dataset.materialize(path=osp.join(path, "data.pt")) dataset = dataset.shuffle() -""" -if dataset.split_col is None: - train_dataset, val_dataset, test_dataset = dataset[:0.62], dataset[ - 0.62:0.80], dataset[0.80:] +if args.out_of_distribution: + if dataset.split_col is None: + train_dataset, val_dataset, test_dataset = dataset[:0.62], dataset[ + 0.62:0.80], dataset[0.80:] + else: + train_dataset, _, test_dataset = dataset.split() + train_dataset, val_dataset = train_dataset[:0.775], dataset[0.775:] + # Calculating the validation dataset + treated_df = val_dataset.df[(val_dataset.df['source'] == 1) + & (val_dataset.df['treated'] == 1)] + treated_val_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_val_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_val_dataset.materialize(path=osp.join(path, "treated_val_data.pt")) + control_val_dataset.materialize(path=osp.join(path, "control_val_data.pt")) + # Calculating the evaluation dataset + treated_df = test_dataset.df[(test_dataset.df['source'] == 1) + & (test_dataset.df['treated'] == 1)] + treated_test_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_test_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_test_dataset.materialize( + path=osp.join(path, "treated_test_data.pt")) + control_test_dataset.materialize( + path=osp.join(path, "control_test_data.pt")) else: - train_dataset, val_dataset, test_dataset = dataset.split() -""" -train_dataset = dataset -treated_df = dataset.df[(dataset.df['source'] == 1) - & (dataset.df['treated'] == 1)] -treated_eval_dataset = Dataset(treated_df, dataset.col_to_stype, - target_col='target') -control_df = copy.deepcopy(treated_df) -control_df['treated'] = 0 -control_eval_dataset = Dataset(control_df, dataset.col_to_stype, - target_col='target') - -treated_eval_dataset.materialize(path=osp.join(path, "treated_eval_data.pt")) -control_eval_dataset.materialize(path=osp.join(path, "control_eval_data.pt")) + train_dataset = dataset + + # Calculating the evaluation dataset + treated_df = dataset.df[(dataset.df['source'] == 1) + & (dataset.df['treated'] == 1)] + treated_test_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_test_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_test_dataset.materialize( + path=osp.join(path, "treated_eval_data.pt")) + control_test_dataset.materialize( + path=osp.join(path, "control_eval_data.pt")) train_tensor_frame = train_dataset.tensor_frame treatment_idx = train_tensor_frame.col_names_dict[stype.categorical].index( 'treated') -# val_tensor_frame = val_dataset.tensor_frame -# test_tensor_frame = test_dataset.tensor_frame -treated_eval_tensor_frame = treated_eval_dataset.tensor_frame +if args.out_of_distribution: + val_tensor_frame = val_dataset.tensor_frame + test_tensor_frame = test_dataset.tensor_frame + treated_val_tensor_frame = treated_val_dataset.tensor_frame + # This is a bad hack. Currently the materialization logic would override 1's + # to 0's due to 1's being the popular class + treated_val_tensor_frame.feat_dict[stype.categorical][:, + treatment_idx] = 1. + control_val_tensor_frame = control_val_dataset.tensor_frame + +treated_test_tensor_frame = treated_test_dataset.tensor_frame # This is a bad hack. Currently the materialization logic would override 1's # to 0's due to 1's being the popular class -treated_eval_tensor_frame.feat_dict[stype.categorical][:, treatment_idx] = 1. -control_eval_tensor_frame = control_eval_dataset.tensor_frame +treated_test_tensor_frame.feat_dict[stype.categorical][:, treatment_idx] = 1. +control_test_tensor_frame = control_test_dataset.tensor_frame train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, shuffle=True) @@ -105,7 +145,7 @@ def train(epoch: int) -> float: @torch.no_grad() -def test(treated: TensorFrame, control: TensorFrame) -> float: +def eval(treated: TensorFrame, control: TensorFrame) -> float: model.eval() treated = treated.to(device) @@ -117,8 +157,23 @@ def test(treated: TensorFrame, control: TensorFrame) -> float: return torch.abs(ATT - torch.mean(treated_effect - control_effect)) +best_val_metric = float('inf') +best_test_metric = float('inf') + for epoch in range(1, args.epochs + 1): train_loss = train(epoch) - error = test(treated_eval_tensor_frame, control_eval_tensor_frame) - - print(f'Train Loss: {train_loss:.4f} Error_ATT: {error:.4f},\n') + error = eval(treated_test_tensor_frame, control_test_tensor_frame) + if args.out_of_distribution: + val_error = eval(treated_val_tensor_frame, control_val_tensor_frame) + if val_error < best_val_metric: + best_val_metric = val_error + best_test_metric = error + print( + f'Train Loss: {train_loss:.4f} Val Error_ATT: {val_error:.4f},\n' + f' Error_ATT: {error:.4f}\n') + else: + print(f'Train Loss: {train_loss:.4f} Error_ATT: {error:.4f},\n') + +if args.out_of_distribution: + print(f'Best Val Error: {best_val_metric:.4f}, ' + f'Best Test Error: {best_test_metric:.4f}') diff --git a/torch_frame/datasets/jobs.py b/torch_frame/datasets/jobs.py index 8a6c2c469..e2f2d7fa9 100644 --- a/torch_frame/datasets/jobs.py +++ b/torch_frame/datasets/jobs.py @@ -14,6 +14,8 @@ class Jobs(torch_frame.data.Dataset): (1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise), nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974), RE75 (earnings in 1975), and RE78 (earnings in 1978). + + Dehejia features: """ lalonde_treated = 'https://users.nber.org/~rdehejia/data/nsw_treated.txt' lalonde_control = 'https://users.nber.org/~rdehejia/data/nsw_control.txt' @@ -21,21 +23,22 @@ class Jobs(torch_frame.data.Dataset): train = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz' test = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.test.npz' - def __init__(self, root: str, feature_engineering: bool = False): - if feature_engineering: + def __init__(self, root: str, dehejia: bool = False): + if not dehejia: + split = 0 train = self.download_url(Jobs.train, root) test = self.download_url(Jobs.test, root) train_np = np.load(train) test_np = np.load(test) train_data = np.concatenate([ - train_np.f.t[:, 0].reshape(-1, 1), train_np.f.x[:, :, 0], - train_np.f.e[:, 0].reshape(-1, 1), train_np.f.yf[:, 0].reshape( - -1, 1) + train_np.f.t[:, split].reshape(-1, 1), + train_np.f.x[:, :, split], train_np.f.e[:, split].reshape( + -1, 1), train_np.f.yf[:, split].reshape(-1, 1) ], axis=1) test_data = np.concatenate([ - test_np.f.t[:, 0].reshape(-1, 1), test_np.f.x[:, :, 0], - test_np.f.e[:, 0].reshape(-1, 1), test_np.f.yf[:, 0].reshape( - -1, 1) + test_np.f.t[:, split].reshape(-1, 1), test_np.f.x[:, :, split], + test_np.f.e[:, split].reshape( + -1, 1), test_np.f.yf[:, split].reshape(-1, 1) ], axis=1) train_df = pd.DataFrame( train_data, columns=['treated'] + @@ -87,7 +90,7 @@ def __init__(self, root: str, feature_engineering: bool = False): sep='\s+', # noqa names=names) assert (nsw_treated_df['treated'] == 1).all() - nsw_treated_df['source'] = 'nsw' + nsw_treated_df['source'] = 1 nsw_control_df = pd.read_csv( nsw_control, diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py index e9811629a..ea55fd5af 100644 --- a/torch_frame/nn/models/bcauss.py +++ b/torch_frame/nn/models/bcauss.py @@ -127,8 +127,8 @@ def forward(self, tf: TensorFrame, pred[~treated_mask, :] = self.control_decoder(control) pred[treated_mask, :] = self.treatment_decoder(treated) penalty = self.balance_score_learner(out) - treated_weight = treated_mask.unsqueeze(-1) / penalty - control_weight = ~treated_mask.unsqueeze(-1) / penalty + treated_weight = treated_mask.unsqueeze(-1) / (penalty + 0.01) + control_weight = ~treated_mask.unsqueeze(-1) / (penalty + 0.01) balance_score = torch.mean( torch.square( torch.sum(treated_weight * x, dim=0) / From 0b55b9d8fb621a20ebccdac7d72f3265f2b47b05 Mon Sep 17 00:00:00 2001 From: yiweny Date: Thu, 25 Apr 2024 06:55:03 +0000 Subject: [PATCH 5/9] add cfr with maximum mean discrepancy --- examples/bcauss.py | 26 +++-- examples/causalml.py | 181 +++++++++++++++++++++++++++++ torch_frame/datasets/__init__.py | 4 +- torch_frame/datasets/ihdp.py | 93 +++++++++++++++ torch_frame/datasets/jobs.py | 66 ++++++----- torch_frame/nn/decoder/__init__.py | 6 +- torch_frame/nn/models/__init__.py | 3 +- torch_frame/nn/models/bcauss.py | 40 ++++++- torch_frame/nn/models/cfr.py | 115 ++++++++++++++++++ 9 files changed, 487 insertions(+), 47 deletions(-) create mode 100644 examples/causalml.py create mode 100644 torch_frame/datasets/ihdp.py create mode 100644 torch_frame/nn/models/cfr.py diff --git a/examples/bcauss.py b/examples/bcauss.py index bee3d0719..36b7209a9 100644 --- a/examples/bcauss.py +++ b/examples/bcauss.py @@ -1,6 +1,6 @@ -"""Reported (reproduced) E_ATT of BCAUSS based on Table 1 of the paper. -BAUSS + in_sample 0.02 (0.0284) -BAUSS + out_of_sample 0.0290 (0.05 +/- 0.02) +"""Reported (reproduced) E_ATT of BCAUSS based on Table 1 of the paper! +BAUSS + in_sample 0.02 (0.0284). +BAUSS + out_of_sample 0.05 +/- 0.02 (0.0290). """ import argparse import copy @@ -16,16 +16,16 @@ from torch_frame.nn.models import BCAUSS parser = argparse.ArgumentParser() -parser.add_argument("--batch_size", type=int, default=3000) +parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--lr", type=float, default=0.00001) -parser.add_argument("--epochs", type=int, default=2) +parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--seed", type=int, default=2) parser.add_argument("--feature-engineering", action="store_true", default=True) parser.add_argument("--out-of-distribution", action="store_true", default=True) args = parser.parse_args() path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "jobs") -dataset = Jobs(root=path, dehejia=False) +dataset = Jobs(root=path, feature_engineering=args.feature_engineering) ATT = dataset.get_att() print(f"ATT is {ATT}") @@ -93,8 +93,8 @@ val_tensor_frame = val_dataset.tensor_frame test_tensor_frame = test_dataset.tensor_frame treated_val_tensor_frame = treated_val_dataset.tensor_frame - # This is a bad hack. Currently the materialization logic would override 1's - # to 0's due to 1's being the popular class + # This is a bad hack. Currently the materialization logic would override + # 1's to 0's due to 1's being the popular class treated_val_tensor_frame.feat_dict[stype.categorical][:, treatment_idx] = 1. control_val_tensor_frame = control_val_dataset.tensor_frame @@ -132,10 +132,12 @@ def train(epoch: int) -> float: for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): tf = tf.to(device) - out, balance_score, treated_mask = model.forward( - tf, treatment_index=treatment_idx) - loss = torch.mean(treated_mask * - torch.square(tf.y - out.squeeze(-1))) + balance_score + out, balance_score, treated_mask = model(tf, + treatment_index=treatment_idx) + loss = ( + (torch.sum(treated_mask * torch.square(tf.y - out.squeeze(-1))) + + torch.sum(~treated_mask * torch.square(tf.y - out.squeeze(-1)))) / + len(treated_mask) + balance_score) optimizer.zero_grad() loss.backward() loss_accum += float(loss) * len(out) diff --git a/examples/causalml.py b/examples/causalml.py new file mode 100644 index 000000000..6ff692470 --- /dev/null +++ b/examples/causalml.py @@ -0,0 +1,181 @@ +import argparse +import copy +import os.path as osp + +import torch +from torch.optim.lr_scheduler import ExponentialLR +from tqdm import tqdm + +from torch_frame import TensorFrame, stype +from torch_frame.data import DataLoader, Dataset +from torch_frame.datasets import IHDP +from torch_frame.nn.models import CFR + +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", type=int, default=200) +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument("--epochs", type=int, default=300) +parser.add_argument("--seed", type=int, default=2) +parser.add_argument("--feature-engineering", action="store_true", default=True) +parser.add_argument("--out-of-distribution", action="store_true", default=True) +parser.add_argument("--lambda-reg", type=float, default=0.01, + help="l2 normalization score") +args = parser.parse_args() + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "ihdp") +dataset = IHDP(root=path, split_num=0) +print(dataset.get_att()) +ATT = dataset.get_att() +print(f"ATT is {ATT}") + +torch.manual_seed(args.seed) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +dataset.materialize(path=osp.join(path, "data.pt")) + +dataset = dataset.shuffle() +if args.out_of_distribution: + if dataset.split_col is None: + train_dataset, val_dataset, test_dataset = dataset[:0.62], dataset[ + 0.62:0.80], dataset[0.80:] + else: + train_dataset, _, test_dataset = dataset.split() + train_dataset, val_dataset = train_dataset[:0.775], dataset[0.775:] + # Calculating the validation dataset + treated_df = val_dataset.df[(val_dataset.df['treated'] == 1)] + treated_val_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_val_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_val_dataset.materialize(path=osp.join(path, "treated_val_data.pt")) + control_val_dataset.materialize(path=osp.join(path, "control_val_data.pt")) + # Calculating the evaluation dataset + treated_df = test_dataset.df[(test_dataset.df['treated'] == 1)] + treated_test_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_test_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_test_dataset.materialize( + path=osp.join(path, "treated_test_data.pt")) + control_test_dataset.materialize( + path=osp.join(path, "control_test_data.pt")) +else: + train_dataset = dataset + + # Calculating the evaluation dataset + treated_df = dataset.df[(dataset.df['treated'] == 1)] + treated_test_dataset = Dataset(treated_df, dataset.col_to_stype, + target_col='target') + control_df = copy.deepcopy(treated_df) + control_df['treated'] = 0 + control_test_dataset = Dataset(control_df, dataset.col_to_stype, + target_col='target') + + treated_test_dataset.materialize( + path=osp.join(path, "treated_eval_data.pt")) + control_test_dataset.materialize( + path=osp.join(path, "control_eval_data.pt")) + +train_tensor_frame = train_dataset.tensor_frame +treatment_idx = train_tensor_frame.col_names_dict[stype.categorical].index( + 'treated') +if args.out_of_distribution: + val_tensor_frame = val_dataset.tensor_frame + test_tensor_frame = test_dataset.tensor_frame + treated_val_tensor_frame = treated_val_dataset.tensor_frame + # This is a bad hack. Currently the materialization logic would override + # 1's to 0's due to 1's being the popular class + treated_val_tensor_frame.feat_dict[stype.categorical][:, + treatment_idx] = 1. + control_val_tensor_frame = control_val_dataset.tensor_frame + +treated_test_tensor_frame = treated_test_dataset.tensor_frame +# This is a bad hack. Currently the materialization logic would override 1's +# to 0's due to 1's being the popular class +treated_test_tensor_frame.feat_dict[stype.categorical][:, treatment_idx] = 1. +control_test_tensor_frame = control_test_dataset.tensor_frame + +train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, + shuffle=True) +# val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) +# test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) + +model = CFR( + channels=train_tensor_frame.num_cols - 1, + hidden_channels=200, + decoder_hidden_channels=100, + out_channels=1, + col_stats=dataset.col_stats if not args.feature_engineering else None, + col_names_dict=train_tensor_frame.col_names_dict + if not args.feature_engineering else None, +).to(device) + +optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) +lr_scheduler = ExponentialLR(optimizer, gamma=0.95) + +is_classification = True + + +def train(epoch: int) -> float: + model.train() + loss_accum = total_count = 0 + + for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): + tf = tf.to(device) + out, ipm = model(tf, treatment_index=treatment_idx) + treatment_val = tf.feat_dict[stype.categorical][:, treatment_idx] + avg_treatment = torch.sum(treatment_val) / len(treatment_val) + w_val = treatment_val / (2 * avg_treatment) + (1 - treatment_val) / ( + 2 - 2 * avg_treatment) + loss = torch.mean(w_val * (tf.y - out.squeeze(-1))) + ipm + optimizer.zero_grad() + loss.backward() + for name, param in model.named_parameters(): + if name.startswith('treatment_decoder') or name.startswith( + 'control_decoder'): + loss += args.lambda_reg * torch.sum(param**2) + loss_accum += float(loss) * len(out) + total_count += len(out) + optimizer.step() + return loss_accum / total_count + + +@torch.no_grad() +def eval(treated: TensorFrame, control: TensorFrame) -> float: + model.eval() + + treated = treated.to(device) + treated_effect, _ = model(treated, treatment_idx) + + control = control.to(device) + control_effect, _ = model(control, treatment_idx) + + return torch.abs(ATT - torch.mean(treated_effect - control_effect)) + + +best_val_metric = float('inf') +best_test_metric = float('inf') + +for epoch in range(1, args.epochs + 1): + train_loss = train(epoch) + error = eval(treated_test_tensor_frame, control_test_tensor_frame) + if args.out_of_distribution: + val_error = eval(treated_val_tensor_frame, control_val_tensor_frame) + if val_error < best_val_metric: + best_val_metric = val_error + best_test_metric = error + print( + f'Train Loss: {train_loss:.4f} Val Error_ATT: {val_error:.4f},\n' + f' Error_ATT: {error:.4f}\n') + else: + print(f'Train Loss: {train_loss:.4f} Error_ATT: {error:.4f},\n') + +if args.out_of_distribution: + print(f'Best Val Error: {best_val_metric:.4f}, ' + f'Best Test Error: {best_test_metric:.4f}') diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index 28ad64da9..052cb21f3 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -19,6 +19,7 @@ from .diamond_images import DiamondImages from .huggingface_dataset import HuggingFaceDatasetDict from .jobs import Jobs +from .ihdp import IHDP real_world_datasets = [ 'Titanic', @@ -37,7 +38,8 @@ 'Mercari', 'AmazonFineFoodReviews', 'DiamondImages', - 'jobs', + 'Jobs', + 'IHDP', ] synthetic_datasets = [ diff --git a/torch_frame/datasets/ihdp.py b/torch_frame/datasets/ihdp.py new file mode 100644 index 000000000..bf1347f0b --- /dev/null +++ b/torch_frame/datasets/ihdp.py @@ -0,0 +1,93 @@ +import os.path as osp +import zipfile + +import numpy as np +import pandas as pd + +import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM + + +class IHDP(torch_frame.data.Dataset): + train_url = 'https://www.fredjo.com/files/ihdp_npci_1-1000.train.npz.zip' + test_url = 'https://www.fredjo.com/files/ihdp_npci_1-1000.test.npz.zip' + + def __init__(self, root: str, split_num: int = 0): + train_path = self.download_url(self.train_url, root) + test_path = self.download_url(self.test_url, root) + self.split_num = split_num + folder_path = osp.dirname(train_path) + with zipfile.ZipFile(train_path, 'r') as zip_ref: + zip_ref.extractall(folder_path) + with zipfile.ZipFile(test_path, 'r') as zip_ref: + zip_ref.extractall(folder_path) + train_np = np.load(osp.join(folder_path, 'ihdp_npci_1-1000.train.npz')) + test_np = np.load(osp.join(folder_path, 'ihdp_npci_1-1000.test.npz')) + self.train_np = train_np + self.test_np = test_np + train_data = np.concatenate([ + train_np.f.t[:, split_num].reshape(-1, 1), + train_np.f.x[:, :, split_num], train_np.f.yf[:, split_num].reshape( + -1, 1) + ], axis=1) + test_data = np.concatenate([ + test_np.f.t[:, split_num].reshape(-1, 1), + test_np.f.x[:, :, split_num], test_np.f.yf[:, split_num].reshape( + -1, 1) + ], axis=1) + train_df = pd.DataFrame( + train_data, columns=['treated'] + + [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target']) + train_df['split'] = SPLIT_TO_NUM['train'] + test_df = pd.DataFrame( + test_data, columns=['treated'] + + [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target']) + test_df['split'] = SPLIT_TO_NUM['test'] + df = pd.concat([train_df, test_df], axis=0) + col_to_stype = { + 'treated': torch_frame.categorical, + 'Col_0': torch_frame.numerical, + 'Col_1': torch_frame.numerical, + 'Col_2': torch_frame.numerical, + 'Col_3': torch_frame.numerical, + 'Col_4': torch_frame.numerical, + 'Col_5': torch_frame.numerical, + 'Col_6': torch_frame.categorical, + 'Col_7': torch_frame.categorical, + 'Col_8': torch_frame.categorical, + 'Col_9': torch_frame.categorical, + 'Col_10': torch_frame.categorical, + 'Col_11': torch_frame.categorical, + 'Col_12': torch_frame.categorical, + 'Col_13': torch_frame.categorical, + 'Col_14': torch_frame.categorical, + 'Col_15': torch_frame.categorical, + 'Col_16': torch_frame.categorical, + 'Col_17': torch_frame.categorical, + 'Col_18': torch_frame.categorical, + 'Col_19': torch_frame.categorical, + 'Col_20': torch_frame.categorical, + 'Col_21': torch_frame.categorical, + 'Col_22': torch_frame.categorical, + 'Col_23': torch_frame.categorical, + 'Col_24': torch_frame.categorical, + 'target': torch_frame.numerical, + } + super().__init__(df, col_to_stype=col_to_stype, target_col='target', + split_col='split') + + def get_att(self): + r"""Obtain the ATT(true Average Treatment effect on Treated)). + + Returns: + float: The ATT score from the original randomized experiments. + """ + mu1 = np.concatenate([ + self.train_np.f.mu1[:, self.split_num], + self.test_np.f.mu1[:, self.split_num] + ], axis=0) + mu0 = np.concatenate([ + self.train_np.f.mu0[:, self.split_num], + self.test_np.f.mu0[:, self.split_num] + ], axis=0) + return np.mean(mu1) - np.mean(mu0) diff --git a/torch_frame/datasets/jobs.py b/torch_frame/datasets/jobs.py index e2f2d7fa9..9eb525ce5 100644 --- a/torch_frame/datasets/jobs.py +++ b/torch_frame/datasets/jobs.py @@ -6,28 +6,38 @@ class Jobs(torch_frame.data.Dataset): - r"""The `Jobs - `_ - dataset from Lalonde. - treatment indicator (1 if treated, 0 if not treated), age, - education, Black (1 if black, 0 otherwise), Hispanic - (1 if Hispanic, 0 otherwise), married (1 if married, 0 otherwise), - nodegree (1 if no degree, 0 otherwise), RE74 (earnings in 1974), - RE75 (earnings in 1975), and RE78 (earnings in 1978). + r"""The Jobs dataset from "Evaluating the Econometric + Evaluations of Training Programs with Experimental Data" + by Robert Lalonde. There are two versions of the data. One + version is the `Dehejia subsampe. + `_. The version + is a subsample of Lalonde's original dataset because it includes + one more feature--RE74 (earnings in 1974, two years prior the + treatment). The use of more than one year of pretreatment + earnings is key in accurately estimating the treatment effect, + because many people who volunteer for training programs experience + a drop in their earnings just prior to entering the training program. - Dehejia features: + Another version is a version containing additional columns obtained + from feature engineering, from + `Dr.Johansson's website _`. + + The target in the dataset is index to the target tensor. The target + tensor is a :obj:`Tensor` of size (num_rows, 2), where the first + column represents the target and the second column represents the + treatment. """ - lalonde_treated = 'https://users.nber.org/~rdehejia/data/nsw_treated.txt' - lalonde_control = 'https://users.nber.org/~rdehejia/data/nsw_control.txt' - psid = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' - train = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz' - test = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.test.npz' + dehejia_treated_url = 'https://users.nber.org/~rdehejia/data/nswre74_treated.txt' # noqa + dehejia_control_url = 'https://users.nber.org/~rdehejia/data/nswre74_control.txt' # noqa + psid_url = 'https://users.nber.org/~rdehejia/data/psid_controls.txt' + train_url = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.train.npz' + test_url = 'https://www.fredjo.com/files/jobs_DW_bin.new.10.test.npz' - def __init__(self, root: str, dehejia: bool = False): - if not dehejia: + def __init__(self, root: str, feature_engineering: bool = False): + if feature_engineering: split = 0 - train = self.download_url(Jobs.train, root) - test = self.download_url(Jobs.test, root) + train = self.download_url(self.train_url, root) + test = self.download_url(self.test_url, root) train_np = np.load(train) test_np = np.load(test) train_data = np.concatenate([ @@ -76,13 +86,13 @@ def __init__(self, root: str, dehejia: bool = False): split_col='split') else: # National Supported Work Demonstration - nsw_treated = self.download_url(Jobs.lalonde_treated, root) - nsw_control = self.download_url(Jobs.lalonde_control, root) + nsw_treated = self.download_url(self.dehejia_treated_url, root) + nsw_control = self.download_url(self.dehejia_control_url, root) # Population Survey of Income Dynamics - psid = self.download_url(Jobs.psid, root) + psid = self.download_url(self.psid_url, root) names = [ 'treated', 'age', 'education', 'Black', 'Hispanic', 'married', - 'nodegree', 'RE75', 'RE78' + 'nodegree', 'RE74', 'RE75', 'RE78' ] nsw_treated_df = pd.read_csv( @@ -99,12 +109,9 @@ def __init__(self, root: str, dehejia: bool = False): assert (nsw_control_df['treated'] == 0).all() nsw_control_df['source'] = 1 - names.insert(7, 'RE74') - psid_df = pd.read_csv(psid, sep='\s+', names=names) # noqa assert (psid_df['treated'] == 0).all() psid_df['source'] = 0 - psid_df = psid_df.drop('RE74', axis=1) df = pd.concat([nsw_treated_df, nsw_control_df, psid_df], axis=0) df['target'] = df['RE78'] != 0 @@ -112,19 +119,24 @@ def __init__(self, root: str, dehejia: bool = False): col_to_stype = { 'treated': torch_frame.categorical, 'age': torch_frame.numerical, - 'education': torch_frame.categorical, + 'education': torch_frame.numerical, 'Black': torch_frame.categorical, 'Hispanic': torch_frame.categorical, 'married': torch_frame.categorical, 'nodegree': torch_frame.categorical, + 'RE74': torch_frame.numerical, 'RE75': torch_frame.numerical, 'target': torch_frame.categorical, } - super().__init__(df, col_to_stype, target_col='target') self.df = df def get_att(self): + r"""Obtain the ATT(true Average Treatment effect on Treated)). + + Returns: + float: The ATT score from the original randomized experiments. + """ df = self.df[self.df['source'] == 1] treated = df[df['treated'] == 1] control = df[df['treated'] == 0] diff --git a/torch_frame/nn/decoder/__init__.py b/torch_frame/nn/decoder/__init__.py index 2c64a258e..065233d9b 100644 --- a/torch_frame/nn/decoder/__init__.py +++ b/torch_frame/nn/decoder/__init__.py @@ -3,4 +3,8 @@ from .trompt_decoder import TromptDecoder from .excelformer_decoder import ExcelFormerDecoder -__all__ = classes = ['Decoder', 'TromptDecoder', 'ExcelFormerDecoder'] +__all__ = classes = [ + 'Decoder', + 'TromptDecoder', + 'ExcelFormerDecoder', +] diff --git a/torch_frame/nn/models/__init__.py b/torch_frame/nn/models/__init__.py index 754c07402..4268e7843 100644 --- a/torch_frame/nn/models/__init__.py +++ b/torch_frame/nn/models/__init__.py @@ -7,8 +7,9 @@ from .tab_transformer import TabTransformer from .mlp import MLP from .bcauss import BCAUSS +from .cfr import CFR __all__ = classes = [ 'Trompt', 'FTTransformer', 'ExcelFormer', 'TabNet', 'ResNet', - 'TabTransformer', 'MLP', 'BCAUSS' + 'TabTransformer', 'MLP', 'BCAUSS', 'CFR' ] diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py index ea55fd5af..6bd34e73a 100644 --- a/torch_frame/nn/models/bcauss.py +++ b/torch_frame/nn/models/bcauss.py @@ -25,7 +25,6 @@ def __init__( ReLU(), Linear(hidden_channels, out_channels) ]) - self.reset_parameters() def reset_parameters(self) -> None: for block in self.model: @@ -46,9 +45,9 @@ def forward(self, x: Tensor) -> Tensor: class BalanceScoreEstimator(Module): - def __init__(self, channels, out_channels): + def __init__(self, in_channels, out_channels): super().__init__() - self.model = Linear(channels, out_channels) + self.model = Linear(in_channels, out_channels) def reset_parameters(self): self.model.reset_parameters() @@ -58,6 +57,37 @@ def forward(self, x: Tensor): class BCAUSS(Module): + r"""The BCAUSS model introduced in the + `"Learning end-to-end patient representations through self-supervised + covariate balancing for causal treatment effect estimation" + `_ + paper. + + + .. note:: + + For an example of using ExcelFormer, see `examples/bcauss.py + `_. + + Args: + in_channels (int): Input channel dimensionality + hidden_channels (int): Hidden channel dimensionality + decoder_hidden_channels (int): Hidden channel dimensionality + in decoder. + out_channels (int): Output channels dimensionality + col_stats(dict[str,dict[:class:`torch_frame.data.stats.StatType`,Any]]): + A dictionary that maps column name into stats. + Available as :obj:`dataset.col_stats`. + col_names_dict (dict[:obj:`torch_frame.stype`, list[str]]): A + dictionary that maps stype to a list of column names. The column + names are sorted based on the ordering that appear in + :obj:`tensor_frame.feat_dict`. Available as + :obj:`tensor_frame.col_names_dict`. + epsilon (float): Constant weighting factor that controls the relative + importance of balance score w.r.t. squared factual loss + (default: :obj:`0.5`). + """ def __init__( self, channels: int, @@ -66,7 +96,7 @@ def __init__( out_channels: int, col_stats: dict[str, dict[StatType, Any]] | None, col_names_dict: dict[torch_frame.stype, list[str]] | None, - epsilon: float = 0.5, + epsilon: float = 1.0, ): super().__init__() @@ -123,7 +153,7 @@ def forward(self, tf: TensorFrame, treated_mask = t == 1 treated = out[treated_mask, :] control = out[~treated_mask, :] - pred = torch.zeros((len(x), 1), dtype=torch.float32, device=x.device) + pred = torch.zeros((len(x), 1), dtype=x.dtype, device=x.device) pred[~treated_mask, :] = self.control_decoder(control) pred[treated_mask, :] = self.treatment_decoder(treated) penalty = self.balance_score_learner(out) diff --git a/torch_frame/nn/models/cfr.py b/torch_frame/nn/models/cfr.py new file mode 100644 index 000000000..5efad4f6e --- /dev/null +++ b/torch_frame/nn/models/cfr.py @@ -0,0 +1,115 @@ +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch.nn import Linear, Module, ReLU, Sequential + +import torch_frame +from torch_frame import TensorFrame, categorical, numerical +from torch_frame.data.stats import StatType + + +class MLPBlock(Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.model = Sequential(*[ + Linear(in_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, hidden_channels), + ReLU(), + Linear(hidden_channels, out_channels) + ]) + + def reset_parameters(self) -> None: + for block in self.model: + if isinstance(block, Linear): + block.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + r"""Transforming :obj:`x` into output predictions. + + Args: + x (Tensor): Input column-wise tensor of shape + [batch_size, num_cols, in_channels] + + Returns: + Tensor: [batch_size, out_channels]. + """ + return self.model(x) + + +class CFR(Module): + def __init__( + self, + channels: int, + hidden_channels: int, + decoder_hidden_channels: int, + out_channels: int, + col_stats: dict[str, dict[StatType, Any]] | None, + col_names_dict: dict[torch_frame.stype, list[str]] | None, + epsilon: float = 1.0, + ): + + super().__init__() + if col_stats is not None and col_names_dict is not None: + numerical_stats_list = [ + col_stats[col_name] + for col_name in col_names_dict[torch_frame.numerical] + ] + mean = torch.tensor( + [stats[StatType.MEAN] for stats in numerical_stats_list]) + self.register_buffer("mean", mean) + std = (torch.tensor( + [stats[StatType.STD] + for stats in numerical_stats_list]) + 1e-6) + self.register_buffer("std", std) + self.representation_learner = MLPBlock(channels, hidden_channels, + hidden_channels) + # decoder for treatment group + self.treatment_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, + out_channels) + # decoder for control group + self.control_decoder = MLPBlock(hidden_channels, + decoder_hidden_channels, out_channels) + self.epsilon = epsilon + self.reset_parameters() + + def reset_parameters(self): + self.representation_learner.reset_parameters() + self.treatment_decoder.reset_parameters() + self.control_decoder.reset_parameters() + + def forward(self, tf: TensorFrame, + treatment_index: int) -> Tuple[Tensor, Tensor]: + r"""T stands for treatment and y stands for output.""" + feat_cat = tf.feat_dict[categorical] + feat_num = tf.feat_dict[numerical] + if hasattr(self, 'mean'): + feat_num = (feat_num - self.mean) / self.std + assert isinstance(feat_cat, Tensor) + assert isinstance(feat_num, Tensor) + x = torch.cat([feat_cat, feat_num], dim=1) + t = x[:, treatment_index].clone() + # Swap the treatment col with the last column of x + x[:, treatment_index] = x[:, -1] + x[:, -1] = t + # Remove the treatment column + x = x[:, :-1] + + out = self.representation_learner(x) # batch_size, hidden_channels + treated_mask = t == 1 + treated = out[treated_mask, :] + control = out[~treated_mask, :] + pred = torch.zeros((len(x), 1), dtype=x.dtype, device=x.device) + pred[~treated_mask, :] = self.control_decoder(control) + pred[treated_mask, :] = self.treatment_decoder(treated) + treated_mean = torch.mean(treated, dim=0) + control_mean = torch.mean(control, dim=0) + ipm = 2 * torch.norm(treated_mean - control_mean, p=2) + return pred, self.epsilon * ipm From 607af163fd3574c87e08b4b4b85471ee1304e2a4 Mon Sep 17 00:00:00 2001 From: yiweny Date: Tue, 30 Apr 2024 17:32:08 +0000 Subject: [PATCH 6/9] add script --- examples/.causalml.py.swp | Bin 0 -> 16384 bytes examples/causalml.py | 287 ++++++++++++++++++++------------ torch_frame/datasets/ihdp.py | 11 +- torch_frame/nn/models/bcauss.py | 4 +- 4 files changed, 186 insertions(+), 116 deletions(-) create mode 100644 examples/.causalml.py.swp diff --git a/examples/.causalml.py.swp b/examples/.causalml.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..3bd2be04cd3bab5e4c8e3d381f7362de1e25fe99 GIT binary patch literal 16384 zcmeHNON<;x8Ez5?1Y%w`4_P28W2Bj}XL@Jsm_)0L6~_<7*oMfvb|jWnr>DDmr|q8Z zc6W{K#o8b)JPs)M0!Tzi1PKBt2#X^Jf_R)DNC<}u1O%LL4v1H*T*CKP^?PR4yWxN$ zbxYq&b=6=0yQ}N3nPKnTiPP)>y|3YVr>6aLci%aGZKw9ahr8Okdzt&xf1AqEK<$Cr1GNWg57Zv0J@ESV zfC%r@J`MHnNK8m8f;S}h>G!M2^;pt>t}=dSa(^=Uezr26+}2;U2Wk)09;iJ~d!Y6} z?Sa|@wFhbs)E=lkPE!2iJmrmbmjgj_*|K7RjC>;J#ML(~2WJP-T|cnb^cH0=+-?}2N;DsU9& z0Pg{I0WZH*)2;(Q0L}w0a2nVJym+^!eIHl{+Q3V1(X^Yuw*U`dz+c`B{eaH`Cx9Km zJ;0yu(zLGtUj|MA`+%Eo(zLGwp8@*7OK{9Hz_)=lpaI+q+{9tWH-YoOY2aPJPjG

4#4|#_vmm26Cr0(5JFw8%y__Ssj<8(CAVXJ|a7Ht~Wgkq*EXkr?z z*-_2-r)0DFOvSY5=~75%1o5&fMy_wf^4~k#3Fq|0YGBJrw8b#zIWRM{`S6xnGn1z3 z#hI4O&y))JHgl#rd$UzxidG||ZgNV-dN-iEsm-jXByD=+EIjiJYT8 z<*#&>5;1euOz9nF;Eh>$;X}c*?I3CAXI#_yrmsk08t3Yh#?F%%U zUve!D5%LeD*o)9@t6@!Sv%y>hp8i*=)2;5thIz?!J#$DJFz&^iHLP&bXlsgUZ2GkK ziaEDiY8!$UR1G877cHl8czKz;b`{?nm_)a^(b%pTrVubqk*Chc`iuTJ@yO9fTTu{* zfig1N?+nV@A27py2~x8C6E$o%^6@{}7Pav@_lTxt7>?_4!)UiVtf}iwe3Au|V4jSHm<_FQ3_GK#HVk?L^)R-g6Gsi>Q>_&VomF0MWS(K*t`f5*Aea=#Sj++SS zTme`Ox`ERzJ{ORL{WczkK_f;Aa6?2B-ssFz(ZRlX67eqg*Ib`t)9I70vp8ffS_Z?} zUb033GBqBoG=}sy=NOzCmJ@Zyw%u5rZA7!%9S1h|5K|cpJm7gSn#ttmT=glhQ{R^2 zxPP`d#Gi+OH6qofgMXmcL;U9jGfL-72eH?mJ8ftRYh=XO$*0rHdgc<5#>i7CT|b;8 z?4vS=X>D<$V62{eL;j;&_kCS@E@tjH3?jiEK63c%(o%;VIeuoEI(c+6VtQIE;V z;z>Idm^Md1sExhE)P$6Fnfq}N9jAY~cbF_aQiU(tW5#7a6fN;`YSj zQA&C2N*MTfTu}j?I-@1#$@L7?u`)J@@vvhVF=JXs(q1ar(st@gjq3j!s82tL8kXw+ zlSz$w9ku;az$rin{)C$TS>Q?FGOz^f2A)S<|1|J8Z~)i|yn@>P1>gqoTi~a_2H*h~ zfR6&ZfE%dqzXUu9JPv#ScolX14}lHf5O4?Zo4YmbJHRu*2)G~k2yhU1A8-qG{>#9x zf$Km3^nm}M&i@gx0Xzh}f*SvOz;}Vq0VePu@P6RmsQ3Q{d=2;n&;|a1TK~tu=YbaR zNBHPP;CDa_d{CO$at6Oi$aK;WI&5g`pslZ zl5csPrZhE~A&ty5>XhA3@&HK&zx<7sFgmV|t=O5&Ml-(lsF&zMiDGCCP2 z_h*G#9*e9?>X_XnzT(iWm6<2a-&B_=3Z;VVRtjWIl?q9@l?a!EGewh+HkDi@P=cvf z$xU0SSh`Q7QgfzIm0z2RohC^vmPZ4L$yqxUu4bi1%i$>n6b(HBlxsxOk{j2y<^&Ow zO@=2fGW|8)>SM=FN6cz#%l5H;8_ATXz7z#Ma^u=+y1s+uuAF?;QI(4kM{xit^z71F z)=^Elu65YZHBo*FRiNVJkNl$(Cij4w{ih_|@o7<90K!+I2*ZJdWo z1}RUnT9&!yFm6Fckya9BeJSCXM5J7GNZ}l(=iNRyGhFG&ieQIWkA5fUR=e09=CrPK z;5{fMTS+x2$UEwFpi{wQ&~eO)+~?{R6j=*!5@C9)D~m{II*g`T({Sx8>P#h-11^<^ z7b6@8JCLTh8>&MpOiIslUYU))KK+Y_Vnl8wJ~DtZYMEbUaO7Z_S)u~YLO?1DjLJ^| z6%&^5dI8w;R3!E{mmJjM{v%S;AilzFL z3m2xyWl0@_7&r#OG-XLV8RKU&)QfxfqEau;EVl7OlP@Ctpk;FlhXhd*%UW@q+IXeX z!4WJoBN5ZvuB8rcOM3*)3-S4 zW}NTj4D(57JyD#LyIrll!jVe`LS0%X+ikX>mAKH&zCy^RM$ float: for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): tf = tf.to(device) - out, ipm = model(tf, treatment_index=treatment_idx) - treatment_val = tf.feat_dict[stype.categorical][:, treatment_idx] - avg_treatment = torch.sum(treatment_val) / len(treatment_val) - w_val = treatment_val / (2 * avg_treatment) + (1 - treatment_val) / ( - 2 - 2 * avg_treatment) - loss = torch.mean(w_val * (tf.y - out.squeeze(-1))) + ipm + if args.model == 'cfr-mdd': + out, ipm = model(tf, treatment_index=treatment_idx) + treatment_val = tf.feat_dict[stype.categorical][:, treatment_idx] + avg_treatment = torch.sum(treatment_val) / len(treatment_val) + w_val = treatment_val / (2 * avg_treatment) + ( + 1 - treatment_val) / (2 - 2 * avg_treatment) + rmse = torch.sqrt(torch.mean(torch.square(tf.y - out.squeeze(-1)))) + loss = torch.mean(w_val * rmse) + 0.3 * ipm + else: + out, balance_score = model(tf, treatment_index=treatment_idx) + treated_mask = tf.feat_dict[stype.categorical][:, + treatment_idx] == 1 + loss = ( + (torch.sum(treated_mask * torch.square(tf.y - out.squeeze(-1))) + + torch.sum( + ~treated_mask * torch.square(tf.y - out.squeeze(-1)))) / + len(treated_mask) + balance_score) optimizer.zero_grad() loss.backward() - for name, param in model.named_parameters(): - if name.startswith('treatment_decoder') or name.startswith( - 'control_decoder'): - loss += args.lambda_reg * torch.sum(param**2) + if args.model == 'cfr-mdd': + for name, param in model.named_parameters(): + if name.startswith('treatment_decoder') or name.startswith( + 'control_decoder'): + loss += args.lambda_reg * torch.sum(param**2) loss_accum += float(loss) * len(out) total_count += len(out) optimizer.step() @@ -147,35 +190,59 @@ def train(epoch: int) -> float: @torch.no_grad() -def eval(treated: TensorFrame, control: TensorFrame) -> float: +def eval(factual: TensorFrame, counterfactual: TensorFrame) -> float: model.eval() - treated = treated.to(device) - treated_effect, _ = model(treated, treatment_idx) + factual = factual.to(device) + factual_effect, _ = model(factual, treatment_idx) + rmse_fact = torch.sqrt(torch.mean(torch.square(factual.y - + factual_effect))) - control = control.to(device) - control_effect, _ = model(control, treatment_idx) + counterfactual = counterfactual.to(device) + counterfactual_effect, _ = model(counterfactual, treatment_idx) + rmse_cfact = torch.sqrt( + torch.mean(torch.square(counterfactual.y - counterfactual_effect))) + eff_pred = counterfactual_effect - factual_effect + t = factual.feat_dict[stype.categorical][:, treatment_idx] + eff_pred[t > 0] = -eff_pred[t > 0] # f(x, 1) - f(x, 0) + ate_pred = torch.mean(eff_pred.squeeze(-1)) + bias_ate = torch.abs(ate_pred - ATT) - return torch.abs(ATT - torch.mean(treated_effect - control_effect)) + pehe = torch.sqrt(torch.mean(torch.square(eff_pred - ATT))) + return rmse_fact, rmse_cfact, bias_ate, pehe -best_val_metric = float('inf') -best_test_metric = float('inf') +best_val_error = float('inf') +best_test_error = float('inf') for epoch in range(1, args.epochs + 1): train_loss = train(epoch) - error = eval(treated_test_tensor_frame, control_test_tensor_frame) - if args.out_of_distribution: - val_error = eval(treated_val_tensor_frame, control_val_tensor_frame) - if val_error < best_val_metric: - best_val_metric = val_error - best_test_metric = error - print( - f'Train Loss: {train_loss:.4f} Val Error_ATT: {val_error:.4f},\n' - f' Error_ATT: {error:.4f}\n') - else: - print(f'Train Loss: {train_loss:.4f} Error_ATT: {error:.4f},\n') - -if args.out_of_distribution: - print(f'Best Val Error: {best_val_metric:.4f}, ' - f'Best Test Error: {best_test_metric:.4f}') + train_rmse, train_rmse_cfact, train_error, train_pehe = eval( + train_tensor_frame, counterfactual_train_tensor_frame) + val_rmse, val_rmse_cfact, val_error, val_pehe = eval( + val_tensor_frame, counterfactual_val_tensor_frame) + test_rmse, test_rmse_cfact, test_error, test_pehe = eval( + test_tensor_frame, counterfactual_test_tensor_frame) + within_rmse, within_rmse_cfact, within_error, within_pehe = eval( + within_sample_tensor_frame, counterfactual_within_sample_tensor_frame) + + if val_error < best_val_error: + best_val_error = within_error + best_test_error = test_error + best_val_pehe = within_pehe + best_test_pehe = test_pehe + print( + f'Train Loss: {train_loss:.4f} Train Factual RMSE: {train_rmse:.4f} ' + f'Train Counterfactual RMSE: {train_rmse_cfact:.4f}, \n' + f'Val Factual RMSE: {val_rmse:.4f} ' + f'Val Counterfactual RMSE: {val_rmse_cfact:.4f}, ' + f'Within Sample PEHE: {within_pehe:.4f}, ' + f'Within Sample Error: {within_error:.4f}, \n' + f'Test Factual RMSE: {val_rmse:.4f} ' + f'Test Counterfactual RMSE: {val_rmse_cfact:.4f}, ' + f'Test PEHE: {test_pehe:.4f}, Test Error: {test_error:.4f}, \n') + +print(f'Best Within Sample Error: {best_val_error:.4f}, ' + f'Best Within Sample PEHE: {best_val_pehe}, \n' + f'Best Test Error: {best_test_error:.4f}, ' + f'Best Test PEHE: {best_test_pehe:.4f}') diff --git a/torch_frame/datasets/ihdp.py b/torch_frame/datasets/ihdp.py index bf1347f0b..6dd8f7d0d 100644 --- a/torch_frame/datasets/ihdp.py +++ b/torch_frame/datasets/ihdp.py @@ -9,6 +9,7 @@ class IHDP(torch_frame.data.Dataset): + r"""Counterfactual target is generated with knn.""" train_url = 'https://www.fredjo.com/files/ihdp_npci_1-1000.train.npz.zip' test_url = 'https://www.fredjo.com/files/ihdp_npci_1-1000.test.npz.zip' @@ -28,20 +29,22 @@ def __init__(self, root: str, split_num: int = 0): train_data = np.concatenate([ train_np.f.t[:, split_num].reshape(-1, 1), train_np.f.x[:, :, split_num], train_np.f.yf[:, split_num].reshape( - -1, 1) + -1, 1), train_np.f.ycf[:, split_num].reshape(-1, 1) ], axis=1) test_data = np.concatenate([ test_np.f.t[:, split_num].reshape(-1, 1), test_np.f.x[:, :, split_num], test_np.f.yf[:, split_num].reshape( - -1, 1) + -1, 1), test_np.f.ycf[:, split_num].reshape(-1, 1) ], axis=1) train_df = pd.DataFrame( train_data, columns=['treated'] + - [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target']) + [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target'] + + ['counterfactual_target']) train_df['split'] = SPLIT_TO_NUM['train'] test_df = pd.DataFrame( test_data, columns=['treated'] + - [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target']) + [f'Col_{i}' for i in range(train_np.f.x.shape[1])] + ['target'] + + ['counterfactual_target']) test_df['split'] = SPLIT_TO_NUM['test'] df = pd.concat([train_df, test_df], axis=0) col_to_stype = { diff --git a/torch_frame/nn/models/bcauss.py b/torch_frame/nn/models/bcauss.py index 6bd34e73a..ff66d42eb 100644 --- a/torch_frame/nn/models/bcauss.py +++ b/torch_frame/nn/models/bcauss.py @@ -133,7 +133,7 @@ def reset_parameters(self): self.control_decoder.reset_parameters() def forward(self, tf: TensorFrame, - treatment_index: int) -> Tuple[Tensor, Tensor, Tensor]: + treatment_index: int) -> Tuple[Tensor, Tensor]: r"""T stands for treatment and y stands for output.""" feat_cat = tf.feat_dict[categorical] feat_num = tf.feat_dict[numerical] @@ -165,4 +165,4 @@ def forward(self, tf: TensorFrame, torch.sum(treated_weight + 0.01) - torch.sum(control_weight * x, dim=0) / (torch.sum(control_weight + 0.01)))) - return pred, self.epsilon * balance_score, treated_mask + return pred, self.epsilon * balance_score From 30f35ec6dd8b1c7245d271c71e3bc116b59a773e Mon Sep 17 00:00:00 2001 From: yiweny Date: Tue, 30 Apr 2024 17:46:19 +0000 Subject: [PATCH 7/9] clean up script --- examples/causalml.py | 34 ++++++++++++++++++---------------- torch_frame/nn/models/cfr.py | 2 +- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/causalml.py b/examples/causalml.py index 2a19384c2..a388ed0fd 100644 --- a/examples/causalml.py +++ b/examples/causalml.py @@ -28,9 +28,8 @@ path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', "ihdp") dataset = IHDP(root=path, split_num=args.split_num) -print(dataset.get_att()) -ATT = dataset.get_att() -print(f"ATT is {ATT}") +ATE = dataset.get_att() +print(f"True Average Treatment Effect is {ATE}") torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -39,8 +38,8 @@ dataset = dataset.shuffle() within_sample_dataset, _, test_dataset = dataset.split() -# Validation is to compute within distribution metric -# Test is to compute out of distribution metric +# Train and Validation set is to compute within distribution metric +# Test set is to compute out of distribution metric train_dataset, val_dataset = within_sample_dataset[: 0.7], within_sample_dataset[ 0.7:] @@ -150,8 +149,6 @@ optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) lr_scheduler = ExponentialLR(optimizer, gamma=0.8) -is_classification = True - def train(epoch: int) -> float: model.train() @@ -166,7 +163,7 @@ def train(epoch: int) -> float: w_val = treatment_val / (2 * avg_treatment) + ( 1 - treatment_val) / (2 - 2 * avg_treatment) rmse = torch.sqrt(torch.mean(torch.square(tf.y - out.squeeze(-1)))) - loss = torch.mean(w_val * rmse) + 0.3 * ipm + loss = torch.mean(w_val * rmse) + ipm else: out, balance_score = model(tf, treatment_index=treatment_idx) treated_mask = tf.feat_dict[stype.categorical][:, @@ -190,30 +187,34 @@ def train(epoch: int) -> float: @torch.no_grad() -def eval(factual: TensorFrame, counterfactual: TensorFrame) -> float: +def eval(factual: TensorFrame, counterfactual: TensorFrame): model.eval() factual = factual.to(device) factual_effect, _ = model(factual, treatment_idx) + # RMSE for factual predictions rmse_fact = torch.sqrt(torch.mean(torch.square(factual.y - factual_effect))) counterfactual = counterfactual.to(device) counterfactual_effect, _ = model(counterfactual, treatment_idx) + # RMSE for counterfactual predictions rmse_cfact = torch.sqrt( torch.mean(torch.square(counterfactual.y - counterfactual_effect))) eff_pred = counterfactual_effect - factual_effect t = factual.feat_dict[stype.categorical][:, treatment_idx] eff_pred[t > 0] = -eff_pred[t > 0] # f(x, 1) - f(x, 0) ate_pred = torch.mean(eff_pred.squeeze(-1)) - bias_ate = torch.abs(ate_pred - ATT) + bias_ate = torch.abs(ate_pred - ATE) - pehe = torch.sqrt(torch.mean(torch.square(eff_pred - ATT))) + pehe = torch.sqrt(torch.mean(torch.square(eff_pred - ATE))) return rmse_fact, rmse_cfact, bias_ate, pehe best_val_error = float('inf') best_test_error = float('inf') +best_val_pehe = float('inf') +best_test_pehe = float('inf') for epoch in range(1, args.epochs + 1): train_loss = train(epoch) @@ -238,11 +239,12 @@ def eval(factual: TensorFrame, counterfactual: TensorFrame) -> float: f'Val Counterfactual RMSE: {val_rmse_cfact:.4f}, ' f'Within Sample PEHE: {within_pehe:.4f}, ' f'Within Sample Error: {within_error:.4f}, \n' - f'Test Factual RMSE: {val_rmse:.4f} ' - f'Test Counterfactual RMSE: {val_rmse_cfact:.4f}, ' - f'Test PEHE: {test_pehe:.4f}, Test Error: {test_error:.4f}, \n') + f'Out of Distribution Factual RMSE: {val_rmse:.4f} ' + f'Out of Distribution Counterfactual RMSE: {val_rmse_cfact:.4f}, ' + f'Out of Distribution PEHE: {test_pehe:.4f}, ' + f'Out of Distribution Error: {test_error:.4f}, \n') print(f'Best Within Sample Error: {best_val_error:.4f}, ' f'Best Within Sample PEHE: {best_val_pehe}, \n' - f'Best Test Error: {best_test_error:.4f}, ' - f'Best Test PEHE: {best_test_pehe:.4f}') + f'Best Out of Distribution Error: {best_test_error:.4f}, ' + f'Best Out of Distribution PEHE: {best_test_pehe:.4f}') diff --git a/torch_frame/nn/models/cfr.py b/torch_frame/nn/models/cfr.py index 5efad4f6e..205b12df3 100644 --- a/torch_frame/nn/models/cfr.py +++ b/torch_frame/nn/models/cfr.py @@ -52,7 +52,7 @@ def __init__( out_channels: int, col_stats: dict[str, dict[StatType, Any]] | None, col_names_dict: dict[torch_frame.stype, list[str]] | None, - epsilon: float = 1.0, + epsilon: float = 0.3, ): super().__init__() From a2b79d7f0c80b145856e8d5a85e20346742f42b3 Mon Sep 17 00:00:00 2001 From: yiweny Date: Tue, 30 Apr 2024 17:56:58 +0000 Subject: [PATCH 8/9] remove unnecessary files --- examples/.causalml.py.swp | Bin 16384 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 examples/.causalml.py.swp diff --git a/examples/.causalml.py.swp b/examples/.causalml.py.swp deleted file mode 100644 index 3bd2be04cd3bab5e4c8e3d381f7362de1e25fe99..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeHNON<;x8Ez5?1Y%w`4_P28W2Bj}XL@Jsm_)0L6~_<7*oMfvb|jWnr>DDmr|q8Z zc6W{K#o8b)JPs)M0!Tzi1PKBt2#X^Jf_R)DNC<}u1O%LL4v1H*T*CKP^?PR4yWxN$ zbxYq&b=6=0yQ}N3nPKnTiPP)>y|3YVr>6aLci%aGZKw9ahr8Okdzt&xf1AqEK<$Cr1GNWg57Zv0J@ESV zfC%r@J`MHnNK8m8f;S}h>G!M2^;pt>t}=dSa(^=Uezr26+}2;U2Wk)09;iJ~d!Y6} z?Sa|@wFhbs)E=lkPE!2iJmrmbmjgj_*|K7RjC>;J#ML(~2WJP-T|cnb^cH0=+-?}2N;DsU9& z0Pg{I0WZH*)2;(Q0L}w0a2nVJym+^!eIHl{+Q3V1(X^Yuw*U`dz+c`B{eaH`Cx9Km zJ;0yu(zLGtUj|MA`+%Eo(zLGwp8@*7OK{9Hz_)=lpaI+q+{9tWH-YoOY2aPJPjG

4#4|#_vmm26Cr0(5JFw8%y__Ssj<8(CAVXJ|a7Ht~Wgkq*EXkr?z z*-_2-r)0DFOvSY5=~75%1o5&fMy_wf^4~k#3Fq|0YGBJrw8b#zIWRM{`S6xnGn1z3 z#hI4O&y))JHgl#rd$UzxidG||ZgNV-dN-iEsm-jXByD=+EIjiJYT8 z<*#&>5;1euOz9nF;Eh>$;X}c*?I3CAXI#_yrmsk08t3Yh#?F%%U zUve!D5%LeD*o)9@t6@!Sv%y>hp8i*=)2;5thIz?!J#$DJFz&^iHLP&bXlsgUZ2GkK ziaEDiY8!$UR1G877cHl8czKz;b`{?nm_)a^(b%pTrVubqk*Chc`iuTJ@yO9fTTu{* zfig1N?+nV@A27py2~x8C6E$o%^6@{}7Pav@_lTxt7>?_4!)UiVtf}iwe3Au|V4jSHm<_FQ3_GK#HVk?L^)R-g6Gsi>Q>_&VomF0MWS(K*t`f5*Aea=#Sj++SS zTme`Ox`ERzJ{ORL{WczkK_f;Aa6?2B-ssFz(ZRlX67eqg*Ib`t)9I70vp8ffS_Z?} zUb033GBqBoG=}sy=NOzCmJ@Zyw%u5rZA7!%9S1h|5K|cpJm7gSn#ttmT=glhQ{R^2 zxPP`d#Gi+OH6qofgMXmcL;U9jGfL-72eH?mJ8ftRYh=XO$*0rHdgc<5#>i7CT|b;8 z?4vS=X>D<$V62{eL;j;&_kCS@E@tjH3?jiEK63c%(o%;VIeuoEI(c+6VtQIE;V z;z>Idm^Md1sExhE)P$6Fnfq}N9jAY~cbF_aQiU(tW5#7a6fN;`YSj zQA&C2N*MTfTu}j?I-@1#$@L7?u`)J@@vvhVF=JXs(q1ar(st@gjq3j!s82tL8kXw+ zlSz$w9ku;az$rin{)C$TS>Q?FGOz^f2A)S<|1|J8Z~)i|yn@>P1>gqoTi~a_2H*h~ zfR6&ZfE%dqzXUu9JPv#ScolX14}lHf5O4?Zo4YmbJHRu*2)G~k2yhU1A8-qG{>#9x zf$Km3^nm}M&i@gx0Xzh}f*SvOz;}Vq0VePu@P6RmsQ3Q{d=2;n&;|a1TK~tu=YbaR zNBHPP;CDa_d{CO$at6Oi$aK;WI&5g`pslZ zl5csPrZhE~A&ty5>XhA3@&HK&zx<7sFgmV|t=O5&Ml-(lsF&zMiDGCCP2 z_h*G#9*e9?>X_XnzT(iWm6<2a-&B_=3Z;VVRtjWIl?q9@l?a!EGewh+HkDi@P=cvf z$xU0SSh`Q7QgfzIm0z2RohC^vmPZ4L$yqxUu4bi1%i$>n6b(HBlxsxOk{j2y<^&Ow zO@=2fGW|8)>SM=FN6cz#%l5H;8_ATXz7z#Ma^u=+y1s+uuAF?;QI(4kM{xit^z71F z)=^Elu65YZHBo*FRiNVJkNl$(Cij4w{ih_|@o7<90K!+I2*ZJdWo z1}RUnT9&!yFm6Fckya9BeJSCXM5J7GNZ}l(=iNRyGhFG&ieQIWkA5fUR=e09=CrPK z;5{fMTS+x2$UEwFpi{wQ&~eO)+~?{R6j=*!5@C9)D~m{II*g`T({Sx8>P#h-11^<^ z7b6@8JCLTh8>&MpOiIslUYU))KK+Y_Vnl8wJ~DtZYMEbUaO7Z_S)u~YLO?1DjLJ^| z6%&^5dI8w;R3!E{mmJjM{v%S;AilzFL z3m2xyWl0@_7&r#OG-XLV8RKU&)QfxfqEau;EVl7OlP@Ctpk;FlhXhd*%UW@q+IXeX z!4WJoBN5ZvuB8rcOM3*)3-S4 zW}NTj4D(57JyD#LyIrll!jVe`LS0%X+ikX>mAKH&zCy^RM$ Date: Sun, 12 May 2024 08:35:23 +0000 Subject: [PATCH 9/9] on pehe --- examples/causalml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/causalml.py b/examples/causalml.py index a388ed0fd..e8e80a5be 100644 --- a/examples/causalml.py +++ b/examples/causalml.py @@ -227,7 +227,7 @@ def eval(factual: TensorFrame, counterfactual: TensorFrame): within_rmse, within_rmse_cfact, within_error, within_pehe = eval( within_sample_tensor_frame, counterfactual_within_sample_tensor_frame) - if val_error < best_val_error: + if within_pehe < best_val_pehe: best_val_error = within_error best_test_error = test_error best_val_pehe = within_pehe