Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: class index to ignore from the loss computation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -122,6 +124,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.ignore_index = ignore_index
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -163,6 +166,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

mask: torch.Tensor | None = None
if self.ignore_index is not None:
mask = (target != self.ignore_index).float()

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
Expand All @@ -180,6 +187,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if mask is not None:
input = input * mask
target = target * mask

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
Expand Down
8 changes: 8 additions & 0 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -99,6 +100,7 @@ def __init__(

use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -124,6 +126,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -161,6 +164,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if self.ignore_index is not None:
mask = (target != self.ignore_index).float()
input = input * mask
target = target * mask

loss: torch.Tensor | None = None
input = input.float()
target = target.float()
Expand Down
16 changes: 16 additions & 0 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -77,6 +78,7 @@ def __init__(
before any `reduction`.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: index of the class to ignore during calculation.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -101,6 +103,7 @@ def __init__(
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -129,8 +132,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
original_target = target
target = one_hot(target, num_classes=n_pred_ch)

if self.ignore_index is not None:
mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target

if mask_src.shape[1] == 1:
mask = (mask_src != self.ignore_index).to(input.dtype)
else:
# Fallback for cases where target is already one-hot
mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype)

input = input * mask
target = target * mask

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
Expand Down
125 changes: 102 additions & 23 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,48 +39,76 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

# Save original for masking
original_y_true = y_true if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
# Build mask after one_hot conversion
mask = torch.ones_like(y_true)
if self.ignore_index is not None:
if original_y_true is not None and self.to_onehot_y:
# Use original labels to build spatial mask
spatial_mask = (original_y_true != self.ignore_index).float()
elif self.ignore_index < y_true.shape[1]:
# For already one-hot: use ignored class channel
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
else:
# For sentinel values: any valid channel
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
mask = spatial_mask.expand_as(y_true)
y_pred = y_pred * mask
y_true = y_true * mask

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis)
fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
loss = torch.stack([back_dice, fore_dice], dim=-1)
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
if self.reduction == LossReduction.SUM.value:
return torch.sum(loss)
return loss


Expand All @@ -103,27 +131,36 @@ def __init__(
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
ignore_index: class index to ignore from the loss computation.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.ignore_index = ignore_index

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

# Save original for masking
original_y_true = y_true if self.ignore_index is not None else None

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
if self.ignore_index is not None:
# Replace ignore_index with valid class before one_hot
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
Expand All @@ -132,13 +169,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

# Build mask from original labels if available
spatial_mask: torch.Tensor | None = None
if self.ignore_index is not None:
if original_y_true is not None and self.to_onehot_y:
spatial_mask = (original_y_true != self.ignore_index).float()
elif self.ignore_index < y_true.shape[1]:
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
else:
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()

if spatial_mask is not None:
cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W]

if self.reduction == LossReduction.MEAN.value:
if self.ignore_index is not None and spatial_mask is not None:
# Apply mask to loss, then average over valid elements only
# loss has shape [B, 2, H, W], spatial_mask has shape [B, 1, H, W]
masked_loss = loss * spatial_mask.expand_as(loss)
return masked_loss.sum() / (spatial_mask.expand_as(loss).sum().clamp(min=1e-5))
return loss.mean()
if self.reduction == LossReduction.SUM.value:
return loss.sum()
return loss


Expand All @@ -162,6 +222,7 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
ignore_index: int | None = None,
):
"""
Args:
Expand All @@ -170,8 +231,7 @@ def __init__(
weight : weight for each loss function. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.


ignore_index: class index to ignore from the loss computation.

Example:
>>> import torch
Expand All @@ -187,10 +247,12 @@ def __init__(
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta, ignore_index=ignore_index)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma, delta=self.delta, ignore_index=ignore_index
)
self.ignore_index = ignore_index

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -207,25 +269,42 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

# Transform binary inputs to 2-channel space
if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")

n_pred_ch = y_pred.shape[1]
# Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
if self.ignore_index is not None:
mask = (y_true != self.ignore_index).float()
y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true)
y_true = one_hot(y_true_clean, num_classes=self.num_classes)
# Keep the channel-wise mask
y_true = y_true * mask
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
y_true = one_hot(y_true, num_classes=self.num_classes)

# Check if shapes match
if y_true.shape[1] == 1 and y_pred.shape[1] == 2:
if self.ignore_index is not None:
# Create mask for valid pixels
mask = (y_true != self.ignore_index).float()
# Set ignore_index values to 0 before conversion
y_true_clean = y_true * mask
# Convert to 2-channel
y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1)
# Apply mask to both channels so ignored pixels are all zeros
y_true = y_true * mask
else:
y_true = torch.cat([1 - y_true, y_true], dim=1)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if self.ignore_index is None and torch.max(y_true) > self.num_classes - 1:
raise ValueError(f"Invalid class index found. Maximum class should be {self.num_classes - 1}")

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
Expand Down
Loading
Loading