diff --git a/biapy/config/config.py b/biapy/config/config.py index 82b28b8a9..27d302a42 100644 --- a/biapy/config/config.py +++ b/biapy/config/config.py @@ -1162,12 +1162,14 @@ def __init__(self, job_dir: str, job_identifier: str): # * Semantic segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' # * Instance segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' # * Detection: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' - # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2' + # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2', 'nafnet' # * Super-resolution: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' # * Self-supervision: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2' # * Classification: 'simple_cnn', 'vit', 'efficientnet_b[0-7]' (only 2D) # * Image to image: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' _C.MODEL.ARCHITECTURE = "unet" + # Architecture of the network. Possible values are: + # * 'patchgan' # Number of feature maps on each level of the network. _C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256] # Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr' @@ -1306,6 +1308,27 @@ def __init__(self, job_dir: str, job_identifier: str): # Whether to use a pretrained version of STUNet on ImageNet _C.MODEL.STUNET.PRETRAINED = False + # NafNet + _C.MODEL.NAFNET = CN() + # Base number of channels (width) used in the first layer and base levels. + _C.MODEL.NAFNET.WIDTH = 16 + # Number of NAFBlocks stacked at the bottleneck (deepest level). + _C.MODEL.NAFNET.MIDDLE_BLK_NUM = 12 + # Number of NAFBlocks assigned to each downsampling level of the encoder. + _C.MODEL.NAFNET.ENC_BLK_NUMS = [2, 2, 4, 8] + # Number of NAFBlocks assigned to each upsampling level of the decoder. + _C.MODEL.NAFNET.DEC_BLK_NUMS = [2, 2, 2, 2] + # Channel expansion factor for the depthwise convolution within the gating unit. + _C.MODEL.NAFNET.DW_EXPAND = 2 + # Expansion factor for the hidden layer within the feed-forward network. + _C.MODEL.NAFNET.FFN_EXPAND = 2 + # Discriminator architecture + _C.MODEL.NAFNET.ARCHITECTURE_D = "patchgan" + # Discriminator PATCHGAN + _C.MODEL.NAFNET.PATCHGAN = CN() + # Number of initial convolutional filters in the first layer of the discriminator. + _C.MODEL.NAFNET.PATCHGAN.BASE_FILTERS = 64 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loss # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1371,7 +1394,21 @@ def __init__(self, job_dir: str, job_identifier: str): _C.LOSS.CONTRAST.MEMORY_SIZE = 5000 _C.LOSS.CONTRAST.PROJ_DIM = 256 _C.LOSS.CONTRAST.PIXEL_UPD_FREQ = 10 - + + # Fine-grained GAN composition. Set any weight to 0.0 to disable that term. + # Used when LOSS.TYPE == "COMPOSED_GAN". + _C.LOSS.COMPOSED_GAN = CN() + # Weight for adversarial BCE term. + _C.LOSS.COMPOSED_GAN.LAMBDA_GAN = 1.0 + # Weight for L1 reconstruction term. + _C.LOSS.COMPOSED_GAN.LAMBDA_RECON = 10.0 + # Weight for MSE reconstruction term. + _C.LOSS.COMPOSED_GAN.DELTA_MSE = 0.0 + # Weight for VGG perceptual term. + _C.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL = 0.0 + # Weight for SSIM term. + _C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Training phase # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1379,14 +1416,14 @@ def __init__(self, job_dir: str, job_identifier: str): _C.TRAIN.ENABLE = False # Enable verbosity _C.TRAIN.VERBOSE = False - # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" - _C.TRAIN.OPTIMIZER = "SGD" - # Learning rate - _C.TRAIN.LR = 1.0e-4 + # Optimizer(s) to use. Possible values: "SGD", "ADAM" or "ADAMW". + _C.TRAIN.OPTIMIZER = ["SGD"] + # Learning rate(s). + _C.TRAIN.LR = [1.0e-4] # Weight decay _C.TRAIN.W_DECAY = 0.02 # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers - _C.TRAIN.OPT_BETAS = (0.9, 0.999) + _C.TRAIN.OPT_BETAS = [(0.9, 0.999)] # Batch size _C.TRAIN.BATCH_SIZE = 2 # If memory or # gpus is limited, use this variable to maintain the effective batch size, which is diff --git a/biapy/data/generators/__init__.py b/biapy/data/generators/__init__.py index ca072706b..783009fa2 100644 --- a/biapy/data/generators/__init__.py +++ b/biapy/data/generators/__init__.py @@ -246,7 +246,7 @@ def create_train_val_augmentors( dic["zflip"] = cfg.AUGMENTOR.ZFLIP if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet': dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR @@ -293,7 +293,7 @@ def create_train_val_augmentors( ) if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet': dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 8991943f7..700058c4a 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -21,7 +21,7 @@ def prepare_optimizer( cfg: CN, model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel, steps_per_epoch: int, -) -> Tuple[Optimizer, Scheduler | None]: +) -> Tuple[list[Optimizer], list[Scheduler | None]]: """ Create and configure the optimizer and learning rate scheduler for the given model. @@ -40,50 +40,76 @@ def prepare_optimizer( Returns ------- - optimizer : Optimizer - Configured optimizer for the model. - lr_scheduler : Scheduler or None - Configured learning rate scheduler, or None if not specified. + optimizers : List[Optimizer] + Configured optimizers for the models. + lr_schedulers : List[Scheduler | None] + Configured learning rate schedulers, or None if not specified. """ - lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR - opt_args = {} - if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]: - opt_args["betas"] = cfg.TRAIN.OPT_BETAS - optimizer = timm.optim.create_optimizer_v2( - model_without_ddp, - opt=cfg.TRAIN.OPTIMIZER, - lr=lr, - weight_decay=cfg.TRAIN.W_DECAY, - **opt_args, - ) - print(optimizer) - - # Learning rate schedulers - lr_scheduler = None - if cfg.TRAIN.LR_SCHEDULER.NAME != "": - if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": - lr_scheduler = ReduceLROnPlateau( - optimizer, - patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, - factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": - lr_scheduler = WarmUpCosineDecayScheduler( - lr=cfg.TRAIN.LR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, - epochs=cfg.TRAIN.EPOCHS, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler = OneCycleLR( - optimizer, - cfg.TRAIN.LR, - epochs=cfg.TRAIN.EPOCHS, - steps_per_epoch=steps_per_epoch, - ) - - return optimizer, lr_scheduler + + optimizers = [] + lr_schedulers = [] + + if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None: + param_groups = [ + # Generator + [p for n, p in model_without_ddp.named_parameters() if not n.startswith("discriminator.")], # should this be and p.requires_grad, same below? + # Discriminator + [p for p in model_without_ddp.discriminator.parameters()] + ] + else: + param_groups = [[p for p in model_without_ddp.parameters()]] + + ## Not quite sure if this is the best place to do this + if len(cfg.TRAIN.OPTIMIZER) != len(param_groups): + raise ValueError( + f"Configuration mismatch: You requested {len(cfg.TRAIN.OPTIMIZER)} optimizers, " + f"but the model has {len(param_groups)} parameter group(s). " + f"Check your TRAIN.OPTIMIZER list in the config." + ) + + for i in range(len(cfg.TRAIN.OPTIMIZER)): + lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else [cfg.TRAIN.LR_SCHEDULER.MIN_LR] * len(cfg.TRAIN.LR) + opt_args = {} + if cfg.TRAIN.OPTIMIZER[i] in ["ADAM", "ADAMW"]: + opt_args["betas"] = cfg.TRAIN.OPT_BETAS[i] if i < len(cfg.TRAIN.OPT_BETAS) else cfg.TRAIN.OPT_BETAS[0] + optimizer = timm.optim.create_optimizer_v2( + param_groups[i], + opt=cfg.TRAIN.OPTIMIZER[i], + lr=lr[i], + weight_decay=cfg.TRAIN.W_DECAY, + **opt_args, + ) + print(optimizer) + optimizers.append(optimizer) + + # Learning rate schedulers + lr_scheduler = None + if cfg.TRAIN.LR_SCHEDULER.NAME != "": + if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + lr_scheduler = ReduceLROnPlateau( + optimizer, + patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, + factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": + lr_scheduler = WarmUpCosineDecayScheduler( + lr=cfg.TRAIN.LR[i], + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, + warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, + epochs=cfg.TRAIN.EPOCHS, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler = OneCycleLR( + optimizer, + cfg.TRAIN.LR[i], + epochs=cfg.TRAIN.EPOCHS, + steps_per_epoch=steps_per_epoch, + ) + + lr_schedulers.append(lr_scheduler) + + return optimizers, lr_schedulers def build_callbacks(cfg: CN) -> EarlyStopping | None: diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 4f28f200d..68eafa849 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -1009,6 +1009,19 @@ def train(self): self.plot_values["loss"].append(train_stats["loss"]) if self.val_generator: self.plot_values["val_loss"].append(test_stats["loss"]) + extra_loss_keys = [k for k in train_stats if "loss" in k and k != "loss"] + for loss_key in extra_loss_keys: + val_loss_key = f"val_{loss_key}" + + if loss_key not in self.plot_values: + self.plot_values[loss_key] = [] + if self.val_generator: + self.plot_values[val_loss_key] = [] + + # Append the values + self.plot_values[loss_key].append(train_stats[loss_key]) + if self.val_generator: + self.plot_values[val_loss_key].append(test_stats.get(loss_key, 0.0)) for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]]) if self.val_generator: diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index f4b5d5b6a..a6da41bc6 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -1242,7 +1242,7 @@ def sort_key(item): ], "LOSS.CLASS_REBALANCE not in ['none', 'auto'] for INSTANCE_SEG workflow" elif cfg.PROBLEM.TYPE == "DENOISING": loss = "MSE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE - assert loss == "MSE", "LOSS.TYPE must be 'MSE'" + assert loss in ["MSE", "COMPOSED_GAN"], "LOSS.TYPE must be in ['MSE', 'COMPOSED_GAN'] for DENOISING" elif cfg.PROBLEM.TYPE == "CLASSIFICATION": loss = "CE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE assert loss == "CE", "LOSS.TYPE must be 'CE'" @@ -1797,12 +1797,19 @@ def sort_key(item): #### Denoising #### elif cfg.PROBLEM.TYPE == "DENOISING": - if cfg.DATA.TEST.LOAD_GT: - raise ValueError( - "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" - ) - if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): - raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") + if cfg.PROBLEM.DENOISING.LOAD_GT_DATA or cfg.LOSS.TYPE == "COMPOSED_GAN": + if not cfg.DATA.TRAIN.GT_PATH and not cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA: + raise ValueError( + "Supervised denoising (e.g., with COMPOSED_GAN or LOAD_GT_DATA=True) " + "requires ground truth. 'DATA.TRAIN.GT_PATH' must be provided." + ) + else: + if cfg.DATA.TEST.LOAD_GT: + raise ValueError( + "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" + ) + if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): + raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") if cfg.MODEL.SOURCE == "torchvision": raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in denoising workflow") @@ -2341,6 +2348,7 @@ def sort_key(item): "hrnet48", "hrnet64", "stunet", + "nafnet", ], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']" if ( model_arch @@ -2514,6 +2522,7 @@ def sort_key(item): "hrnet48", "hrnet64", "stunet", + "nafnet", ]: raise ValueError( "Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']".format( @@ -2714,11 +2723,26 @@ def sort_key(item): assert cfg.MODEL.OUT_CHECKPOINT_FORMAT in ["pth", "safetensors"], "MODEL.OUT_CHECKPOINT_FORMAT not in ['pth', 'safetensors']" ### Train ### - assert cfg.TRAIN.OPTIMIZER in [ - "SGD", - "ADAM", - "ADAMW", - ], "TRAIN.OPTIMIZER not in ['SGD', 'ADAM', 'ADAMW']" + if not isinstance(cfg.TRAIN.OPTIMIZER, list): + raise ValueError("'TRAIN.OPTIMIZER' must be a list") + if not isinstance(cfg.TRAIN.LR, list): + raise ValueError("'TRAIN.LR' must be a list") + if not isinstance(cfg.TRAIN.OPT_BETAS, list): + raise ValueError("'TRAIN.OPT_BETAS' must be a list") + if len(cfg.TRAIN.OPTIMIZER) != len(cfg.TRAIN.LR): + raise ValueError("'TRAIN.OPTIMIZER' and 'TRAIN.LR' must have the same length") + print(cfg.TRAIN.OPT_BETAS) + print(len(cfg.TRAIN.OPT_BETAS)) + if len(cfg.TRAIN.OPT_BETAS) not in [1, len(cfg.TRAIN.OPTIMIZER)]: + raise ValueError("'TRAIN.OPT_BETAS' must have length 1 or match 'TRAIN.OPTIMIZER' length") + + for beta_pair in cfg.TRAIN.OPT_BETAS: + if not isinstance(beta_pair, (list, tuple)) or len(beta_pair) != 2: + raise ValueError("Each entry in 'TRAIN.OPT_BETAS' must be a tuple/list of length 2") + + for opt in cfg.TRAIN.OPTIMIZER: + if opt not in ["SGD", "ADAM", "ADAMW"]: + raise ValueError("'TRAIN.OPTIMIZER' values must be in ['SGD', 'ADAM', 'ADAMW']") if cfg.TRAIN.ENABLE and cfg.TRAIN.LR_SCHEDULER.NAME != "": if cfg.TRAIN.LR_SCHEDULER.NAME not in [ diff --git a/biapy/engine/denoising.py b/biapy/engine/denoising.py index 9cac8d574..a22bcb20a 100644 --- a/biapy/engine/denoising.py +++ b/biapy/engine/denoising.py @@ -26,7 +26,7 @@ from biapy.engine.base_workflow import Base_Workflow from biapy.data.data_manipulation import save_tif from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger -from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation +from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation, ComposedGANLoss class Denoising_Workflow(Base_Workflow): @@ -166,6 +166,8 @@ def define_metrics(self): # print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)") if self.cfg.LOSS.TYPE == "MSE": self.loss = loss_encapsulation(n2v_loss_mse) + elif self.cfg.LOSS.TYPE == "COMPOSED_GAN": + self.loss = ComposedGANLoss(cfg=self.cfg, device=self.device) super().define_metrics() @@ -232,7 +234,11 @@ def metric_calculation( with torch.no_grad(): for i, metric in enumerate(list_to_use): - val = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous()) + if _targets.shape[1] == _output.shape[1]: + target_for_metric = _targets.contiguous() + else: + target_for_metric = _targets[:, _output.shape[1]:].contiguous() + val = metric(_output.contiguous(), target_for_metric) val = val.item() if not torch.isnan(val) else 0 out_metrics[list_names_to_use[i]] = val diff --git a/biapy/engine/metrics.py b/biapy/engine/metrics.py index 32dab4283..09d4df49a 100644 --- a/biapy/engine/metrics.py +++ b/biapy/engine/metrics.py @@ -18,6 +18,8 @@ import torch.nn.functional as F import torch.nn as nn from typing import Optional, List, Tuple, Dict +from torchvision import transforms +from torchvision.models import vgg16, VGG16_Weights def jaccard_index_numpy(y_true, y_pred): """ @@ -2162,4 +2164,224 @@ def forward( loss = loss / B iou = iou / B - return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals \ No newline at end of file + return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals + +class VGGLoss(nn.Module): + """Perceptual loss based on VGG16 feature activations. + + This loss compares intermediate VGG feature maps of prediction and target + images using an L1 distance. It is commonly used as a perceptual term in + image-to-image GAN training. + + Notes + ----- + - Uses pretrained ``torchvision.models.vgg16`` features up to layer ``:16``. + - Supports both 2D `(B, C, H, W)` and 3D `(B, C, D, H, W)` tensors. + For 3D inputs, depth is folded into batch to reuse 2D VGG. + - Single-channel inputs are replicated to 3 channels before VGG. + + References + ---------- + - Johnson et al., "Perceptual Losses for Real-Time Style Transfer and + Super-Resolution", ECCV 2016. + https://arxiv.org/abs/1603.08155 + - Implementation adapted for this project from: + https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + """ + + def __init__(self, device): + """Initialize VGG perceptual loss. + + Parameters + ---------- + device : torch.device + Device where VGG features and loss operations are executed. + """ + super().__init__() + self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].eval().to(device) + for param in self.vgg.parameters(): + param.requires_grad = False + self.loss = nn.L1Loss() + self.preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, pred, target): + """Compute perceptual distance between prediction and target. + + Parameters + ---------- + pred : torch.Tensor or dict + Predicted image tensor. If dict, prediction is taken from ``pred['pred']``. + target : torch.Tensor or dict + Target image tensor. If dict, target is taken from ``target['pred']``. + + Returns + ------- + torch.Tensor + Scalar perceptual loss value (L1 over VGG features). + """ + if isinstance(pred, dict): + pred = pred["pred"] + if isinstance(target, dict): + target = target["pred"] + + # If 3D, fold Depth (dim 2) into Batch (dim 0) -> (B*D, C, H, W) + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + + # 2D behavior remains identical + if pred.shape[1] == 1: + pred = pred.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + pred = self.preprocess(pred) + target = self.preprocess(target) + pred_vgg = self.vgg(pred) + target_vgg = self.vgg(target) + return self.loss(pred_vgg, target_vgg) + +class ComposedGANLoss(nn.Module): + """Weighted composite loss for generator and discriminator training. + + This class combines multiple objectives for GAN-based image restoration: + + - Adversarial BCE term + - L1 reconstruction term + - MSE reconstruction term + - VGG perceptual term + - SSIM term + + Each term is controlled by configuration weights under + ``LOSS.COMPOSED_GAN``. Heavy components (VGG/SSIM modules) are created only + when their weight is greater than zero. + + References + ---------- + - Isola et al., "Image-to-Image Translation with Conditional Adversarial + Networks", CVPR 2017 (pix2pix). + https://arxiv.org/abs/1611.07004 + - Generator family inspiration (NAFNet/NAFSSR): + Chu et al., "NAFSSR: Stereo Image Super-Resolution Using NAFNet", + CVPR Workshops 2022. + https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/html/Chu_NAFSSR_Stereo_Image_Super-Resolution_Using_NAFNet_CVPRW_2022_paper.html + - Structural/perceptual metrics are implemented with torchmetrics + (e.g., ``torchmetrics.image.StructuralSimilarityIndexMeasure``). + - Implementation adapted for this project from: + https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + """ + + def __init__(self, cfg, device): + """Initialize composed GAN loss from configuration. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Global configuration node. Uses ``cfg.LOSS.COMPOSED_GAN`` weights. + device : torch.device + Device where loss terms are computed. + """ + super().__init__() + self.device = device + self.w_gan = cfg.LOSS.COMPOSED_GAN.LAMBDA_GAN + self.w_l1 = cfg.LOSS.COMPOSED_GAN.LAMBDA_RECON + self.w_vgg = cfg.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL + self.w_ssim = cfg.LOSS.COMPOSED_GAN.GAMMA_SSIM + self.w_mse = cfg.LOSS.COMPOSED_GAN.DELTA_MSE + + # Dont load the vgg if not + if self.w_vgg > 0: + self.vgg = VGGLoss(device) + if self.w_ssim > 0: + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) + + # Standard lightweight losses are always initialized + self.l1 = nn.L1Loss() + self.mse = nn.MSELoss() + self.bce = nn.BCEWithLogitsLoss() + + def forward_generator(self, pred, target, d_fake): + """Compute weighted generator loss. + + Parameters + ---------- + pred : torch.Tensor or dict + Generator prediction. If dict, reads ``pred['pred']``. + target : torch.Tensor or dict + Ground-truth target. If dict, reads ``target['pred']``. + d_fake : torch.Tensor + Discriminator logits for generated samples. + + Returns + ------- + torch.Tensor + Scalar generator loss as weighted sum of active terms. + """ + # Dict extraction + if isinstance(pred, dict): pred = pred["pred"] + if isinstance(target, dict): target = target["pred"] + + # NaN Band-aid + pred = torch.nan_to_num(pred, nan=0.0, posinf=1.0, neginf=-1.0) + target = torch.nan_to_num(target, nan=0.0, posinf=1.0, neginf=-1.0) + + total_loss = torch.tensor(0.0, device=self.device) + + # 2. Dynamically build the loss based on config weights + if self.w_l1 > 0: + total_loss += self.w_l1 * self.l1(pred, target) + + if self.w_mse > 0: + total_loss += self.w_mse * self.mse(pred, target) + + if self.w_vgg > 0: + total_loss += self.w_vgg * self.vgg(pred, target) + + if self.w_ssim > 0: + # SSIM requires 4D tensors. Safely route 3D to 2D slices. + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred_ssim = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target_ssim = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + total_loss += self.w_ssim * (1.0 - self.ssim(pred_ssim, target_ssim)) + else: + total_loss += self.w_ssim * (1.0 - self.ssim(pred, target)) + + if self.w_gan > 0: + total_loss += self.w_gan * self.bce(d_fake, torch.ones_like(d_fake)) + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in generator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss + + def forward_discriminator(self, d_real, d_fake): + """Compute discriminator adversarial loss. + + Uses BCE with one-sided label smoothing for real logits. + + Parameters + ---------- + d_real : torch.Tensor + Discriminator logits for real samples. + d_fake : torch.Tensor + Discriminator logits for generated samples. + + Returns + ------- + torch.Tensor + Scalar discriminator loss. + """ + # Calculate Adversarial Loss for Discriminator + real_loss = self.bce(d_real, torch.full_like(d_real, 0.9)) # Label smoothing (0.9 instead of 1.0) + fake_loss = self.bce(d_fake, torch.zeros_like(d_fake)) + total_loss = (real_loss + fake_loss) / 2.0 + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in discriminator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss \ No newline at end of file diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index e4b05e3b9..54b689e2e 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -29,11 +29,11 @@ def train_one_epoch( metric_function: Callable, prepare_targets: Callable, data_loader: DataLoader, - optimizer: Optimizer, + optimizer: list[Optimizer], device: torch.device, epoch: int, log_writer: Optional[TensorboardLogger] = None, - lr_scheduler: Optional[Scheduler] = None, + lr_scheduler: Optional[list[Optional[Scheduler]]] = None, verbose: bool = False, memory_bank: Optional[MemoryBank] = None, total_iters: int=0, @@ -61,7 +61,7 @@ def train_one_epoch( Function to prepare targets for loss/metrics. data_loader : DataLoader Training data loader. - optimizer : Optimizer + optimizer : List[Optimizer] Optimizer for model parameters. device : torch.device Device to use. @@ -69,7 +69,7 @@ def train_one_epoch( Current epoch number. log_writer : TensorboardLogger, optional Logger for TensorBoard. - lr_scheduler : Scheduler, optional + lr_scheduler : List[Scheduler], optional Learning rate scheduler. verbose : bool, optional Verbosity flag. @@ -89,16 +89,21 @@ def train_one_epoch( """ # Switch to training mode model.train(True) + has_discriminator = hasattr(model, "discriminator") and model.discriminator is not None + lr_scheduler = [None] * len(optimizer) if lr_scheduler is None else lr_scheduler # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ", verbose=verbose) - metric_logger.add_meter("loss", SmoothedValue()) + for i in range(len(optimizer)): + loss_name = "loss" if i == 0 else f"loss_{i}" + metric_logger.add_meter(loss_name, SmoothedValue()) # Set up the header for logging header = "Epoch: [{}]".format(epoch + 1) print_freq = 10 - optimizer.zero_grad() + for opt in optimizer: + opt.zero_grad() for step, (batch, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): @@ -107,10 +112,10 @@ def train_one_epoch( if ( epoch % cfg.TRAIN.ACCUM_ITER == 0 and cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine" - and lr_scheduler - and isinstance(lr_scheduler, WarmUpCosineDecayScheduler) ): - lr_scheduler.adjust_learning_rate(optimizer, step / len(data_loader) + epoch) + for sched, opt in zip(lr_scheduler, optimizer): + if sched and isinstance(sched, WarmUpCosineDecayScheduler): + sched.adjust_learning_rate(opt, step / len(data_loader) + epoch) # Gather inputs targets = prepare_targets(targets, batch) @@ -125,7 +130,21 @@ def train_one_epoch( outputs = model_call_func(batch, is_train=True) # Loss function call - if memory_bank is not None: + losses = [] + if has_discriminator and len(optimizer) > 1: + fake_img = outputs["pred"] if isinstance(outputs, dict) else outputs + fake_img = torch.clamp(fake_img, 0, 1) + + d_fake_for_g = model.discriminator(fake_img) + loss_g = loss_function.forward_generator(fake_img, targets, d_fake_for_g) + losses.append(loss_g) + + d_real = model.discriminator(targets) + d_fake = model.discriminator(fake_img.detach()) + loss_d = loss_function.forward_discriminator(d_real, d_fake) + losses.append(loss_d) + + elif memory_bank is not None: if total_iters + step >= contrast_warmup_iters: with_embed = True else: @@ -144,20 +163,23 @@ def train_one_epoch( memory_bank.dequeue_and_enqueue( outputs['key'], targets.detach(), ) + losses.append(loss) else: loss = loss_function(outputs, targets) + losses.append(loss) # Separate metric if precalculated inside the loss (e.g. Embedding loss) precalculated_metric, precalculated_metric_name = None, None - if isinstance(loss, tuple): - precalculated_metric = loss[1] - precalculated_metric_name = loss[2] - loss = loss[0] + if isinstance(losses[0], tuple): + precalculated_metric = losses[0][1] + precalculated_metric_name = losses[0][2] + losses[0] = losses[0][0] - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) + for l_val in losses: + loss_value = l_val.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) # Calculate the metrics if precalculated_metric is None: @@ -166,32 +188,43 @@ def train_one_epoch( metric_logger.meters[precalculated_metric_name].update(precalculated_metric) # Forward pass scaling the loss - loss /= cfg.TRAIN.ACCUM_ITER if (step + 1) % cfg.TRAIN.ACCUM_ITER == 0: - loss.backward() - optimizer.step() # update weight - optimizer.zero_grad() - if lr_scheduler and isinstance(lr_scheduler, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler.step() + for i, (opt, loss_tensor) in enumerate(zip(optimizer, losses)): + loss_tensor = loss_tensor / cfg.TRAIN.ACCUM_ITER + + if has_discriminator and i == 1: + opt.zero_grad() + + loss_tensor.backward() + opt.step() # update weight + opt.zero_grad() + + if lr_scheduler[i] and isinstance(lr_scheduler[i], OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler[i].step() if device.type != "cpu": getattr(torch, device.type).synchronize() # Update loss in loggers - metric_logger.update(loss=loss_value) - loss_value_reduce = all_reduce_mean(loss_value) - if log_writer: - log_writer.update(loss=loss_value_reduce, head="loss") + for i, loss_tensor in enumerate(losses): + loss_name = "loss" if i == 0 else f"loss_{i}" + val = loss_tensor.item() * cfg.TRAIN.ACCUM_ITER + metric_logger.update(**{loss_name: val}) + loss_value_reduce = all_reduce_mean(val) + if log_writer: + log_writer.update(head="loss", **{loss_name: loss_value_reduce}) # Update lr in loggers - max_lr = 0.0 - for group in optimizer.param_groups: - max_lr = max(max_lr, group["lr"]) - if step == 0: - metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) - metric_logger.update(lr=max_lr) - if log_writer: - log_writer.update(lr=max_lr, head="opt") + for i, opt in enumerate(optimizer): + lr_name = "lr" if i == 0 else f"lr_{i}" + max_lr = 0.0 + for group in opt.param_groups: + max_lr = max(max_lr, group["lr"]) + if step == 0: + metric_logger.add_meter(lr_name, SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.update(**{lr_name: max_lr}) + if log_writer: + log_writer.update(head="opt", **{lr_name: max_lr}) # Gather the stats from all processes metric_logger.synchronize_between_processes() @@ -209,7 +242,7 @@ def evaluate( prepare_targets: Callable, epoch: int, data_loader: DataLoader, - lr_scheduler: Optional[Scheduler] = None, + lr_scheduler: Optional[list[Optional[Scheduler]]] = None, memory_bank: Optional[MemoryBank] = None, ): """ @@ -246,9 +279,13 @@ def evaluate( dict Dictionary of averaged metrics for the validation set. """ + has_discriminator = hasattr(model, "discriminator") and model.discriminator is not None # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ") - metric_logger.add_meter("loss", SmoothedValue()) + num_losses = 2 if has_discriminator else 1 + for i in range(num_losses): + loss_name = "loss" if i == 0 else f"loss_{i}" + metric_logger.add_meter(loss_name, SmoothedValue()) header = "Epoch: [{}]".format(epoch + 1) # Switch to evaluation mode @@ -261,10 +298,24 @@ def evaluate( targets = prepare_targets(targets, images) # Pass the images through the model - outputs = model_call_func(images, is_train=True) + outputs = model_call_func(images, is_train=True) # Im not Undertanding why is this True? # Loss function call - if memory_bank is not None: + losses = [] + if has_discriminator: + fake_img = outputs["pred"] if isinstance(outputs, dict) else outputs + fake_img = torch.clamp(fake_img, 0, 1) + + d_fake_for_g = model.discriminator(fake_img) + loss_g = loss_function.forward_generator(fake_img, targets, d_fake_for_g) + losses.append(loss_g) + + d_real = model.discriminator(targets) + d_fake = model.discriminator(fake_img.detach()) + loss_d = loss_function.forward_discriminator(d_real, d_fake) + losses.append(loss_d) + + elif memory_bank is not None: with_embed = False outputs = { @@ -274,22 +325,24 @@ def evaluate( 'pixel_queue': memory_bank.pixel_queue, 'segment_queue': memory_bank.segment_queue, } - loss = loss_function(outputs, targets, with_embed=with_embed) + losses.append(loss) else: loss = loss_function(outputs, targets) + losses.append(loss) # Separate metric if precalculated inside the loss (e.g. Embedding loss) precalculated_metric, precalculated_metric_name = None, None - if isinstance(loss, tuple): - precalculated_metric = loss[1] - precalculated_metric_name = loss[2] - loss = loss[0] - - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) + if isinstance(losses[0], tuple): + precalculated_metric = losses[0][1] + precalculated_metric_name = losses[0][2] + losses[0] = losses[0][0] + + for l_val in losses: + loss_value = l_val.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) # Calculate the metrics if precalculated_metric is not None: @@ -298,7 +351,9 @@ def evaluate( metric_function(outputs, targets, metric_logger=metric_logger) # Update loss in loggers - metric_logger.update(loss=loss) + for i, loss_tensor in enumerate(losses): + loss_name = "loss" if i == 0 else f"loss_{i}" + metric_logger.update(**{loss_name: loss_tensor.item()}) # Gather the stats from all processes metric_logger.synchronize_between_processes() @@ -306,10 +361,10 @@ def evaluate( print("[Val] averaged stats:", metric_logger) # Apply reduceonplateau scheduler if the global validation has been reduced - if ( - lr_scheduler - and isinstance(lr_scheduler, ReduceLROnPlateau) - and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau" - ): - lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) + if lr_scheduler and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + for i, sched in enumerate(lr_scheduler): + if sched and isinstance(sched, ReduceLROnPlateau): + loss_name = "loss" if i == 0 else f"loss_{i}" + sched.step(metric_logger.meters[loss_name].global_avg, epoch=epoch) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 5a0f98b23..e34740190 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -366,6 +366,22 @@ def build_model( ) model = MaskedAutoencoderViT(**args) # type: ignore callable_model = MaskedAutoencoderViT # type: ignore + elif modelname == "nafnet": + args = dict( + img_channel=cfg.DATA.PATCH_SIZE[-1], + width=cfg.MODEL.NAFNET.WIDTH, + middle_blk_num=cfg.MODEL.NAFNET.MIDDLE_BLK_NUM, + enc_blk_nums=cfg.MODEL.NAFNET.ENC_BLK_NUMS, + dec_blk_nums=cfg.MODEL.NAFNET.DEC_BLK_NUMS, + drop_out_rate=cfg.MODEL.DROPOUT_VALUES[0], + dw_expand=cfg.MODEL.NAFNET.DW_EXPAND, + ffn_expand=cfg.MODEL.NAFNET.FFN_EXPAND, + discriminator_arch=cfg.MODEL.NAFNET.ARCHITECTURE_D, + patchgan_base_filters=cfg.MODEL.NAFNET.PATCHGAN.BASE_FILTERS, + ) + callable_model = NAFNet # type: ignore + model = callable_model(**args) # type: ignore + # Check the network created model.to(device) if cfg.PROBLEM.NDIM == "2D": @@ -405,7 +421,6 @@ def build_model( return model, str(callable_model.__name__), collected_sources, all_import_lines, scanned_files, args, network_stride # type: ignore - def init_embedding_output(model: nn.Module, n_sigma: int = 2): """ Initialize the output layer of the model for embedding. diff --git a/biapy/models/nafnet.py b/biapy/models/nafnet.py new file mode 100644 index 000000000..38b7f0cc0 --- /dev/null +++ b/biapy/models/nafnet.py @@ -0,0 +1,359 @@ +"""NAFNet model components and GAN discriminator builder utilities. + +This module provides: + +1. Lightweight building blocks (`SimpleGate`, `LayerNorm2d`, `NAFBlock`) used + by NAFNet. +2. The `NAFNet` encoder-decoder model for image restoration / image-to-image + workflows. +3. A discriminator builder function used by GAN-based training setups in BiaPy. + +Compared with traditional restoration backbones, NAFNet simplifies nonlinear +design while preserving strong reconstruction quality via gated depthwise blocks +and residual scaling. + +Reference +--------- +`Simple Baselines for Image Restoration `_. + +Related Work +------------ +The generator design is also inspired by the NAFSSR family: +`NAFSSR: Stereo Image Super-Resolution Using NAFNet +. +Implementation adapted for this project from: +https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge +Citation +-------- +Chu, Xiaojie and Chen, Liangyu and Yu, Wenqing. "NAFSSR: Stereo Image +Super-Resolution Using NAFNet." CVPR Workshops, 2022. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from yacs.config import CfgNode as CN +from torchinfo import summary + +from biapy.models.patchgan import PatchGANDiscriminator + +class SimpleGate(nn.Module): + """Simple channel-gating operator used in NAF blocks. + + The input tensor is split into two equal channel groups and both parts are + multiplied element-wise. + """ + + def forward(self, x): + """Apply channel-wise gating. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)` where `C` must be divisible + by 2. + + Returns + ------- + torch.Tensor + Tensor with shape `(N, C/2, H, W)` obtained by multiplying both + channel chunks element-wise. + """ + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class LayerNorm2d(nn.Module): + """Layer normalization over channel dimension for 2D features. + + This normalization computes mean and variance across channels for each + spatial position and applies learned affine parameters. + """ + + def __init__(self, channels, eps=1e-6): + """Initialize layer normalization parameters. + + Parameters + ---------- + channels : int + Number of channels in the input tensor. + eps : float, optional + Numerical stability constant added to the variance. + """ + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + """Normalize each spatial position across channels. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Normalized tensor with same shape as input. + """ + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + self.eps).sqrt() + y = self.weight.view(1, C, 1, 1) * y + self.bias.view(1, C, 1, 1) + return y + + +class NAFBlock(nn.Module): + """Core NAFNet residual block. + + The block combines: + 1. Layer normalization. + 2. Pointwise + depthwise convolutions. + 3. `SimpleGate` and simplified channel attention. + 4. A lightweight FFN branch. + 5. Two residual scaling parameters (`beta`, `gamma`). + """ + + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + """Initialize one NAF block. + + Parameters + ---------- + c : int + Number of input/output channels in the block. + DW_Expand : int, optional + Expansion ratio for the depthwise branch before gating. + FFN_Expand : int, optional + Expansion ratio for the feed-forward branch. + drop_out_rate : float, optional + Dropout probability used in both residual branches. + """ + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + """Apply the NAF block transformation. + + Parameters + ---------- + inp : torch.Tensor + Input feature map with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Output feature map with the same shape as `inp`. + """ + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +class NAFNet(nn.Module): + """NAFNet encoder-decoder architecture for image restoration. + + The model follows a U-shaped design with: + 1. Intro and ending convolutions. + 2. Multiple encoder stages with downsampling. + 3. Bottleneck NAF blocks. + 4. Decoder stages with PixelShuffle upsampling and skip connections. + """ + + def __init__( + self, + img_channel=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + drop_out_rate=0.0, + dw_expand=2, + ffn_expand=2, + discriminator_arch=None, + patchgan_base_filters=64, + ): + """Initialize a NAFNet model. + + Parameters + ---------- + img_channel : int, optional + Number of input/output image channels. + width : int, optional + Base number of channels. + middle_blk_num : int, optional + Number of NAF blocks in the bottleneck. + enc_blk_nums : list[int], optional + Number of NAF blocks per encoder stage. + dec_blk_nums : list[int], optional + Number of NAF blocks per decoder stage. + drop_out_rate : float, optional + Dropout probability used inside blocks. + dw_expand : int, optional + Expansion ratio for depthwise branch. + ffn_expand : int, optional + Expansion ratio for feed-forward branch. + + Notes + ----- + Spatial padding is handled in `check_image_size` to ensure dimensions are + divisible by the encoder downsampling factor. + """ + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + # Pass the new parameters into the NAFBlock + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + discriminator = None + if discriminator_arch == "patchgan": + discriminator = PatchGANDiscriminator( + in_channels=img_channel, + base_filters=patchgan_base_filters, + ) + + self.discriminator = discriminator + + + def forward(self, inp): + """Run a forward pass through NAFNet. + + Parameters + ---------- + inp : torch.Tensor + Input image tensor with shape `(N, C, H, W)`. + + Notes + ----- + The input is internally padded to satisfy the downsampling factor and + then cropped back to original size at the end of the forward pass. + + Returns + ------- + torch.Tensor + Restored image with original spatial size `(H, W)`. + """ + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + """Pad image so height/width are divisible by internal stride. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Padded tensor compatible with encoder/decoder downsampling. + """ + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x diff --git a/biapy/models/patchgan.py b/biapy/models/patchgan.py new file mode 100644 index 000000000..b7e736ddf --- /dev/null +++ b/biapy/models/patchgan.py @@ -0,0 +1,95 @@ +"""PatchGAN discriminator model used in image-to-image GAN training. + +This module implements a convolutional discriminator that predicts realism at +the patch level instead of producing a single global score. Patch-level +classification is commonly used in conditional GAN pipelines because it +emphasizes local texture and edge consistency, which is especially useful in +restoration and translation tasks. + +Classes +------- +PatchGANDiscriminator + Multi-layer convolutional discriminator with strided downsampling blocks and + a final 1-channel logits map. + +Notes +----- +The output tensor shape is `(N, 1, H_patch, W_patch)`, where each spatial value +acts as a local real/fake logit for a receptive-field patch in the input image. + +Implementation adapted for this project from: +https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + +""" + +import torch.nn as nn + + +class PatchGANDiscriminator(nn.Module): + """PatchGAN discriminator based on strided convolutional blocks. + + Parameters + ---------- + in_channels : int, optional + Number of channels in the input image. + base_filters : int, optional + Number of filters in the first discriminator block. Each subsequent + block doubles this value. + + Notes + ----- + The architecture follows a typical PatchGAN design: + 1. Four convolutional downsampling blocks. + 2. Batch normalization on all blocks except the first one. + 3. LeakyReLU activations. + 4. Final convolution producing a patch-logits map. + """ + + def __init__(self, in_channels=1, base_filters=64): + super(PatchGANDiscriminator, self).__init__() + + def discriminator_block(in_filters, out_filters, normalization=True): + """Create one discriminator stage. + + Parameters + ---------- + in_filters : int + Number of input channels. + out_filters : int + Number of output channels. + normalization : bool, optional + Whether to include BatchNorm after convolution. + + Returns + ------- + list[nn.Module] + Layers composing one stage of the discriminator. + """ + layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] + if normalization: + layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *discriminator_block(in_channels, base_filters, normalization=False), + *discriminator_block(base_filters, base_filters * 2), + *discriminator_block(base_filters * 2, base_filters * 4), + *discriminator_block(base_filters * 4, base_filters * 8), + nn.Conv2d(base_filters * 8, 1, 4, stride=1, padding=1) + ) + + def forward(self, img): + """Run a forward pass through the discriminator. + + Parameters + ---------- + img : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Patch-wise realism logits with shape `(N, 1, H_patch, W_patch)`. + """ + return self.model(img) \ No newline at end of file diff --git a/biapy/utils/misc.py b/biapy/utils/misc.py index 0ce41e73b..9af6c2f6a 100644 --- a/biapy/utils/misc.py +++ b/biapy/utils/misc.py @@ -325,7 +325,7 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, The current epoch number. model_without_ddp : nn.Module The model instance, typically the unwrapped model if using DistributedDataParallel. - optimizer : torch.optim.Optimizer + optimizer : List[torch.optim.Optimizer] The optimizer's state. model_build_kwargs : Optional[Dict], optional Keyword arguments used to build the model, useful for re-instantiating @@ -346,11 +346,15 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, to_save = { "model_build_kwargs": model_build_kwargs, "model": model_without_ddp.state_dict(), - "optimizer": optimizer.state_dict(), + "optimizer": [opt.state_dict() for opt in optimizer], # should i check if none? if i leave empty it uses default. "epoch": epoch, "cfg": cfg, "biapy_version": biapy_version, } + + # For Gan Models + if hasattr(model_without_ddp, 'discriminator'): + to_save["discriminator_state_dict"] = model_without_ddp.discriminator.state_dict() save_on_master(to_save, checkpoint_path) if len(checkpoint_paths) > 0: @@ -454,8 +458,8 @@ def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=Non The model instance (unwrapped if DDP is used) to load weights into. device : torch.device The device to map the loaded checkpoint to. - optimizer : Optional[torch.optim.Optimizer], optional - The optimizer instance to load state into. If None, optimizer state is not loaded. + optimizer : Optional[List[torch.optim.Optimizer]], optional + The list of optimizer instances to load state into. If None, optimizer state is not loaded. Defaults to None. just_extract_checkpoint_info : bool, optional If True, only the configuration (`cfg`) and BiaPy version from the checkpoint @@ -550,13 +554,27 @@ def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=Non print("Model weights loaded!") + if "discriminator_state_dict" in checkpoint: + if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None: + # We use strict=False just in case there are minor architecture changes + model_without_ddp.discriminator.load_state_dict(checkpoint["discriminator_state_dict"], strict=False) + print("Discriminator weights loaded!") + if cfg.MODEL.LOAD_CHECKPOINT_ONLY_WEIGHTS: return start_epoch, resume # Load also opt, epoch and scaler info if "optimizer" in checkpoint and optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer"], strict=False) - print("Optimizer info loaded!") + # im leaving this for prior non list optimizers for backward compatibility, + checkpoint_optimizer = checkpoint["optimizer"] + if isinstance(checkpoint_optimizer, dict): + checkpoint_optimizer = [checkpoint_optimizer] + + loaded_optimizers = 0 + for opt, opt_state in zip(optimizer, checkpoint_optimizer): + opt.load_state_dict(opt_state) + loaded_optimizers += 1 + print(f"Optimizer info loaded for {loaded_optimizers}/{len(optimizer)} optimizer(s)!") if "epoch" in checkpoint: start_epoch = checkpoint["epoch"] diff --git a/biapy/utils/scripts/run_checks.py b/biapy/utils/scripts/run_checks.py index 0252a2661..7416af2c1 100644 --- a/biapy/utils/scripts/run_checks.py +++ b/biapy/utils/scripts/run_checks.py @@ -2274,7 +2274,7 @@ def runjob(test_info, results_folder, yaml_file, biapy_folder, multigpu=False, b biapy_config['TRAIN']['PATIENCE'] = -1 biapy_config['TRAIN']['LR_SCHEDULER'] = {} biapy_config['TRAIN']['LR_SCHEDULER']['NAME'] = 'onecycle' - biapy_config['TRAIN']['LR'] = 0.001 + biapy_config['TRAIN']['LR'] = [0.001] biapy_config['TRAIN']['BATCH_SIZE'] = 16 biapy_config['MODEL']['ARCHITECTURE'] = 'resunet' @@ -3334,7 +3334,7 @@ def runjob(test_info, results_folder, yaml_file, biapy_folder, multigpu=False, b biapy_config['TRAIN']['EPOCHS'] = 100 biapy_config['TRAIN']['BATCH_SIZE'] = 1 biapy_config['TRAIN']['PATIENCE'] = 20 - biapy_config['TRAIN']['LR'] = 0.0001 + biapy_config['TRAIN']['LR'] = [0.0001] biapy_config['TRAIN']['LR_SCHEDULER'] = {} biapy_config['TRAIN']['LR_SCHEDULER']['NAME'] = 'warmupcosine' biapy_config['TRAIN']['LR_SCHEDULER']['MIN_LR'] = 5.E-6 @@ -3449,8 +3449,8 @@ def runjob(test_info, results_folder, yaml_file, biapy_folder, multigpu=False, b biapy_config['TRAIN']['ENABLE'] = True biapy_config['TRAIN']['EPOCHS'] = 80 biapy_config['TRAIN']['PATIENCE'] = -1 - biapy_config['TRAIN']['OPTIMIZER'] = "ADAMW" - biapy_config['TRAIN']['LR'] = 1.E-4 + biapy_config['TRAIN']['OPTIMIZER'] = ["ADAMW"] + biapy_config['TRAIN']['LR'] = [1.E-4] biapy_config['TRAIN']['LR_SCHEDULER'] = {} biapy_config['TRAIN']['LR_SCHEDULER']['NAME'] = 'warmupcosine' biapy_config['TRAIN']['LR_SCHEDULER']['MIN_LR'] = 5.E-6 diff --git a/biapy/utils/util.py b/biapy/utils/util.py index be1e6e6a5..a1eca0745 100644 --- a/biapy/utils/util.py +++ b/biapy/utils/util.py @@ -76,18 +76,21 @@ def create_plots(results, metrics, job_id, chartOutDir): os.environ["QT_QPA_PLATFORM"] = "offscreen" # Loss - plt.plot(results["loss"]) - if "val_loss" in results: - plt.plot(results["val_loss"]) - plt.title("Model JOBID=" + job_id + " loss") - plt.ylabel("Value") - plt.xlabel("Epoch") - if "val_loss" in results: - plt.legend(["Train loss", "Val. loss"], loc="upper left") - else: - plt.legend(["Train loss"], loc="upper left") - plt.savefig(os.path.join(chartOutDir, job_id + "_loss.png")) - plt.clf() + loss_keys = ["loss"] + [k for k in results if "loss" in k and k not in ["loss", "val_loss"]] + for loss_key in loss_keys: + val_loss_key = f"val_{loss_key}" + plt.plot(results[loss_key]) + if val_loss_key in results: + plt.plot(results[val_loss_key]) + plt.title("Model JOBID=" + job_id + " " + loss_key) + plt.ylabel("Value") + plt.xlabel("Epoch") + if val_loss_key in results: + plt.legend([f"Train {loss_key}", f"Val. {loss_key}"], loc="upper left") + else: + plt.legend([f"Train {loss_key}"], loc="upper left") + plt.savefig(os.path.join(chartOutDir, job_id + "_" + loss_key + ".png")) + plt.clf() # Metric for i in range(len(metrics)):