@@ -109,10 +109,12 @@ def pdf(self, X):
109109
110110 Returns
111111 -------
112- NDArray[DType]
112+ DType | NDArray[DType]
113113 The PDF values corresponding to each point in :attr:`X`.
114+ Return a scalar when given a scalar, and to return an array when given an array.
114115
115116 """
117+
116118 X = np .asarray (X , dtype = self .dtype )
117119 return np .exp (self .lpdf (X ))
118120
@@ -137,13 +139,16 @@ def ppf(self, P):
137139
138140 Returns
139141 -------
140- NDArray[DType]
142+ DType | NDArray[DType]
141143 The PPF values corresponding to each probability in :attr:`P`.
144+ Return a scalar when given a scalar, and to return an array when given an array.
142145 """
146+
147+ is_scalar = np .isscalar (P )
143148 P = np .asarray (P , dtype = self .dtype )
144149 dtype = self .dtype
145150
146- return np .where (
151+ result = np .where (
147152 (P >= 0 ) & (P <= 1 ),
148153 (
149154 self .left_border
@@ -152,6 +157,10 @@ def ppf(self, P):
152157 dtype (np .nan ),
153158 )
154159
160+ if is_scalar :
161+ return result [()]
162+ return result
163+
155164 def lpdf (self , X ):
156165 """Log of the Probability Density Function (LPDF).
157166
@@ -177,8 +186,9 @@ def lpdf(self, X):
177186
178187 Returns
179188 -------
180- NDArray[DType]
189+ DType | NDArray[DType]
181190 The log-PDF values corresponding to each point in :attr:`X`.
191+ Return a scalar when given a scalar, and to return an array when given an array.
182192 """
183193
184194 X = np .asarray (X , dtype = self .dtype )
@@ -189,7 +199,7 @@ def lpdf(self, X):
189199 log_pdf_standard = beta_dist .logpdf (Z , self .alpha , self .beta ).astype (dtype )
190200 result = log_pdf_standard - np .log (self .right_border - self .left_border )
191201
192- return np . atleast_1d ( result )
202+ return result
193203
194204 def _dlog_alpha (self , X ):
195205 """Partial derivative of the lpdf w.r.t. the :attr:`alpha` parameter.
@@ -214,22 +224,28 @@ def _dlog_alpha(self, X):
214224
215225 Returns
216226 -------
217- NDArray[DType]
227+ DType | NDArray[DType]
218228 The gradient of the lpdf with respect to :attr:`alpha` for each point in :attr:`X`.
229+ Return a scalar when given a scalar, and to return an array when given an array.
219230 """
220231
232+ is_scalar = np .isscalar (X )
221233 X = np .asarray (X , dtype = self .dtype )
222234 dtype = self .dtype
223235
224236 in_bounds = (self .left_border < X ) & (self .right_border >= X )
225- return np .where (
237+ result = np .where (
226238 in_bounds ,
227239 np .log (X - self .left_border )
228240 - np .log (self .right_border - self .left_border )
229241 - (dtype (digamma (self .alpha )) - dtype (digamma (self .alpha + self .beta ))),
230242 dtype (0.0 ),
231243 )
232244
245+ if is_scalar :
246+ return result [()]
247+ return result
248+
233249 def _dlog_beta (self , X ):
234250 """Partial derivative of the lpdf w.r.t. the :attr:`beta` parameter.
235251
@@ -253,22 +269,28 @@ def _dlog_beta(self, X):
253269
254270 Returns
255271 -------
256- NDArray[DType]
272+ DType | NDArray[DType]
257273 The gradient of the lpdf with respect to :attr:`beta` for each point in :attr:`X`.
274+ Return a scalar when given a scalar, and to return an array when given an array.
258275 """
259276
277+ is_scalar = np .isscalar (X )
260278 X = np .asarray (X , dtype = self .dtype )
261279 dtype = self .dtype
262280
263281 in_bounds = (self .left_border < X ) & (self .right_border >= X )
264- return np .where (
282+ result = np .where (
265283 in_bounds ,
266284 np .log (self .right_border - X )
267285 - np .log (self .right_border - self .left_border )
268286 - (dtype (digamma (self .beta )) - dtype (digamma (self .alpha + self .beta ))),
269287 dtype (0.0 ),
270288 )
271289
290+ if is_scalar :
291+ return result [()]
292+ return result
293+
272294 def _dlog_left_border (self , X ):
273295 """Partial derivative of the lpdf w.r.t. the :attr:`left_border` parameter.
274296
@@ -290,15 +312,17 @@ def _dlog_left_border(self, X):
290312
291313 Returns
292314 -------
293- NDArray[DType]
315+ DType | NDArray[DType]
294316 The gradient of the lpdf with respect to :attr:`left_border` for each point in :attr:`X`.
317+ Return a scalar when given a scalar, and to return an array when given an array.
295318 """
296319
320+ is_scalar = np .isscalar (X )
297321 X = np .asarray (X , dtype = self .dtype )
298322 dtype = self .dtype
299323
300324 in_bounds = (self .left_border < X ) & (self .right_border >= X )
301- return np .where (
325+ result = np .where (
302326 in_bounds ,
303327 (
304328 ((self .alpha + self .beta - dtype (1 )) / (self .right_border - self .left_border ))
@@ -307,6 +331,10 @@ def _dlog_left_border(self, X):
307331 dtype (0.0 ),
308332 )
309333
334+ if is_scalar :
335+ return result [()]
336+ return result
337+
310338 def _dlog_right_border (self , X ):
311339 """Partial derivative of the lpdf w.r.t. the :attr:`right_border` parameter.
312340
@@ -328,14 +356,17 @@ def _dlog_right_border(self, X):
328356
329357 Returns
330358 -------
331- NDArray[DType]
359+ DType | NDArray[DType]
332360 The gradient of the lpdf with respect to :attr:`right_border` for each point in :attr:`X`.
361+ Return a scalar when given a scalar, and to return an array when given an array.
333362 """
363+
364+ is_scalar = np .isscalar (X )
334365 X = np .asarray (X , dtype = self .dtype )
335366 dtype = self .dtype
336367
337368 in_bounds = (self .left_border < X ) & (self .right_border >= X )
338- return np .where (
369+ result = np .where (
339370 in_bounds ,
340371 (
341372 ((self .beta - dtype (1 )) / (self .right_border - X ))
@@ -344,6 +375,10 @@ def _dlog_right_border(self, X):
344375 dtype (0.0 ),
345376 )
346377
378+ if is_scalar :
379+ return result [()]
380+ return result
381+
347382 def log_gradients (self , X ):
348383 """Calculates the gradients of the log-PDF w.r.t. its parameters.
349384
@@ -361,7 +396,10 @@ def log_gradients(self, X):
361396 and each column corresponds to the gradient with respect to a
362397 specific optimizable parameter. The order of columns corresponds
363398 to the sorted order of :attr:`self.params_to_optimize`.
399+ Returns a 1D array if X is a scalar.
364400 """
401+
402+ is_scalar = np .isscalar (X )
365403 X = np .asarray (X , dtype = self .dtype )
366404
367405 gradient_calculators = {
@@ -378,6 +416,8 @@ def log_gradients(self, X):
378416
379417 gradients = [gradient_calculators [param ](X ) for param in optimizable_params ]
380418
419+ if is_scalar :
420+ return np .array (gradients )
381421 return np .stack (gradients , axis = 1 )
382422
383423 def generate (self , size : int ):
0 commit comments