Skip to content

Commit 4a79b7a

Browse files
AL/math/MaxPooling (#407)
* added maxpooling docs and unit tests * Update test_math.py * Implemented feedback from Mirja * Added type and shape check * ú * Added shape handling for len(dim) = 2 * Update test_math with len(dim) = 2 * type hints * implemented xp for tests * Update math.py * Update test_math.py * Update math.py * Update test_math.py * Update test_math.py * Update math.py * Update math.py * u * Update math.py --------- Co-authored-by: Giovanni Volpe <giovanni.volpe@physics.gu.se>
1 parent dc302ba commit 4a79b7a

File tree

2 files changed

+141
-23
lines changed

2 files changed

+141
-23
lines changed

deeptrack/math.py

Lines changed: 128 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,46 +1249,44 @@ def __init__(
12491249
super().__init__(np.mean, ksize=ksize, **kwargs)
12501250

12511251

1252-
#TODO ***AL*** revise MaxPooling - torch, typing, docstring, unit test
12531252
class MaxPooling(Pool):
12541253
"""Apply max-pooling to images.
12551254
1256-
This class reduces the resolution of an image by dividing it into
1257-
non-overlapping blocks of size `ksize` and applying the max function to
1258-
each block. The result is a downsampled image where each pixel value
1255+
`MaxPooling` reduces the resolution of an image by dividing it into
1256+
non-overlapping blocks of size `ksize` and applying the `max` function
1257+
to each block. The result is a downsampled image where each pixel value
12591258
represents the maximum value within the corresponding block of the
1260-
original image.
1261-
This is useful for reducing the size of an image while retaining the
1262-
most significant features.
1259+
original image. This is useful for reducing the size of an image while
1260+
retaining the most significant features.
1261+
1262+
If the backend is NumPy, the downsampling is performed using
1263+
`skimage.measure.block_reduce`.
1264+
1265+
If the backend is PyTorch, the downsampling is performed using
1266+
`torch.nn.functional.max_pool2d`.
12631267
12641268
Parameters
12651269
----------
12661270
ksize: int
12671271
Size of the pooling kernel.
1268-
cval: number
1269-
Value to pad edges with if necessary. Default 0.
1270-
func_kwargs: dict
1272+
**kwargs: Any
12711273
Additional parameters sent to the pooling function.
12721274
12731275
Examples
12741276
--------
12751277
>>> import deeptrack as dt
1276-
>>> import numpy as np
1278+
12771279
Create an input image:
1280+
>>> import numpy as np
1281+
>>>
12781282
>>> input_image = np.random.rand(32, 32)
12791283
1280-
Define a max-pooling feature:
1284+
Define and use a max-pooling feature:
1285+
12811286
>>> max_pooling = dt.MaxPooling(ksize=8)
12821287
>>> output_image = max_pooling(input_image)
1283-
>>> print(output_image.shape)
1284-
(8, 8)
1285-
1286-
Notes
1287-
-----
1288-
Calling this feature returns a `np.ndarray` by default. If
1289-
`store_properties` is set to `True`, the returned array will be
1290-
automatically wrapped in an `Image` object. This behavior is handled
1291-
internally and does not affect the return type of the `get()` method.
1288+
>>> output_image.shape
1289+
(4, 4)
12921290
12931291
"""
12941292

@@ -1312,6 +1310,115 @@ def __init__(
13121310

13131311
super().__init__(np.max, ksize=ksize, **kwargs)
13141312

1313+
def get(
1314+
self: MaxPooling,
1315+
image: NDArray[Any] | torch.Tensor,
1316+
ksize: int=3,
1317+
**kwargs: Any,
1318+
) -> NDArray[Any] | torch.Tensor:
1319+
"""Max-pooling of input.
1320+
1321+
Checks the current backend and chooses the appropriate function to pool
1322+
the input image, either `._get_torch()` or `._get_numpy()`.
1323+
1324+
Parameters
1325+
----------
1326+
image: array or tensor
1327+
Input array or tensor be pooled.
1328+
ksize: int
1329+
Kernel size of the pooling operation.
1330+
1331+
Returns
1332+
-------
1333+
array or tensor
1334+
The pooled input as `NDArray` or `torch.Tensor` depending on
1335+
the backend.
1336+
1337+
"""
1338+
1339+
if self.get_backend() == "numpy":
1340+
return self._get_numpy(image, ksize, **kwargs)
1341+
1342+
if self.get_backend() == "torch":
1343+
return self._get_torch(image, ksize, **kwargs)
1344+
1345+
raise NotImplementedError(f"Backend {self.backend} not supported")
1346+
1347+
def _get_numpy(
1348+
self: MaxPooling,
1349+
image: NDArray[Any],
1350+
ksize: int=3,
1351+
**kwargs: Any,
1352+
) -> NDArray[Any]:
1353+
"""Max-pooling pooling with the NumPy backend enabled.
1354+
1355+
Returns the result of the input array passed to the scikit image
1356+
`block_reduce()` function with `np.max()` as the pooling function.
1357+
1358+
Parameters
1359+
----------
1360+
image: array
1361+
Input array to be pooled.
1362+
ksize: int
1363+
Kernel size of the pooling operation.
1364+
1365+
Returns
1366+
-------
1367+
array
1368+
The pooled image as a NumPy array.
1369+
1370+
"""
1371+
1372+
return utils.safe_call(
1373+
skimage.measure.block_reduce,
1374+
image=image,
1375+
func=np.max,
1376+
block_size=ksize,
1377+
**kwargs,
1378+
)
1379+
1380+
def _get_torch(
1381+
self: MaxPooling,
1382+
image: torch.Tensor,
1383+
ksize: int=3,
1384+
**kwargs: Any,
1385+
) -> torch.Tensor:
1386+
"""Max-pooling with the PyTorch backend enabled.
1387+
1388+
1389+
Returns the result of the tensor passed to a PyTorch max
1390+
pooling layer.
1391+
1392+
Parameters
1393+
----------
1394+
image: torch.Tensor
1395+
Input tensor to be pooled.
1396+
ksize: int
1397+
Kernel size of the pooling operation.
1398+
1399+
Returns
1400+
-------
1401+
torch.Tensor
1402+
The pooled image as a `torch.Tensor`.
1403+
1404+
"""
1405+
1406+
# If input tensor is 2D
1407+
if len(image.shape) == 2:
1408+
# Add batch dimension for max-pooling
1409+
expanded_image = image.unsqueeze(0)
1410+
1411+
pooled_image = torch.nn.functional.max_pool2d(
1412+
expanded_image, kernel_size=ksize,
1413+
)
1414+
# Remove the expanded dim
1415+
return pooled_image.squeeze(0)
1416+
1417+
return torch.nn.functional.max_pool2d(
1418+
image,
1419+
kernel_size=ksize,
1420+
)
1421+
13151422

13161423
class MinPooling(Pool):
13171424
"""Apply min-pooling to images.

deeptrack/tests/test_math.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,21 @@ def test_Blur(self):
7878
#input_image = xp.asarray(np.array([[1, 2], [3, 4]], dtype=float))
7979
#expected_output = xp.asarray(np.array([[1, 1.5], [2, 2.5]]))
8080

81-
#eature = math.Blur(filter_function=uniform_filter, size=2)
81+
#feature = math.Blur(filter_function=uniform_filter, size=2)
8282
#blurred_image = feature.resolve(input_image)
8383
#self.assertTrue(xp.all(blurred_image == expected_output))
8484

85+
86+
def test_MaxPooling(self):
87+
input_image = xp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float)
88+
feature = math.MaxPooling(ksize=2)
89+
pooled_image = feature.resolve(input_image)
90+
91+
expected = xp.asarray([[6.0, 8.0]], dtype=float)
92+
93+
self.assertTrue(xp.all(pooled_image == expected))
94+
self.assertEqual(pooled_image.shape, (1, 2))
95+
8596
def test_MinPooling(self):
8697
input_image = xp.asarray([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float)
8798
feature = math.MinPooling(ksize=2)
@@ -123,7 +134,7 @@ def test_MaxPooling(self):
123134
input_image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
124135
feature = math.MaxPooling(ksize=2)
125136
pooled_image = feature.resolve(input_image)
126-
self.assertTrue(np.all(pooled_image == [[5, 6], [8, 9]]))
137+
self.assertTrue(xp.all(pooled_image == xp.asarray([[5, 6], [8, 9]]) ) )
127138

128139
def test_MinPooling(self):
129140
input_image = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])

0 commit comments

Comments
 (0)