Skip to content

Commit 451a85c

Browse files
committed
update docs of loss and network
update docs of loss and network
1 parent 450a848 commit 451a85c

37 files changed

+614
-438
lines changed

docs/source/pymic.loss.cls.rst

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,10 @@ pymic.loss.cls package
44
Submodules
55
----------
66

7-
pymic.loss.cls.ce module
7+
pymic.loss.cls.basic module
88
------------------------
99

10-
.. automodule:: pymic.loss.cls.ce
11-
:members:
12-
:undoc-members:
13-
:show-inheritance:
14-
15-
pymic.loss.cls.l1 module
16-
------------------------
17-
18-
.. automodule:: pymic.loss.cls.l1
19-
:members:
20-
:undoc-members:
21-
:show-inheritance:
22-
23-
pymic.loss.cls.mse module
24-
-------------------------
25-
26-
.. automodule:: pymic.loss.cls.mse
27-
:members:
28-
:undoc-members:
29-
:show-inheritance:
30-
31-
pymic.loss.cls.nll module
32-
-------------------------
33-
34-
.. automodule:: pymic.loss.cls.nll
10+
.. automodule:: pymic.loss.cls.basic
3511
:members:
3612
:undoc-members:
3713
:show-inheritance:

docs/source/pymic.loss.seg.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ pymic.loss.seg package
44
Submodules
55
----------
66

7+
pymic.loss.seg.abstract module
8+
------------------------
9+
10+
.. automodule:: pymic.loss.seg.abstract
11+
:members:
12+
:undoc-members:
13+
:show-inheritance:
14+
715
pymic.loss.seg.ce module
816
------------------------
917

docs/source/usage.fsl.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ hyper-parameters. For example, the following is a configuration for using ``2DUN
209209
bilinear = False
210210
deep_supervise= False
211211
212-
The ``SegNetDict`` in :mod:`pymic.net.neg_dict_seg` lists all the built-in network
212+
The ``SegNetDict`` in :mod:`pymic.net.net_dict_seg` lists all the built-in network
213213
structures currently implemented in PyMIC.
214214

215215
You can also define your own networks. To integrate your customized

pymic/io/image_read_write.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,12 @@ def load_nifty_volume_as_4d_array(filename):
1010
"""
1111
Read a nifty image and return a dictionay storing data array, origin,
1212
spacing and direction.\n
13-
output['data_array'] 4d array with shape [C, D, H, W];\n
14-
output['spacing'] a list of spacing in z, y, x axis;\n
15-
output['direction'] a 3x3 matrix for direction.
13+
output['data_array'] 4D array with shape [C, D, H, W];\n
14+
output['spacing'] A list of spacing in z, y, x axis;\n
15+
output['direction'] A 3x3 matrix for direction.
1616
17-
Args:
18-
filename (str): the input file name
19-
20-
Returns:
21-
dict: a dictionay storing data array, origin, spacing and direction.
17+
:param filename: (str) The input file name
18+
:return: A dictionay storing data array, origin, spacing and direction.
2219
"""
2320
img_obj = sitk.ReadImage(filename)
2421
data_array = sitk.GetArrayFromImage(img_obj)
@@ -43,15 +40,12 @@ def load_rgb_image_as_3d_array(filename):
4340
"""
4441
Read an RGB image and return a dictionay storing data array, origin,
4542
spacing and direction. \n
46-
output['data_array'] 3d array with shape [D, H, W]; \n
43+
output['data_array'] 3D array with shape [D, H, W]; \n
4744
output['spacing'] a list of spacing in z, y, x axis; \n
4845
output['direction'] a 3x3 matrix for direction.
4946
50-
Args:
51-
filename (str): the input file name
52-
53-
Returns:
54-
dict: a dictionay storing data array, origin, spacing and direction.
47+
:param filename: (str) The input file name
48+
:return: A dictionay storing data array, origin, spacing and direction.
5549
"""
5650
image = np.asarray(Image.open(filename))
5751
image_shape = image.shape
@@ -74,14 +68,11 @@ def load_rgb_image_as_3d_array(filename):
7468

7569
def load_image_as_nd_array(image_name):
7670
"""
77-
load an image and return a 4D array with shape [C, D, H, W],
71+
Load an image and return a 4D array with shape [C, D, H, W],
7872
or 3D array with shape [C, H, W].
7973
80-
Args:
81-
image_name (str): the image name.
82-
83-
Returns:
84-
dict: a dictionay storing data array, origin, spacing and direction.
74+
:param filename: (str) The input file name
75+
:return: A dictionay storing data array, origin, spacing and direction.
8576
"""
8677
if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or
8778
image_name.endswith(".mha")):
@@ -97,10 +88,9 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None):
9788
"""
9889
Save a numpy array as nifty image
9990
100-
Args:
101-
data (numpy.ndarray): a numpy array with shape [Depth, Height, Width].\n
102-
image_name (str): the ouput file name.\n
103-
reference_name (str): file name of the reference image of which
91+
:param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width].
92+
:param image_name: (str) The ouput file name.
93+
:param reference_name: (str) File name of the reference image of which
10494
meta information is used.
10595
"""
10696
img = sitk.GetImageFromArray(data)
@@ -114,12 +104,11 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None):
114104

115105
def save_array_as_rgb_image(data, image_name):
116106
"""
117-
Save a numpy array as rgb image
107+
Save a numpy array as rgb image.
118108
119-
Args:
120-
data (numpy.ndarray): a numpy array with shape [3, H, W] or
121-
[H, W, 3] or [H, W]. \n
122-
image_name (str): the output file name.
109+
:param data: (numpy.ndarray) A numpy array with shape [3, H, W] or
110+
[H, W, 3] or [H, W].
111+
:param image_name: (str) The output file name.
123112
"""
124113
data_dim = len(data.shape)
125114
if(data_dim == 3):
@@ -133,10 +122,9 @@ def save_nd_array_as_image(data, image_name, reference_name = None):
133122
"""
134123
Save a 3D or 2D numpy array as medical image or RGB image
135124
136-
Args:
137-
data (numpy.ndarray): a numpy array with shape [D, H, W] or [C, H, W]. \n
138-
image_name (str): the output file name. \n
139-
reference_name (str): file name of the reference image of which
125+
:param data: (numpy.ndarray) A numpy array with shape [3, H, W] or
126+
[H, W, 3] or [H, W].
127+
:param reference_name: (str) File name of the reference image of which
140128
meta information is used.
141129
"""
142130
data_dim = len(data.shape)
@@ -158,16 +146,14 @@ def rotate_nifty_volume_to_LPS(filename_or_image_dict, origin = None, direction
158146
'''
159147
Rotate the axis of a 3D volume to LPS
160148
161-
Args:
162-
filename_or_image_dict (str): filename of the nifty file (str) or image dictionary
149+
:param filename_or_image_dict: (str) Filename of the nifty file (str) or image dictionary
163150
returned by load_nifty_volume_as_4d_array. If supplied with the former,
164151
the flipped image data will be saved to override the original file.
165152
If supplied with the later, only flipped image data will be returned.\n
166-
origin (list or tuple): the origin of the image.\n
167-
direction (list or tuple): the direction of the image.
153+
:param origin: (list/tuple) The origin of the image.
154+
:param direction: (list or tuple) The direction of the image.
168155
169-
Returns:
170-
dict: a dictionary for image data and meta info, with ``data_array``,
156+
:return: A dictionary for image data and meta info, with ``data_array``,
171157
``origin``, ``direction`` and ``spacing``.
172158
'''
173159

pymic/io/nifty_dataset.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@ class NiftyDataset(Dataset):
1515
dimention order [C, D, H, W] for 3D images, and 3D tensors
1616
with dimention order [C, H, W] for 2D images.
1717
18-
Args:
19-
root_dir (str): Directory with all the images. \n
20-
csv_file (str): Path to the csv file with image names. \n
21-
modal_num (int): Number of modalities. \n
22-
with_label (bool): Load the data with segmentation ground truth or not. \n
23-
with_weight(bool): Load pixel-wise weight map or not. \n
24-
transform (list): list of transform to be applied on a sample.
18+
:param root_dir: (str) Directory with all the images.
19+
:param csv_file: (str) Path to the csv file with image names.
20+
:param modal_num: (int) Number of modalities.
21+
:param with_label: (bool) Load the data with segmentation ground truth or not.
22+
:param transform: (list) List of transforms to be applied on a sample.
23+
The built-in transforms can listed in :mod:`pymic.transform.trans_dict`.
2524
"""
2625
def __init__(self, root_dir, csv_file, modal_num = 1,
2726
with_label = False, transform=None):
@@ -93,13 +92,13 @@ class ClassificationDataset(NiftyDataset):
9392
dimention order [C, D, H, W] for 3D images, and 3D tensors
9493
with dimention order [C, H, W] for 2D images.
9594
96-
Args:
97-
root_dir (str): Directory with all the images. \n
98-
csv_file (str): Path to the csv file with image names. \n
99-
modal_num (int): Number of modalities. \n
100-
class_num (int): class number of the classificaiton task. \n
101-
with_label (bool): Load the data with segmentation ground truth or not. \n
102-
transform (list): list of transform to be applied on a sample.
95+
:param root_dir: (str) Directory with all the images.
96+
:param csv_file: (str) Path to the csv file with image names.
97+
:param modal_num: (int) Number of modalities.
98+
:param class_num: (int) Class number of the classificaiton task.
99+
:param with_label: (bool) Load the data with segmentation ground truth or not.
100+
:param transform: (list) List of transforms to be applied on a sample.
101+
The built-in transforms can listed in :mod:`pymic.transform.trans_dict`.
103102
"""
104103
def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2,
105104
with_label = False, transform=None):

pymic/loss/cls/basic.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
class AbstractClassificationLoss(nn.Module):
8+
"""
9+
Abstract Classification Loss.
10+
"""
11+
def __init__(self, params = None):
12+
super(AbstractClassificationLoss, self).__init__()
13+
14+
def forward(self, loss_input_dict):
15+
"""
16+
The arguments should be written in the `loss_input_dict` dictionary, and it has the
17+
following fields.
18+
19+
:param prediction: A prediction with shape of [N, C] where C is the class number.
20+
:param ground_truth: The corresponding ground truth, with shape of [N, 1].
21+
22+
Note that `prediction` is the digit output of a network, before using softmax.
23+
"""
24+
pass
25+
26+
class CrossEntropyLoss(AbstractClassificationLoss):
27+
"""
28+
Standard Softmax-based CE loss.
29+
"""
30+
def __init__(self, params = None):
31+
super(CrossEntropyLoss, self).__init__(params)
32+
self.ce_loss = nn.CrossEntropyLoss()
33+
34+
def forward(self, loss_input_dict):
35+
predict = loss_input_dict['prediction']
36+
labels = loss_input_dict['ground_truth']
37+
loss = self.ce_loss(predict, labels)
38+
return loss
39+
40+
class SigmoidCELoss(AbstractClassificationLoss):
41+
"""
42+
Sigmoid-based CE loss.
43+
"""
44+
def __init__(self, params = None):
45+
super(SigmoidCELoss, self).__init__(params)
46+
47+
def forward(self, loss_input_dict):
48+
predict = loss_input_dict['prediction']
49+
labels = loss_input_dict['ground_truth']
50+
predict = nn.Sigmoid()(predict) * 0.999 + 5e-4
51+
loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict)
52+
loss = loss.mean()
53+
return loss
54+
55+
class L1Loss(AbstractClassificationLoss):
56+
"""
57+
L1 (MAE) loss for classification
58+
"""
59+
def __init__(self, params = None):
60+
super(L1Loss, self).__init__(params)
61+
self.l1_loss = nn.L1Loss()
62+
63+
def forward(self, loss_input_dict):
64+
predict = loss_input_dict['prediction']
65+
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
66+
softmax = nn.Softmax(dim = 1)
67+
predict = softmax(predict)
68+
num_class = list(predict.size())[1]
69+
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
70+
soft_y = get_soft_label(labels, num_class, data_type)
71+
loss = self.l1_loss(predict, soft_y)
72+
return loss
73+
74+
class MSELoss(AbstractClassificationLoss):
75+
"""
76+
Mean Square Error loss for classification.
77+
"""
78+
def __init__(self, params = None):
79+
super(MSELoss, self).__init__(params)
80+
self.mse_loss = nn.MSELoss()
81+
82+
def forward(self, loss_input_dict):
83+
predict = loss_input_dict['prediction']
84+
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
85+
softmax = nn.Softmax(dim = 1)
86+
predict = softmax(predict)
87+
num_class = list(predict.size())[1]
88+
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
89+
soft_y = get_soft_label(labels, num_class, data_type)
90+
loss = self.mse_loss(predict, soft_y)
91+
return loss
92+
93+
class NLLLoss(AbstractClassificationLoss):
94+
"""
95+
The negative log likelihood loss for classification.
96+
"""
97+
def __init__(self, params = None):
98+
super(NLLLoss, self).__init__(params)
99+
self.nll_loss = nn.NLLLoss()
100+
101+
def forward(self, loss_input_dict):
102+
predict = loss_input_dict['prediction']
103+
labels = loss_input_dict['ground_truth']
104+
logsoft = nn.LogSoftmax(dim = 1)
105+
predict = logsoft(predict)
106+
loss = self.nll_loss(predict, labels)
107+
return loss

pymic/loss/cls/ce.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)