88 from typing_extensions import Self
99
1010import numpy as np
11+ from numbers import Integral , Real
1112from sklearn .base import BaseEstimator
1213from sklearn .linear_model ._base import LinearClassifierMixin
14+ from sklearn .metrics .pairwise import pairwise_kernels
1315from sklearn .utils .extmath import safe_sparse_dot
1416from sklearn .utils .multiclass import check_classification_targets , type_of_target
15- from sklearn .utils .validation import _check_sample_weight
17+ from sklearn .utils .validation import _check_sample_weight , check_is_fitted
18+ from sklearn .utils ._param_validation import Interval , StrOptions
1619
1720from ._utils import (
1821 SKLEARN_V1_6_OR_LATER ,
@@ -38,28 +41,46 @@ class SEFR(LinearClassifierMixin, BaseEstimator):
3841 Parameters
3942 ----------
4043 fit_intercept : bool, default=True
41- Specifies if a constant (a.k.a. bias or intercept) should be
42- added to the decision function.
44+ Specifies if a constant (a.k.a. bias or intercept) should be
45+ added to the decision function.
46+
47+ kernel : {'linear', 'poly', 'rbf', 'sigmoid'} or callable, default='linear'
48+ Specifies the kernel type to be used in the algorithm.
49+ If a callable is given, it is used to pre-compute the kernel matrix.
50+
51+ gamma : float, default=None
52+ Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. If None, then it is
53+ set to 1.0 / n_features.
54+
55+ degree : int, default=3
56+ Degree for 'poly' kernels. Ignored by other kernels.
57+
58+ coef0 : float, default=1
59+ Independent term in kernel function. It is only significant in 'poly' and 'sigmoid'.
4360
4461 Attributes
4562 ----------
4663 classes_ : ndarray of shape (n_classes, )
47- A list of class labels known to the classifier.
64+ A list of class labels known to the classifier.
4865
49- coef_ : ndarray of shape (1, n_features)
50- Coefficient of the features in the decision function.
66+ coef_ : ndarray of shape (1, n_features) or (1, n_samples)
67+ Coefficient of the features in the decision function. When a kernel is used,
68+ the shape will be (1, n_samples).
5169
5270 intercept_ : ndarray of shape (1,)
53- Intercept (a.k.a. bias) added to the decision function.
71+ Intercept (a.k.a. bias) added to the decision function.
5472
55- If `fit_intercept` is set to False, the intercept is set to zero.
73+ If `fit_intercept` is set to False, the intercept is set to zero.
5674
5775 n_features_in_ : int
58- Number of features seen during :term:`fit`.
76+ Number of features seen during :term:`fit`.
5977
6078 feature_names_in_ : ndarray of shape (`n_features_in_`,)
61- Names of features seen during :term:`fit`. Defined only when `X`
62- has feature names that are all strings.
79+ Names of features seen during :term:`fit`. Defined only when `X`
80+ has feature names that are all strings.
81+
82+ X_fit_ : ndarray of shape (n_samples, n_features)
83+ The training data, stored when a kernel is used.
6384
6485 Notes
6586 -----
@@ -70,22 +91,35 @@ class SEFR(LinearClassifierMixin, BaseEstimator):
7091 >>> from linearboost import SEFR
7192 >>> from sklearn.datasets import load_breast_cancer
7293 >>> X, y = load_breast_cancer(return_X_y=True)
73- >>> clf = SEFR().fit(X, y)
94+ >>> clf = SEFR(kernel='rbf' ).fit(X, y)
7495 >>> clf.predict(X[:2, :])
7596 array([0, 0])
76- >>> clf.predict_proba(X[:2, :])
77- array([[1.00...e+000, 2.04...e-154],
78- [1.00...e+000, 1.63...e-165]])
7997 >>> clf.score(X, y)
80- 0.86 ...
98+ 0.89 ...
8199 """
82100
83101 _parameter_constraints : dict = {
84102 "fit_intercept" : ["boolean" ],
103+ "kernel" : [StrOptions ({"linear" , "poly" , "rbf" , "sigmoid" }), callable ],
104+ "gamma" : [Interval (Real , 0 , None , closed = "left" ), None ],
105+ "degree" : [Interval (Integral , 1 , None , closed = "left" ), None ],
106+ "coef0" : [Real , None ],
85107 }
86108
87- def __init__ (self , * , fit_intercept = True ):
109+ def __init__ (
110+ self ,
111+ * ,
112+ fit_intercept = True ,
113+ kernel = "linear" ,
114+ gamma = None ,
115+ degree = 3 ,
116+ coef0 = 1 ,
117+ ):
88118 self .fit_intercept = fit_intercept
119+ self .kernel = kernel
120+ self .gamma = gamma
121+ self .degree = degree
122+ self .coef0 = coef0
89123
90124 if SKLEARN_V1_6_OR_LATER :
91125
@@ -145,6 +179,23 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
145179
146180 return X , y
147181
182+ def _get_kernel_matrix (self , X , Y = None ):
183+ if Y is None :
184+ Y = self .X_fit_
185+
186+ if callable (self .kernel ):
187+ return self .kernel (X , Y )
188+ else :
189+ return pairwise_kernels (
190+ X ,
191+ Y ,
192+ metric = self .kernel ,
193+ filter_params = True ,
194+ gamma = self .gamma ,
195+ degree = self .degree ,
196+ coef0 = self .coef0 ,
197+ )
198+
148199 @_fit_context (prefer_skip_nested_validation = True )
149200 def fit (self , X , y , sample_weight = None ) -> Self :
150201 """
@@ -153,27 +204,33 @@ def fit(self, X, y, sample_weight=None) -> Self:
153204 Parameters
154205 ----------
155206 X : {array-like, sparse matrix} of shape (n_samples, n_features)
156- Training vector, where `n_samples` is the number of samples and
157- `n_features` is the number of features.
207+ Training vector, where `n_samples` is the number of samples and
208+ `n_features` is the number of features.
158209
159210 y : array-like of shape (n_samples,)
160- Target vector relative to X.
211+ Target vector relative to X.
161212
162213 sample_weight : array-like of shape (n_samples,) default=None
163- Array of weights that are assigned to individual samples.
164- If not provided, then each sample is given unit weight.
214+ Array of weights that are assigned to individual samples.
215+ If not provided, then each sample is given unit weight.
165216
166217 Returns
167218 -------
168219 self
169- Fitted estimator.
220+ Fitted estimator.
170221 """
171222 _check_n_features (self , X = X , reset = True )
172223 _check_feature_names (self , X = X , reset = True )
173224
174225 X , y = self ._check_X_y (X , y )
226+ self .X_fit_ = X
175227 self .classes_ , y_ = np .unique (y , return_inverse = True )
176228
229+ if self .kernel == "linear" :
230+ K = X
231+ else :
232+ K = self ._get_kernel_matrix (X )
233+
177234 pos_labels = y_ == 1
178235 neg_labels = y_ == 0
179236
@@ -193,13 +250,13 @@ def fit(self, X, y, sample_weight=None) -> Self:
193250 if np .all (pos_sample_weight == 0 ) or np .all (neg_sample_weight == 0 ):
194251 raise ValueError ("SEFR requires 2 classes; got only 1 class." )
195252
196- avg_pos = np .average (X [pos_labels , :], axis = 0 , weights = pos_sample_weight )
197- avg_neg = np .average (X [neg_labels , :], axis = 0 , weights = neg_sample_weight )
253+ avg_pos = np .average (K [pos_labels , :], axis = 0 , weights = pos_sample_weight )
254+ avg_neg = np .average (K [neg_labels , :], axis = 0 , weights = neg_sample_weight )
198255 self .coef_ = (avg_pos - avg_neg ) / (avg_pos + avg_neg + 1e-7 )
199256 self .coef_ = np .reshape (self .coef_ , (1 , - 1 ))
200257
201258 if self .fit_intercept :
202- scores = safe_sparse_dot (X , self .coef_ .T , dense_output = True )
259+ scores = safe_sparse_dot (K , self .coef_ .T , dense_output = True )
203260 pos_score_avg = np .average (
204261 scores [pos_labels ][:, 0 ], weights = pos_sample_weight
205262 )
@@ -217,6 +274,17 @@ def fit(self, X, y, sample_weight=None) -> Self:
217274
218275 return self
219276
277+ def decision_function (self , X ):
278+ check_is_fitted (self )
279+ X = self ._check_X (X )
280+ if self .kernel == "linear" :
281+ K = X
282+ else :
283+ K = self ._get_kernel_matrix (X )
284+ return (
285+ safe_sparse_dot (K , self .coef_ .T , dense_output = True ) + self .intercept_
286+ ).ravel ()
287+
220288 def predict_proba (self , X ):
221289 """
222290 Probability estimates.
@@ -227,16 +295,22 @@ def predict_proba(self, X):
227295 Parameters
228296 ----------
229297 X : array-like of shape (n_samples, n_features)
230- Vector to be scored, where `n_samples` is the number of samples and
231- `n_features` is the number of features.
298+ Vector to be scored, where `n_samples` is the number of samples and
299+ `n_features` is the number of features.
232300
233301 Returns
234302 -------
235303 T : array-like of shape (n_samples, n_classes)
236- Returns the probability of the sample for each class in the model,
237- where classes are ordered as they are in ``self.classes_``.
304+ Returns the probability of the sample for each class in the model,
305+ where classes are ordered as they are in ``self.classes_``.
238306 """
239- score = self .decision_function (X ) / np .linalg .norm (self .coef_ )
307+ check_is_fitted (self )
308+ norm_coef = np .linalg .norm (self .coef_ )
309+ if norm_coef == 0 :
310+ # Handle the case of a zero-norm coefficient vector to avoid division by zero
311+ score = self .decision_function (X )
312+ else :
313+ score = self .decision_function (X ) / norm_coef
240314 proba = 1.0 / (1.0 + np .exp (- score ))
241315 return np .column_stack ((1.0 - proba , proba ))
242316
@@ -250,13 +324,13 @@ def predict_log_proba(self, X):
250324 Parameters
251325 ----------
252326 X : array-like of shape (n_samples, n_features)
253- Vector to be scored, where `n_samples` is the number of samples and
254- `n_features` is the number of features.
327+ Vector to be scored, where `n_samples` is the number of samples and
328+ `n_features` is the number of features.
255329
256330 Returns
257331 -------
258332 T : array-like of shape (n_samples, n_classes)
259- Returns the log-probability of the sample for each class in the
260- model, where classes are ordered as they are in ``self.classes_``.
333+ Returns the log-probability of the sample for each class in the
334+ model, where classes are ordered as they are in ``self.classes_``.
261335 """
262336 return np .log (self .predict_proba (X ))
0 commit comments