Skip to content

Commit 94e9e17

Browse files
authored
6007 reverse_indexing for PILReader (#6008)
Fixes #6007 ### Description - reverse_indexing = False: to support consistency with PIL/torchvision ```py img = LoadImage(image_only=True, ensure_channel_first=True, reverse_indexing=False)("MONAI-logo_color.png") # PILReader torchvision.utils.save_image(img, "MONAI-logo_color_torchvision.png", normalize=True) ``` - reverse_indexing = True: to support consistency with other backends in monai ```py img = LoadImage(image_only=True, ensure_channel_first=True, reader="PILReader", reverse_indexing=True)(filename) # PIL backend img_1 = LoadImage(image_only=True, ensure_channel_first=True, reader="ITKReader")(filename) # itk backend np.testing.assert_allclose(img, img_1) # true ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent f4902b2 commit 94e9e17

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

monai/data/image_reader.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,13 +1141,17 @@ class PILReader(ImageReader):
11411141
Args:
11421142
converter: additional function to convert the image data after `read()`.
11431143
for example, use `converter=lambda image: image.convert("LA")` to convert image format.
1144+
reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default,
1145+
so that output of the reader is consistent with the other readers. Set this option to ``False`` to use
1146+
the PIL backend's original spatial axes convention.
11441147
kwargs: additional args for `Image.open` API in `read()`, mode details about available args:
11451148
https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open
11461149
"""
11471150

1148-
def __init__(self, converter: Callable | None = None, **kwargs):
1151+
def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs):
11491152
super().__init__()
11501153
self.converter = converter
1154+
self.reverse_indexing = reverse_indexing
11511155
self.kwargs = kwargs
11521156

11531157
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
@@ -1194,8 +1198,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
11941198
It computes `spatial_shape` and stores it in meta dict.
11951199
When loading a list of files, they are stacked together at a new dimension as the first dimension,
11961200
and the metadata of the first image is used to represent the output metadata.
1197-
Note that it will swap axis 0 and 1 after loading the array because the `HW` definition in PIL
1198-
is different from other common medical packages.
1201+
Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading
1202+
the array because the spatial axes definition in PIL is different from other common medical packages.
11991203
12001204
Args:
12011205
img: a PIL Image object loaded from a file or a list of PIL Image objects.
@@ -1207,7 +1211,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
12071211
for i in ensure_tuple(img):
12081212
header = self._get_meta_dict(i)
12091213
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
1210-
data = np.moveaxis(np.asarray(i), 0, 1)
1214+
data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i)
12111215
img_array.append(data)
12121216
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
12131217
"no_channel" if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1

monai/transforms/io/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ class LoadImage(Transform):
116116
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
117117
(npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).
118118
119-
Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
120-
loading the array because the `HW` definition for non-medical specific file formats is different
121-
from other common medical packages.
119+
Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after
120+
loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition
121+
for non-medical specific file formats is different from other common medical packages.
122122
123123
See also:
124124

monai/transforms/io/dictionary.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ class LoadImaged(MapTransform):
5353
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
5454
(npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader).
5555
56-
Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
57-
loading the array because the `HW` definition for non-medical specific file formats is different
58-
from other common medical packages.
56+
Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after
57+
loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition
58+
for non-medical specific file formats is different from other common medical packages.
5959
6060
Note:
6161

tests/test_pil_reader.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
TEST_CASE_2 = [(128, 128, 3), ["test_image.png"], (128, 128, 3), (128, 128)]
2727

28-
TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128)]
28+
TEST_CASE_3 = [(128, 128, 4), ["test_image.png"], (128, 128, 4), (128, 128), False]
2929

3030
TEST_CASE_4 = [(128, 128), ["test_image1.png", "test_image2.png", "test_image3.png"], (3, 128, 128), (128, 128)]
3131

@@ -38,20 +38,21 @@
3838

3939
class TestPNGReader(unittest.TestCase):
4040
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
41-
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape):
41+
def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True):
4242
test_image = np.random.randint(0, 256, size=data_shape)
4343
with tempfile.TemporaryDirectory() as tempdir:
4444
for i, name in enumerate(filenames):
4545
filenames[i] = os.path.join(tempdir, name)
4646
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
47-
reader = PILReader(mode="r")
47+
reader = PILReader(mode="r", reverse_indexing=reverse)
4848
result = reader.get_data(reader.read(filenames))
4949
# load image by PIL and compare the result
5050
test_image = np.asarray(Image.open(filenames[0]))
5151

5252
self.assertTupleEqual(tuple(result[1]["spatial_shape"]), meta_shape)
5353
self.assertTupleEqual(result[0].shape, expected_shape)
54-
test_image = np.moveaxis(test_image, 0, 1)
54+
if reverse:
55+
test_image = np.moveaxis(test_image, 0, 1)
5556
if result[0].shape == test_image.shape:
5657
np.testing.assert_allclose(result[0], test_image)
5758
else:

0 commit comments

Comments
 (0)