@@ -1249,46 +1249,44 @@ def __init__(
12491249 super ().__init__ (np .mean , ksize = ksize , ** kwargs )
12501250
12511251
1252- #TODO ***AL*** revise MaxPooling - torch, typing, docstring, unit test
12531252class MaxPooling (Pool ):
12541253 """Apply max-pooling to images.
12551254
1256- This class reduces the resolution of an image by dividing it into
1257- non-overlapping blocks of size `ksize` and applying the max function to
1258- each block. The result is a downsampled image where each pixel value
1255+ `MaxPooling` reduces the resolution of an image by dividing it into
1256+ non-overlapping blocks of size `ksize` and applying the ` max` function
1257+ to each block. The result is a downsampled image where each pixel value
12591258 represents the maximum value within the corresponding block of the
1260- original image.
1261- This is useful for reducing the size of an image while retaining the
1262- most significant features.
1259+ original image. This is useful for reducing the size of an image while
1260+ retaining the most significant features.
1261+
1262+ If the backend is NumPy, the downsampling is performed using
1263+ `skimage.measure.block_reduce`.
1264+
1265+ If the backend is PyTorch, the downsampling is performed using
1266+ `torch.nn.functional.max_pool2d`.
12631267
12641268 Parameters
12651269 ----------
12661270 ksize: int
12671271 Size of the pooling kernel.
1268- cval: number
1269- Value to pad edges with if necessary. Default 0.
1270- func_kwargs: dict
1272+ **kwargs: Any
12711273 Additional parameters sent to the pooling function.
12721274
12731275 Examples
12741276 --------
12751277 >>> import deeptrack as dt
1276- >>> import numpy as np
1278+
12771279 Create an input image:
1280+ >>> import numpy as np
1281+ >>>
12781282 >>> input_image = np.random.rand(32, 32)
12791283
1280- Define a max-pooling feature:
1284+ Define and use a max-pooling feature:
1285+
12811286 >>> max_pooling = dt.MaxPooling(ksize=8)
12821287 >>> output_image = max_pooling(input_image)
1283- >>> print(output_image.shape)
1284- (8, 8)
1285-
1286- Notes
1287- -----
1288- Calling this feature returns a `np.ndarray` by default. If
1289- `store_properties` is set to `True`, the returned array will be
1290- automatically wrapped in an `Image` object. This behavior is handled
1291- internally and does not affect the return type of the `get()` method.
1288+ >>> output_image.shape
1289+ (4, 4)
12921290
12931291 """
12941292
@@ -1312,6 +1310,115 @@ def __init__(
13121310
13131311 super ().__init__ (np .max , ksize = ksize , ** kwargs )
13141312
1313+ def get (
1314+ self : MaxPooling ,
1315+ image : NDArray [Any ] | torch .Tensor ,
1316+ ksize : int = 3 ,
1317+ ** kwargs : Any ,
1318+ ) -> NDArray [Any ] | torch .Tensor :
1319+ """Max-pooling of input.
1320+
1321+ Checks the current backend and chooses the appropriate function to pool
1322+ the input image, either `._get_torch()` or `._get_numpy()`.
1323+
1324+ Parameters
1325+ ----------
1326+ image: array or tensor
1327+ Input array or tensor be pooled.
1328+ ksize: int
1329+ Kernel size of the pooling operation.
1330+
1331+ Returns
1332+ -------
1333+ array or tensor
1334+ The pooled input as `NDArray` or `torch.Tensor` depending on
1335+ the backend.
1336+
1337+ """
1338+
1339+ if self .get_backend () == "numpy" :
1340+ return self ._get_numpy (image , ksize , ** kwargs )
1341+
1342+ if self .get_backend () == "torch" :
1343+ return self ._get_torch (image , ksize , ** kwargs )
1344+
1345+ raise NotImplementedError (f"Backend { self .backend } not supported" )
1346+
1347+ def _get_numpy (
1348+ self : MaxPooling ,
1349+ image : NDArray [Any ],
1350+ ksize : int = 3 ,
1351+ ** kwargs : Any ,
1352+ ) -> NDArray [Any ]:
1353+ """Max-pooling pooling with the NumPy backend enabled.
1354+
1355+ Returns the result of the input array passed to the scikit image
1356+ `block_reduce()` function with `np.max()` as the pooling function.
1357+
1358+ Parameters
1359+ ----------
1360+ image: array
1361+ Input array to be pooled.
1362+ ksize: int
1363+ Kernel size of the pooling operation.
1364+
1365+ Returns
1366+ -------
1367+ array
1368+ The pooled image as a NumPy array.
1369+
1370+ """
1371+
1372+ return utils .safe_call (
1373+ skimage .measure .block_reduce ,
1374+ image = image ,
1375+ func = np .max ,
1376+ block_size = ksize ,
1377+ ** kwargs ,
1378+ )
1379+
1380+ def _get_torch (
1381+ self : MaxPooling ,
1382+ image : torch .Tensor ,
1383+ ksize : int = 3 ,
1384+ ** kwargs : Any ,
1385+ ) -> torch .Tensor :
1386+ """Max-pooling with the PyTorch backend enabled.
1387+
1388+
1389+ Returns the result of the tensor passed to a PyTorch max
1390+ pooling layer.
1391+
1392+ Parameters
1393+ ----------
1394+ image: torch.Tensor
1395+ Input tensor to be pooled.
1396+ ksize: int
1397+ Kernel size of the pooling operation.
1398+
1399+ Returns
1400+ -------
1401+ torch.Tensor
1402+ The pooled image as a `torch.Tensor`.
1403+
1404+ """
1405+
1406+ # If input tensor is 2D
1407+ if len (image .shape ) == 2 :
1408+ # Add batch dimension for max-pooling
1409+ expanded_image = image .unsqueeze (0 )
1410+
1411+ pooled_image = torch .nn .functional .max_pool2d (
1412+ expanded_image , kernel_size = ksize ,
1413+ )
1414+ # Remove the expanded dim
1415+ return pooled_image .squeeze (0 )
1416+
1417+ return torch .nn .functional .max_pool2d (
1418+ image ,
1419+ kernel_size = ksize ,
1420+ )
1421+
13151422
13161423class MinPooling (Pool ):
13171424 """Apply min-pooling to images.
0 commit comments