From a12e00923d45cd9da1c85d78fd6be16c3518692e Mon Sep 17 00:00:00 2001 From: xnuohz Date: Mon, 11 Dec 2023 23:05:01 +0800 Subject: [PATCH 1/6] update --- examples/tuned_gbdt.py | 6 ++++++ test/gbdt/test_gbdt.py | 14 +++++++++++++- torch_frame/gbdt/gbdt.py | 17 +++++++++++++++++ torch_frame/gbdt/tuned_catboost.py | 4 ++++ torch_frame/gbdt/tuned_lightgbm.py | 4 ++++ torch_frame/gbdt/tuned_xgboost.py | 4 ++++ 6 files changed, 48 insertions(+), 1 deletion(-) diff --git a/examples/tuned_gbdt.py b/examples/tuned_gbdt.py index f414dc7c0..f6d9a0a77 100644 --- a/examples/tuned_gbdt.py +++ b/examples/tuned_gbdt.py @@ -27,6 +27,7 @@ import random import numpy as np +import pandas as pd import torch from torch_frame.datasets import TabularBenchmark @@ -88,6 +89,11 @@ gbdt.tune(tf_train=train_dataset.tensor_frame, tf_val=val_dataset.tensor_frame, num_trials=20) gbdt.save(args.saved_model_path) + scores = pd.DataFrame({ + 'feature': dataset.feat_cols, + 'importance': gbdt.feature_importance() + }).sort_values(by='importance', ascending=False) + print(scores) pred = gbdt.predict(tf_test=test_dataset.tensor_frame) score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred) diff --git a/test/gbdt/test_gbdt.py b/test/gbdt/test_gbdt.py index 42787976c..0cb28f677 100644 --- a/test/gbdt/test_gbdt.py +++ b/test/gbdt/test_gbdt.py @@ -21,7 +21,7 @@ [stype.numerical], [stype.categorical], [stype.text_embedded], - [stype.numerical, stype.numerical, stype.text_embedded], + [stype.numerical, stype.categorical, stype.text_embedded], ]) @pytest.mark.parametrize('task_type_and_metric', [ (TaskType.REGRESSION, Metric.RMSE), @@ -76,7 +76,19 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): loaded_score = loaded_gbdt.compute_metric(dataset.tensor_frame.y, pred) dataset.tensor_frame.y = None loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame) + # TODO: support more stypes + num_features = 0 + for x in stypes: + if x == stype.numerical: + num_features += 3 * 1 + elif x == stype.categorical: + num_features += 2 * 1 + elif x == stype.text_embedded: + num_features += 2 * 8 + assert (gbdt_cls == XGBoost + and len(gbdt.feature_importance()) <= num_features) or (len( + gbdt.feature_importance()) == num_features) assert torch.allclose(pred, loaded_pred, atol=1e-5) assert gbdt.metric == metric assert score == loaded_score diff --git a/torch_frame/gbdt/gbdt.py b/torch_frame/gbdt/gbdt.py index b2aafc5c5..8802eec2e 100644 --- a/torch_frame/gbdt/gbdt.py +++ b/torch_frame/gbdt/gbdt.py @@ -63,6 +63,10 @@ def _predict(self, tf_train: TensorFrame) -> Tensor: def _load(self, path: str) -> None: raise NotImplementedError + @abstractmethod + def _feature_importance(self) -> list: + raise NotImplementedError + @property def is_fitted(self) -> bool: r"""Whether the GBDT is already fitted.""" @@ -135,6 +139,19 @@ def load(self, path: str) -> None: self._load(path) self._is_fitted = True + def feature_importance(self) -> list: + r"""Get GBDT's feature importance. + + Returns: + scores (list): Feature importance. + """ + if not self.is_fitted: + raise RuntimeError( + f"{self.__class__.__name__} is not yet fitted. Please run " + f"`tune()` first before attempting to get feature importance.") + scores = self._feature_importance() + return scores + @torch.no_grad() def compute_metric( self, diff --git a/torch_frame/gbdt/tuned_catboost.py b/torch_frame/gbdt/tuned_catboost.py index 125d410b5..7d1c3ed8e 100644 --- a/torch_frame/gbdt/tuned_catboost.py +++ b/torch_frame/gbdt/tuned_catboost.py @@ -225,3 +225,7 @@ def _load(self, path: str) -> None: self.model = catboost.CatBoost() self.model.load_model(path) + + def _feature_importance(self) -> list: + scores = self.model.feature_importances_ + return scores diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index 94ade576f..9c8f3fff8 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -226,3 +226,7 @@ def _load(self, path: str) -> None: import lightgbm self.model = lightgbm.Booster(model_file=path) + + def _feature_importance(self) -> list: + scores = self.model.feature_importance(importance_type='gain') + return scores.tolist() diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index b11ef29ef..7a15078e0 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -232,3 +232,7 @@ def _load(self, path: str) -> None: import xgboost self.model = xgboost.Booster(model_file=path) + + def _feature_importance(self) -> list: + scores = self.model.get_score(importance_type='weight') + return list(scores.values()) From 05028aa5e441d9a489f85c04f5a141ba742fe029 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 15 Dec 2023 21:47:00 +0800 Subject: [PATCH 2/6] update --- torch_frame/gbdt/gbdt.py | 6 +++--- torch_frame/gbdt/tuned_lightgbm.py | 23 ++++++++++++++++++++--- torch_frame/gbdt/tuned_xgboost.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/torch_frame/gbdt/gbdt.py b/torch_frame/gbdt/gbdt.py index 8802eec2e..307665fd0 100644 --- a/torch_frame/gbdt/gbdt.py +++ b/torch_frame/gbdt/gbdt.py @@ -64,7 +64,7 @@ def _load(self, path: str) -> None: raise NotImplementedError @abstractmethod - def _feature_importance(self) -> list: + def _feature_importance(self, *args, **kwargs) -> list: raise NotImplementedError @property @@ -139,7 +139,7 @@ def load(self, path: str) -> None: self._load(path) self._is_fitted = True - def feature_importance(self) -> list: + def feature_importance(self, *args, **kwargs) -> list: r"""Get GBDT's feature importance. Returns: @@ -149,7 +149,7 @@ def feature_importance(self) -> list: raise RuntimeError( f"{self.__class__.__name__} is not yet fitted. Please run " f"`tune()` first before attempting to get feature importance.") - scores = self._feature_importance() + scores = self._feature_importance(*args, **kwargs) return scores @torch.no_grad() diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index 9c8f3fff8..ef26106a2 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional import numpy as np import pandas as pd @@ -227,6 +227,23 @@ def _load(self, path: str) -> None: self.model = lightgbm.Booster(model_file=path) - def _feature_importance(self) -> list: - scores = self.model.feature_importance(importance_type='gain') + def _feature_importance( + self, + importance_type: str = 'gain', + iteration: Optional[int] = None + ) -> list: + r"""Get feature importances. + + Args: + importance_type (str): How the importance is calculated. + If "split", result contains numbers of times the feature is used in a model. + If "gain", result contains total gains of splits which use the feature. + iteration (int, optional): Limit number of iterations in the feature importance calculation. + If None, if the best iteration exists, it is used; otherwise, all trees are used. + If <= 0, all trees are used (no limits). + + Returns: + list: Array with feature importances. + """ + scores = self.model.feature_importance(importance_type=importance_type, iteration=iteration) return scores.tolist() diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 7a15078e0..20fa4a7db 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -233,6 +233,31 @@ def _load(self, path: str) -> None: self.model = xgboost.Booster(model_file=path) - def _feature_importance(self) -> list: - scores = self.model.get_score(importance_type='weight') + def _feature_importance(self, importance_type: str = 'weight') -> list: + r"""Get feature importances. + + Args: + importance_type (str): How the importance is calculated. + For tree model Importance type can be defined as: + + * 'weight': the number of times a feature is used to split the data across all trees. + * 'gain': the average gain across all splits the feature is used in. + * 'cover': the average coverage across all splits the feature is used in. + * 'total_gain': the total gain across all splits the feature is used in. + * 'total_cover': the total coverage across all splits the feature is used in. + + .. note:: + + For linear model, only "weight" is defined and it's the normalized coefficients + without bias. + + .. note:: Zero-importance features will not be included + + Keep in mind that this function does not include zero-importance feature, i.e. + those features that have not been used in any split conditions. + + Returns: + list: Array with feature importances. + """ + scores = self.model.get_score(importance_type=importance_type) return list(scores.values()) From 1c6ef5a9556a9fdc26792f984a7814ad06521fcd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Dec 2023 13:48:58 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_frame/gbdt/tuned_lightgbm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index ef26106a2..717efddf3 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -227,11 +227,8 @@ def _load(self, path: str) -> None: self.model = lightgbm.Booster(model_file=path) - def _feature_importance( - self, - importance_type: str = 'gain', - iteration: Optional[int] = None - ) -> list: + def _feature_importance(self, importance_type: str = 'gain', + iteration: Optional[int] = None) -> list: r"""Get feature importances. Args: @@ -245,5 +242,6 @@ def _feature_importance( Returns: list: Array with feature importances. """ - scores = self.model.feature_importance(importance_type=importance_type, iteration=iteration) + scores = self.model.feature_importance(importance_type=importance_type, + iteration=iteration) return scores.tolist() From 4f3ce7cba93a198f7658f29a91201b0cb5fd2a0b Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 15 Dec 2023 21:54:57 +0800 Subject: [PATCH 4/6] update --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16ffe5de8..1016bd7ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added GBDTs feature importance ([#292](https://github.com/pyg-team/pytorch-frame/pull/292)) + ### Changed ### Deprecated From 6d800330b0797ee24d0e53ae0ae6cb8e0135f4e8 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 15 Dec 2023 23:22:56 +0800 Subject: [PATCH 5/6] update --- torch_frame/gbdt/tuned_lightgbm.py | 15 ++++++++++----- torch_frame/gbdt/tuned_xgboost.py | 27 ++++++++++++++++++--------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index f58bb2bf0..f6333831c 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -233,15 +233,20 @@ def _feature_importance(self, importance_type: str = 'gain', Args: importance_type (str): How the importance is calculated. - If "split", result contains numbers of times the feature is used in a model. - If "gain", result contains total gains of splits which use the feature. - iteration (int, optional): Limit number of iterations in the feature importance calculation. - If None, if the best iteration exists, it is used; otherwise, all trees are used. - If <= 0, all trees are used (no limits). + If "split", result contains numbers of times the feature + is used in a model. If "gain", result contains total gains + of splits which use the feature. + iteration (int, optional): Limit number of iterations in the feature + importance calculation. If None, if the best iteration exists, + it is used; otherwise, all trees are used. If <= 0, all trees + are used (no limits). Returns: list: Array with feature importances. """ + assert importance_type in [ + 'split', 'gain' + ], f'Expect split or gain, got {importance_type}.' scores = self.model.feature_importance(importance_type=importance_type, iteration=iteration) return scores.tolist() diff --git a/torch_frame/gbdt/tuned_xgboost.py b/torch_frame/gbdt/tuned_xgboost.py index 5e442620c..19e28199e 100644 --- a/torch_frame/gbdt/tuned_xgboost.py +++ b/torch_frame/gbdt/tuned_xgboost.py @@ -240,24 +240,33 @@ def _feature_importance(self, importance_type: str = 'weight') -> list: importance_type (str): How the importance is calculated. For tree model Importance type can be defined as: - * 'weight': the number of times a feature is used to split the data across all trees. - * 'gain': the average gain across all splits the feature is used in. - * 'cover': the average coverage across all splits the feature is used in. - * 'total_gain': the total gain across all splits the feature is used in. - * 'total_cover': the total coverage across all splits the feature is used in. + * 'weight': the number of times a feature is used to split + the data across all trees. + * 'gain': the average gain across all splits the feature + is used in. + * 'cover': the average coverage across all splits the + feature is used in. + * 'total_gain': the total gain across all splits the + feature is used in. + * 'total_cover': the total coverage across all splits the + feature is used in. .. note:: - For linear model, only "weight" is defined and it's the normalized coefficients - without bias. + For linear model, only "weight" is defined and it's the + normalized coefficients without bias. .. note:: Zero-importance features will not be included - Keep in mind that this function does not include zero-importance feature, i.e. - those features that have not been used in any split conditions. + Keep in mind that this function does not include + zero-importance feature, i.e. those features that have not + been used in any split conditions. Returns: list: Array with feature importances. """ + assert importance_type in [ + 'weight', 'gain', 'cover', 'total_gain', 'total_cover' + ], f'{importance_type} is not supported.' scores = self.model.get_score(importance_type=importance_type) return list(scores.values()) From 871c91f6ff658739aab5ca054e77f7e778545c4e Mon Sep 17 00:00:00 2001 From: xnuohz Date: Wed, 20 Dec 2023 02:12:52 +0800 Subject: [PATCH 6/6] update --- examples/tuned_gbdt.py | 12 +++++++----- test/gbdt/test_gbdt.py | 17 +++++++++-------- torch_frame/gbdt/tuned_lightgbm.py | 4 ++-- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/examples/tuned_gbdt.py b/examples/tuned_gbdt.py index f6d9a0a77..9f2b2d5f0 100644 --- a/examples/tuned_gbdt.py +++ b/examples/tuned_gbdt.py @@ -40,6 +40,7 @@ parser.add_argument('--dataset', type=str, default='eye_movements') parser.add_argument('--saved_model_path', type=str, default='storage/gbdts.txt') +parser.add_argument('--feature_importance', action='store_true') # Add this flag to match the reported number. parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() @@ -89,11 +90,12 @@ gbdt.tune(tf_train=train_dataset.tensor_frame, tf_val=val_dataset.tensor_frame, num_trials=20) gbdt.save(args.saved_model_path) - scores = pd.DataFrame({ - 'feature': dataset.feat_cols, - 'importance': gbdt.feature_importance() - }).sort_values(by='importance', ascending=False) - print(scores) + if args.feature_importance: + scores = pd.DataFrame({ + 'feature': dataset.feat_cols, + 'importance': gbdt.feature_importance() + }).sort_values(by='importance', ascending=False) + print(scores) pred = gbdt.predict(tf_test=test_dataset.tensor_frame) score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred) diff --git a/test/gbdt/test_gbdt.py b/test/gbdt/test_gbdt.py index 0cb28f677..2da03ff31 100644 --- a/test/gbdt/test_gbdt.py +++ b/test/gbdt/test_gbdt.py @@ -77,14 +77,15 @@ def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): dataset.tensor_frame.y = None loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame) # TODO: support more stypes - num_features = 0 - for x in stypes: - if x == stype.numerical: - num_features += 3 * 1 - elif x == stype.categorical: - num_features += 2 * 1 - elif x == stype.text_embedded: - num_features += 2 * 8 + feat_dim = { + stype.numerical: 1, + stype.categorical: 1, + stype.embedding: 8, + } + num_features = sum([ + feat_dim[feat_stype] * len(feat_list) for feat_stype, feat_list in + dataset.tensor_frame.col_names_dict.items() + ]) assert (gbdt_cls == XGBoost and len(gbdt.feature_importance()) <= num_features) or (len( diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index f6333831c..79584497a 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -236,8 +236,8 @@ def _feature_importance(self, importance_type: str = 'gain', If "split", result contains numbers of times the feature is used in a model. If "gain", result contains total gains of splits which use the feature. - iteration (int, optional): Limit number of iterations in the feature - importance calculation. If None, if the best iteration exists, + iteration (int, optional): Limit number of `iterations` in the feature + importance calculation. If None, if the best `iteration` exists, it is used; otherwise, all trees are used. If <= 0, all trees are used (no limits).