Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions boruta/boruta_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import numpy as np
import scipy as sp
from sklearn.utils import check_random_state, check_X_y
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin
from sklearn.utils.validation import check_is_fitted
import warnings


class BorutaPy(BaseEstimator, TransformerMixin):
class BorutaPy(BaseEstimator, SelectorMixin):
"""
Improved Python implementation of the Boruta R package.

Expand Down Expand Up @@ -287,11 +289,19 @@ def _fit(self, X, y):
# check input params
self._check_params(X, y)

feature_names = getattr(X, "columns", None)
if feature_names is not None:
self.feature_names_in_ = np.asarray(feature_names, dtype=object)
else:
self.feature_names_in_ = None

if not isinstance(X, np.ndarray):
X = self._validate_pandas_input(X)
if not isinstance(y, np.ndarray):
y = self._validate_pandas_input(y)

self.n_features_in_ = X.shape[1]

self.random_state = check_random_state(self.random_state)

early_stopping = False
Expand Down Expand Up @@ -465,6 +475,10 @@ def _set_n_estimators(self, n_estimators):
)
return self

def _get_support_mask(self):
check_is_fitted(self, 'support_')
return self.support_

def _get_tree_num(self, n_feat):
depth = None
try:
Expand Down
59 changes: 58 additions & 1 deletion boruta/test/test_boruta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import pytest
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import NotFittedError
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier

from boruta import BorutaPy
Expand Down Expand Up @@ -68,6 +69,62 @@ def test_dataframe_is_returned(Xy):
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)


def test_selector_mixin_get_support_requires_fit():
bt = BorutaPy(RandomForestClassifier())
with pytest.raises(NotFittedError):
bt.get_support()


def test_selector_mixin_get_support_matches_mask(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier())
bt.fit(X, y)

assert np.array_equal(bt.get_support(), bt.support_)
assert np.array_equal(bt.get_support(indices=True),
np.where(bt.support_)[0])


def test_selector_mixin_inverse_transform_restores_selected_features(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier())
bt.fit(X, y)

X_selected = bt.transform(X)
X_reconstructed = bt.inverse_transform(X_selected)

assert X_reconstructed.shape == X.shape
assert np.allclose(X_reconstructed[:, bt.support_], X[:, bt.support_])

if (~bt.support_).any():
assert np.allclose(X_reconstructed[:, ~bt.support_], 0)


def test_selector_mixin_get_feature_names_out_requires_fit():
bt = BorutaPy(RandomForestClassifier())
with pytest.raises(NotFittedError):
bt.get_feature_names_out()


def test_selector_mixin_get_feature_names_out_returns_selected_names(Xy):
X, y = Xy
bt = BorutaPy(RandomForestClassifier())
bt.fit(X, y)

expected_default = np.array([f"x{i}" for i in np.where(bt.support_)[0]])
assert np.array_equal(bt.get_feature_names_out(), expected_default)

custom_names = np.array([f"feature_{i}" for i in range(X.shape[1])])
selected_names = bt.get_feature_names_out(custom_names)
assert np.array_equal(selected_names, custom_names[bt.support_])

columns = [f"col_{i}" for i in range(X.shape[1])]
X_df = pd.DataFrame(X, columns=columns)
bt_df = BorutaPy(RandomForestClassifier())
bt_df.fit(X_df, y)
assert np.array_equal(bt_df.get_feature_names_out(), np.array(columns)[bt_df.support_])


@pytest.mark.parametrize("tree", [ExtraTreeClassifier(), DecisionTreeClassifier()])
def test_boruta_with_decision_trees(tree, Xy):
msg = (
Expand All @@ -80,4 +137,4 @@ def test_boruta_with_decision_trees(tree, Xy):
with pytest.raises(ValueError) as record:
bt.fit(X, y)

assert str(record.value) == msg
assert str(record.value) == msg