Skip to content

Commit 4891f1d

Browse files
committed
refactor: distribution methods output types
1 parent f1fb486 commit 4891f1d

File tree

17 files changed

+1032
-207
lines changed

17 files changed

+1032
-207
lines changed

rework_pysatl_mpest/core/mixture.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def remove_component(self, component_idx: int):
255255
self._cached_weights = None
256256
self._sorted_pairs_cache = None
257257

258-
def pdf(self, X: ArrayLike) -> NDArray[DType]:
258+
def pdf(self, X: ArrayLike) -> DType | NDArray[DType]:
259259
"""Probability Density Function of the mixture.
260260
261261
The PDF is computed as the weighted sum of the PDFs of its
@@ -268,15 +268,16 @@ def pdf(self, X: ArrayLike) -> NDArray[DType]:
268268
269269
Returns
270270
-------
271-
NDArray[DType]
271+
DType | NDArray[DType]
272272
The PDF values corresponding to each point in :attr:`X`.
273+
Return a scalar when given a scalar, and to return an array when given an array.
273274
"""
274275

275276
X = np.asarray(X, dtype=self.dtype)
276277
component_pdfs = np.array([comp.pdf(X) for comp in self.components])
277-
return np.asarray(np.dot(self.weights, component_pdfs))
278+
return np.dot(self.weights, component_pdfs)
278279

279-
def lpdf(self, X: ArrayLike) -> NDArray[DType]:
280+
def lpdf(self, X: ArrayLike) -> DType | NDArray[DType]:
280281
"""Logarithms of the Probability Density Function.
281282
282283
Parameters
@@ -286,15 +287,18 @@ def lpdf(self, X: ArrayLike) -> NDArray[DType]:
286287
287288
Returns
288289
-------
289-
NDArray[DType]
290+
DType | NDArray[DType]
290291
The log-PDF values corresponding to each point in :attr:`X`.
292+
Return a scalar when given a scalar, and to return an array when given an array.
291293
"""
292294

293-
X = np.atleast_1d(X).astype(self.dtype)
295+
X = np.asarray(X, dtype=self.dtype)
294296
component_lpdfs = np.array([comp.lpdf(X) for comp in self.components])
295-
log_weights = self.log_weights
296-
log_terms = log_weights[:, np.newaxis] + component_lpdfs
297-
return logsumexp(log_terms, axis=0) # type: ignore
297+
broadcast_shape = (self.n_components,) + (1,) * X.ndim
298+
log_weights = self.log_weights.reshape(broadcast_shape)
299+
log_terms = log_weights + component_lpdfs
300+
301+
return logsumexp(log_terms, axis=0)
298302

299303
def loglikelihood(self, X: ArrayLike) -> DType:
300304
"""Log-likelihood of the complete data :attr:`X`.

rework_pysatl_mpest/distributions/beta.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ def pdf(self, X):
109109
110110
Returns
111111
-------
112-
NDArray[DType]
112+
DType | NDArray[DType]
113113
The PDF values corresponding to each point in :attr:`X`.
114+
Return a scalar when given a scalar, and to return an array when given an array.
114115
115116
"""
117+
116118
X = np.asarray(X, dtype=self.dtype)
117119
return np.exp(self.lpdf(X))
118120

@@ -137,13 +139,16 @@ def ppf(self, P):
137139
138140
Returns
139141
-------
140-
NDArray[DType]
142+
DType | NDArray[DType]
141143
The PPF values corresponding to each probability in :attr:`P`.
144+
Return a scalar when given a scalar, and to return an array when given an array.
142145
"""
146+
147+
is_scalar = np.isscalar(P)
143148
P = np.asarray(P, dtype=self.dtype)
144149
dtype = self.dtype
145150

146-
return np.where(
151+
result = np.where(
147152
(P >= 0) & (P <= 1),
148153
(
149154
self.left_border
@@ -152,6 +157,10 @@ def ppf(self, P):
152157
dtype(np.nan),
153158
)
154159

160+
if is_scalar:
161+
return result[()]
162+
return result
163+
155164
def lpdf(self, X):
156165
"""Log of the Probability Density Function (LPDF).
157166
@@ -177,8 +186,9 @@ def lpdf(self, X):
177186
178187
Returns
179188
-------
180-
NDArray[DType]
189+
DType | NDArray[DType]
181190
The log-PDF values corresponding to each point in :attr:`X`.
191+
Return a scalar when given a scalar, and to return an array when given an array.
182192
"""
183193

184194
X = np.asarray(X, dtype=self.dtype)
@@ -189,7 +199,7 @@ def lpdf(self, X):
189199
log_pdf_standard = beta_dist.logpdf(Z, self.alpha, self.beta).astype(dtype)
190200
result = log_pdf_standard - np.log(self.right_border - self.left_border)
191201

192-
return np.atleast_1d(result)
202+
return result
193203

194204
def _dlog_alpha(self, X):
195205
"""Partial derivative of the lpdf w.r.t. the :attr:`alpha` parameter.
@@ -214,22 +224,28 @@ def _dlog_alpha(self, X):
214224
215225
Returns
216226
-------
217-
NDArray[DType]
227+
DType | NDArray[DType]
218228
The gradient of the lpdf with respect to :attr:`alpha` for each point in :attr:`X`.
229+
Return a scalar when given a scalar, and to return an array when given an array.
219230
"""
220231

232+
is_scalar = np.isscalar(X)
221233
X = np.asarray(X, dtype=self.dtype)
222234
dtype = self.dtype
223235

224236
in_bounds = (self.left_border < X) & (self.right_border >= X)
225-
return np.where(
237+
result = np.where(
226238
in_bounds,
227239
np.log(X - self.left_border)
228240
- np.log(self.right_border - self.left_border)
229241
- (dtype(digamma(self.alpha)) - dtype(digamma(self.alpha + self.beta))),
230242
dtype(0.0),
231243
)
232244

245+
if is_scalar:
246+
return result[()]
247+
return result
248+
233249
def _dlog_beta(self, X):
234250
"""Partial derivative of the lpdf w.r.t. the :attr:`beta` parameter.
235251
@@ -253,22 +269,28 @@ def _dlog_beta(self, X):
253269
254270
Returns
255271
-------
256-
NDArray[DType]
272+
DType | NDArray[DType]
257273
The gradient of the lpdf with respect to :attr:`beta` for each point in :attr:`X`.
274+
Return a scalar when given a scalar, and to return an array when given an array.
258275
"""
259276

277+
is_scalar = np.isscalar(X)
260278
X = np.asarray(X, dtype=self.dtype)
261279
dtype = self.dtype
262280

263281
in_bounds = (self.left_border < X) & (self.right_border >= X)
264-
return np.where(
282+
result = np.where(
265283
in_bounds,
266284
np.log(self.right_border - X)
267285
- np.log(self.right_border - self.left_border)
268286
- (dtype(digamma(self.beta)) - dtype(digamma(self.alpha + self.beta))),
269287
dtype(0.0),
270288
)
271289

290+
if is_scalar:
291+
return result[()]
292+
return result
293+
272294
def _dlog_left_border(self, X):
273295
"""Partial derivative of the lpdf w.r.t. the :attr:`left_border` parameter.
274296
@@ -290,15 +312,17 @@ def _dlog_left_border(self, X):
290312
291313
Returns
292314
-------
293-
NDArray[DType]
315+
DType | NDArray[DType]
294316
The gradient of the lpdf with respect to :attr:`left_border` for each point in :attr:`X`.
317+
Return a scalar when given a scalar, and to return an array when given an array.
295318
"""
296319

320+
is_scalar = np.isscalar(X)
297321
X = np.asarray(X, dtype=self.dtype)
298322
dtype = self.dtype
299323

300324
in_bounds = (self.left_border < X) & (self.right_border >= X)
301-
return np.where(
325+
result = np.where(
302326
in_bounds,
303327
(
304328
((self.alpha + self.beta - dtype(1)) / (self.right_border - self.left_border))
@@ -307,6 +331,10 @@ def _dlog_left_border(self, X):
307331
dtype(0.0),
308332
)
309333

334+
if is_scalar:
335+
return result[()]
336+
return result
337+
310338
def _dlog_right_border(self, X):
311339
"""Partial derivative of the lpdf w.r.t. the :attr:`right_border` parameter.
312340
@@ -328,14 +356,17 @@ def _dlog_right_border(self, X):
328356
329357
Returns
330358
-------
331-
NDArray[DType]
359+
DType | NDArray[DType]
332360
The gradient of the lpdf with respect to :attr:`right_border` for each point in :attr:`X`.
361+
Return a scalar when given a scalar, and to return an array when given an array.
333362
"""
363+
364+
is_scalar = np.isscalar(X)
334365
X = np.asarray(X, dtype=self.dtype)
335366
dtype = self.dtype
336367

337368
in_bounds = (self.left_border < X) & (self.right_border >= X)
338-
return np.where(
369+
result = np.where(
339370
in_bounds,
340371
(
341372
((self.beta - dtype(1)) / (self.right_border - X))
@@ -344,6 +375,10 @@ def _dlog_right_border(self, X):
344375
dtype(0.0),
345376
)
346377

378+
if is_scalar:
379+
return result[()]
380+
return result
381+
347382
def log_gradients(self, X):
348383
"""Calculates the gradients of the log-PDF w.r.t. its parameters.
349384
@@ -361,7 +396,10 @@ def log_gradients(self, X):
361396
and each column corresponds to the gradient with respect to a
362397
specific optimizable parameter. The order of columns corresponds
363398
to the sorted order of :attr:`self.params_to_optimize`.
399+
Returns a 1D array if X is a scalar.
364400
"""
401+
402+
is_scalar = np.isscalar(X)
365403
X = np.asarray(X, dtype=self.dtype)
366404

367405
gradient_calculators = {
@@ -378,6 +416,8 @@ def log_gradients(self, X):
378416

379417
gradients = [gradient_calculators[param](X) for param in optimizable_params]
380418

419+
if is_scalar:
420+
return np.array(gradients)
381421
return np.stack(gradients, axis=1)
382422

383423
def generate(self, size: int):

rework_pysatl_mpest/distributions/cauchy.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,13 @@ def pdf(self, X):
8282
8383
Returns
8484
-------
85-
NDArray[DType]
85+
DType | NDArray[DType]
8686
The PDF values corresponding to each point in :attr:`X`.
87+
Return a scalar when given a scalar, and to return an array when given an array.
8788
"""
8889

8990
X = np.asarray(X, dtype=self.dtype)
90-
dtype = self.dtype
91-
92-
return dtype(1.0) / (dtype(np.pi) * self.scale * (dtype(1.0) + ((X - self.loc) / self.scale) ** 2))
91+
return np.exp(self.lpdf(X))
9392

9493
def ppf(self, P):
9594
"""Percent Point Function (PPF) or quantile function.
@@ -110,13 +109,16 @@ def ppf(self, P):
110109
111110
Returns
112111
-------
113-
NDArray[DType]
112+
DType | NDArray[DType]
114113
The PPF values corresponding to each probability in :attr:`P`.
114+
Return a scalar when given a scalar, and to return an array when given an array.
115115
"""
116+
117+
is_scalar = np.isscalar(P)
116118
P = np.asarray(P, dtype=self.dtype)
117119
dtype = self.dtype
118120

119-
return np.where(
121+
result = np.where(
120122
(P >= 0) & (P <= 1),
121123
np.where(
122124
(P == 0) | (P == 1),
@@ -125,6 +127,9 @@ def ppf(self, P):
125127
),
126128
dtype(np.nan),
127129
)
130+
if is_scalar:
131+
return result[()]
132+
return result
128133

129134
def lpdf(self, X):
130135
"""Log of the Probability Density Function (LPDF).
@@ -145,8 +150,9 @@ def lpdf(self, X):
145150
146151
Returns
147152
-------
148-
NDArray[DType]
153+
DType | NDArray[DType]
149154
The log-PDF values corresponding to each point in :attr:`X`.
155+
Return a scalar when given a scalar, and to return an array when given an array.
150156
"""
151157

152158
X = np.asarray(X, dtype=self.dtype)
@@ -179,14 +185,20 @@ def _dlog_loc(self, X):
179185
180186
Returns
181187
-------
182-
NDArray[DType]
188+
DType | NDArray[DType]
183189
The gradient of the lpdf with respect to :attr:`loc` for each point in ::attr`X`.
190+
Return a scalar when given a scalar, and to return an array when given an array.
184191
"""
185192

193+
is_scalar = np.isscalar(X)
186194
X = np.asarray(X, dtype=self.dtype)
187195
dtype = self.dtype
188196

189-
return (dtype(2) * X - dtype(2) * self.loc) / (self.scale**2 + X**2 - dtype(2) * self.loc * X + self.loc**2)
197+
result = (dtype(2) * X - dtype(2) * self.loc) / (self.scale**2 + X**2 - dtype(2) * self.loc * X + self.loc**2)
198+
199+
if is_scalar:
200+
return result[()]
201+
return result
190202

191203
def _dlog_scale(self, X):
192204
"""Partial derivative of the lpdf w.r.t. the :attr:`scale` parameter.
@@ -208,16 +220,23 @@ def _dlog_scale(self, X):
208220
209221
Returns
210222
-------
211-
NDArray[DType]
223+
DType | NDArray[DType]
212224
The gradient of the lpdf with respect to :attr:`rate` for each point in :attr:`X`.
225+
Return a scalar when given a scalar, and to return an array when given an array.
213226
"""
227+
228+
is_scalar = np.isscalar(X)
214229
X = np.asarray(X, dtype=self.dtype)
215230
dtype = self.dtype
216231

217-
return (-(self.scale**2) + X**2 - dtype(2) * self.loc * X + self.loc**2) / (
232+
result = (-(self.scale**2) + X**2 - dtype(2) * self.loc * X + self.loc**2) / (
218233
self.scale**3 + self.scale * (X**2) - dtype(2) * self.loc * self.scale * X + self.scale * self.loc**2
219234
)
220235

236+
if is_scalar:
237+
return result[()]
238+
return result
239+
221240
def log_gradients(self, X):
222241
"""Calculates the gradients of the log-PDF w.r.t. its parameters.
223242
@@ -235,7 +254,10 @@ def log_gradients(self, X):
235254
and each column corresponds to the gradient with respect to a
236255
specific optimizable parameter. The order of columns corresponds
237256
to the sorted order of :attr:`self.params_to_optimize`.
257+
Returns a 1D array if X is a scalar.
238258
"""
259+
260+
is_scalar = np.isscalar(X)
239261
X = np.asarray(X, dtype=self.dtype)
240262

241263
gradient_calculators = {
@@ -250,6 +272,8 @@ def log_gradients(self, X):
250272

251273
gradients = [gradient_calculators[param](X) for param in optimizable_params]
252274

275+
if is_scalar:
276+
return np.array(gradients)
253277
return np.stack(gradients, axis=1)
254278

255279
def generate(self, size: int):

0 commit comments

Comments
 (0)