From 6f8543248db86479d80b337630d1f55c096ca588 Mon Sep 17 00:00:00 2001 From: Ibai Date: Mon, 23 Mar 2026 22:12:56 +0100 Subject: [PATCH 1/5] Gan implementation first pass --- biapy/config/config.py | 50 ++++++++- biapy/data/generators/__init__.py | 4 +- biapy/engine/__init__.py | 121 ++++++++++++-------- biapy/engine/base_workflow.py | 71 ++++++++++-- biapy/engine/check_configuration.py | 21 ++-- biapy/engine/denoising.py | 59 +++++++++- biapy/engine/metrics.py | 115 ++++++++++++++++++- biapy/engine/train_engine.py | 156 ++++++++++++++++++++++++-- biapy/models/__init__.py | 83 ++++++++++++++ biapy/models/nafnet.py | 165 ++++++++++++++++++++++++++++ biapy/models/patchgan.py | 23 ++++ biapy/utils/misc.py | 27 ++++- 12 files changed, 811 insertions(+), 84 deletions(-) create mode 100644 biapy/models/nafnet.py create mode 100644 biapy/models/patchgan.py diff --git a/biapy/config/config.py b/biapy/config/config.py index 82b28b8a9..f5f73c361 100644 --- a/biapy/config/config.py +++ b/biapy/config/config.py @@ -1162,12 +1162,15 @@ def __init__(self, job_dir: str, job_identifier: str): # * Semantic segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' # * Instance segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' # * Detection: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' - # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2' + # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2', 'nafnet' # * Super-resolution: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' # * Self-supervision: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2' # * Classification: 'simple_cnn', 'vit', 'efficientnet_b[0-7]' (only 2D) # * Image to image: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' _C.MODEL.ARCHITECTURE = "unet" + # Architecture of the network. Possible values are: + # * 'patchgan' + _C.MODEL.ARCHITECTURE_D = "patchgan" # Number of feature maps on each level of the network. _C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256] # Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr' @@ -1306,6 +1309,26 @@ def __init__(self, job_dir: str, job_identifier: str): # Whether to use a pretrained version of STUNet on ImageNet _C.MODEL.STUNET.PRETRAINED = False + # NafNet + _C.MODEL.NAFNET = CN() + # Base number of channels (width) used in the first layer and base levels. + _C.MODEL.NAFNET.WIDTH = 16 + # Number of NAFBlocks stacked at the bottleneck (deepest level). + _C.MODEL.NAFNET.MIDDLE_BLK_NUM = 12 + # Number of NAFBlocks assigned to each downsampling level of the encoder. + _C.MODEL.NAFNET.ENC_BLK_NUMS = [2, 2, 4, 8] + # Number of NAFBlocks assigned to each upsampling level of the decoder. + _C.MODEL.NAFNET.DEC_BLK_NUMS = [2, 2, 2, 2] + # Channel expansion factor for the depthwise convolution within the gating unit. + _C.MODEL.NAFNET.DW_EXPAND = 2 + # Expansion factor for the hidden layer within the feed-forward network. + _C.MODEL.NAFNET.FFN_EXPAND = 2 + + # Discriminator PATCHGAN + _C.MODEL.PATCHGAN = CN() + # Number of initial convolutional filters in the first layer of the discriminator. + _C.MODEL.PATCHGAN.BASE_FILTERS = 64 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loss # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1371,7 +1394,24 @@ def __init__(self, job_dir: str, job_identifier: str): _C.LOSS.CONTRAST.MEMORY_SIZE = 5000 _C.LOSS.CONTRAST.PROJ_DIM = 256 _C.LOSS.CONTRAST.PIXEL_UPD_FREQ = 10 - + + # Fine-grained GAN composition. Set any weight to 0.0 to disable that term. + # Used when LOSS.TYPE == "COMPOSED_GAN". + _C.LOSS.COMPOSED_GAN = CN() + # Weight for adversarial BCE term. + _C.LOSS.COMPOSED_GAN.LAMBDA_GAN = 1.0 + # Weight for L1 reconstruction term. + _C.LOSS.COMPOSED_GAN.LAMBDA_RECON = 10.0 + # Weight for MSE reconstruction term. + _C.LOSS.COMPOSED_GAN.DELTA_MSE = 0.0 + # Weight for VGG perceptual term. + _C.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL = 0.0 + # Weight for SSIM term. + _C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0 + + # Backward-compatible alias for previous naming. + _C.LOSS.GAN = CN() + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Training phase # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1381,12 +1421,18 @@ def __init__(self, job_dir: str, job_identifier: str): _C.TRAIN.VERBOSE = False # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" _C.TRAIN.OPTIMIZER = "SGD" + # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" for GAN discriminator + _C.TRAIN.OPTIMIZER_D = "SGD" # Learning rate _C.TRAIN.LR = 1.0e-4 + # Learning rate for GAN discriminator + _C.TRAIN.LR_D = 1.0e-4 # Weight decay _C.TRAIN.W_DECAY = 0.02 # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers _C.TRAIN.OPT_BETAS = (0.9, 0.999) + # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers for GANS discriminator + _C.TRAIN.OPT_BETAS_D = (0.5, 0.999) # Batch size _C.TRAIN.BATCH_SIZE = 2 # If memory or # gpus is limited, use this variable to maintain the effective batch size, which is diff --git a/biapy/data/generators/__init__.py b/biapy/data/generators/__init__.py index ca072706b..dcc7720f0 100644 --- a/biapy/data/generators/__init__.py +++ b/biapy/data/generators/__init__.py @@ -246,7 +246,7 @@ def create_train_val_augmentors( dic["zflip"] = cfg.AUGMENTOR.ZFLIP if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR @@ -293,7 +293,7 @@ def create_train_val_augmentors( ) if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 8991943f7..1873d9ac7 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -21,6 +21,7 @@ def prepare_optimizer( cfg: CN, model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel, steps_per_epoch: int, + is_gan: bool = False, ) -> Tuple[Optimizer, Scheduler | None]: """ Create and configure the optimizer and learning rate scheduler for the given model. @@ -33,57 +34,89 @@ def prepare_optimizer( ---------- cfg : YACS CN object Configuration object with optimizer and scheduler settings. - model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel + model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel or dict The model to optimize. steps_per_epoch : int Number of steps (batches) per training epoch. + is_gan : bool, optional + Whether to create optimizer/scheduler pairs for GAN generator and discriminator. Returns ------- - optimizer : Optimizer - Configured optimizer for the model. - lr_scheduler : Scheduler or None - Configured learning rate scheduler, or None if not specified. + optimizer : Optimizer or dict + Configured optimizer for the model or dict with generator/discriminator optimizers in GAN mode. + lr_scheduler : Scheduler or None or dict + Configured scheduler for the model or dict with generator/discriminator schedulers in GAN mode. """ - lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR - opt_args = {} - if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]: - opt_args["betas"] = cfg.TRAIN.OPT_BETAS - optimizer = timm.optim.create_optimizer_v2( - model_without_ddp, - opt=cfg.TRAIN.OPTIMIZER, - lr=lr, - weight_decay=cfg.TRAIN.W_DECAY, - **opt_args, - ) - print(optimizer) - - # Learning rate schedulers - lr_scheduler = None - if cfg.TRAIN.LR_SCHEDULER.NAME != "": - if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": - lr_scheduler = ReduceLROnPlateau( - optimizer, - patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, - factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": - lr_scheduler = WarmUpCosineDecayScheduler( - lr=cfg.TRAIN.LR, - min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, - warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, - epochs=cfg.TRAIN.EPOCHS, - ) - elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler = OneCycleLR( - optimizer, - cfg.TRAIN.LR, - epochs=cfg.TRAIN.EPOCHS, - steps_per_epoch=steps_per_epoch, - ) - - return optimizer, lr_scheduler + def _make_scheduler(optimizer: Optimizer, lr_value: float) -> Scheduler | None: + lr_scheduler = None + if cfg.TRAIN.LR_SCHEDULER.NAME != "": + if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + lr_scheduler = ReduceLROnPlateau( + optimizer, + patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, + factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": + lr_scheduler = WarmUpCosineDecayScheduler( + lr=lr_value, + min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, + warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, + epochs=cfg.TRAIN.EPOCHS, + ) + elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler = OneCycleLR( + optimizer, + lr_value, + epochs=cfg.TRAIN.EPOCHS, + steps_per_epoch=steps_per_epoch, + ) + return lr_scheduler + + def _make_optimizer(model: nn.Module | nn.parallel.DistributedDataParallel, train_cfg: dict): + lr_value = train_cfg["lr"] + opt_name = train_cfg["optimizer"] + betas = train_cfg["betas"] + w_decay = train_cfg["weight_decay"] + + lr = lr_value if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR + opt_args = {} + if opt_name in ["ADAM", "ADAMW"]: + opt_args["betas"] = betas + + optimizer = timm.optim.create_optimizer_v2( + model, + opt=opt_name, + lr=lr, + weight_decay=w_decay, + **opt_args, + ) + print(optimizer) + lr_scheduler = _make_scheduler(optimizer, lr_value) + return optimizer, lr_scheduler + + g_train_cfg = { + "lr": cfg.TRAIN.LR, + "optimizer": cfg.TRAIN.OPTIMIZER, + "betas": cfg.TRAIN.OPT_BETAS, + "weight_decay": cfg.TRAIN.W_DECAY, + } + + if not is_gan: + return _make_optimizer(model_without_ddp, g_train_cfg) + + d_train_cfg = { + "lr": cfg.TRAIN.LR_D, + "optimizer": cfg.TRAIN.OPTIMIZER_D, + "betas": cfg.TRAIN.OPT_BETAS_D, + "weight_decay": cfg.TRAIN.W_DECAY, + } + + optimizer_g, scheduler_g = _make_optimizer(model_without_ddp["generator"], g_train_cfg) + optimizer_d, scheduler_d = _make_optimizer(model_without_ddp["discriminator"], d_train_cfg) + + return {"generator": optimizer_g, "discriminator": optimizer_d}, {"generator": scheduler_g, "discriminator": scheduler_d,} def build_callbacks(cfg: CN) -> EarlyStopping | None: diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 4f28f200d..5f2b48464 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -837,6 +837,11 @@ def prepare_logging_tool(self): self.plot_values = {} self.plot_values["loss"] = [] self.plot_values["val_loss"] = [] + if getattr(self, "is_gan_mode", False): + self.plot_values["loss_g"] = [] + self.plot_values["loss_d"] = [] + self.plot_values["val_loss_g"] = [] + self.plot_values["val_loss_d"] = [] for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]] = [] self.plot_values["val_" + self.train_metric_names[i]] = [] @@ -853,9 +858,28 @@ def train(self): assert ( self.start_epoch is not None and self.model is not None and self.model_without_ddp is not None and self.loss ) - self.optimizer, self.lr_scheduler = prepare_optimizer( - self.cfg, self.model_without_ddp, len(self.train_generator) - ) + is_gan_mode = getattr(self, "is_gan_mode", False) + if is_gan_mode: + discriminator_wo_ddp = getattr(self, "discriminator_without_ddp", None) + if discriminator_wo_ddp is None: + raise ValueError("GAN mode requires a discriminator model before training") + optimizers, schedulers = prepare_optimizer( + self.cfg, + { + "generator": self.model_without_ddp, + "discriminator": discriminator_wo_ddp, + }, + len(self.train_generator), + is_gan=True, + ) + self.optimizer = optimizers["generator"] + self.lr_scheduler = schedulers["generator"] + self.optimizer_d = optimizers["discriminator"] + self.lr_scheduler_d = schedulers["discriminator"] + else: + self.optimizer, self.lr_scheduler = prepare_optimizer( + self.cfg, self.model_without_ddp, len(self.train_generator) + ) contrast_init_iter = 0 if self.cfg.LOSS.CONTRAST.ENABLE: @@ -910,6 +934,9 @@ def train(self): memory_bank=self.memory_bank, total_iters=total_iters, contrast_warmup_iters=contrast_init_iter, + model_d=getattr(self, "discriminator", None) if is_gan_mode else None, + optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, + lr_scheduler_d=getattr(self, "lr_scheduler_d", None) if is_gan_mode else None, ) total_iters += iterations_done @@ -929,6 +956,8 @@ def train(self): epoch=epoch + 1, model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, + discriminator_without_ddp=getattr(self, "discriminator_without_ddp", None) if is_gan_mode else None, + optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, ) # Validation @@ -944,24 +973,29 @@ def train(self): data_loader=self.val_generator, lr_scheduler=self.lr_scheduler, memory_bank=self.memory_bank, + model_d=getattr(self, "discriminator", None) if is_gan_mode else None, + lr_scheduler_d=getattr(self, "lr_scheduler_d", None) if is_gan_mode else None, + device=self.device if is_gan_mode else None, ) + val_loss_key = "loss_g" if is_gan_mode else "loss" + # Save checkpoint is val loss improved - if test_stats["loss"] < self.val_best_loss: + if test_stats[val_loss_key] < self.val_best_loss: f = os.path.join( self.cfg.PATHS.CHECKPOINT, "{}-checkpoint-best.pth".format(self.job_identifier), ) print( "Val loss improved from {} to {}, saving model to {}".format( - self.val_best_loss, test_stats["loss"], f + self.val_best_loss, test_stats[val_loss_key], f ) ) m = " " for i in range(len(self.val_best_metric)): self.val_best_metric[i] = test_stats[self.train_metric_names[i]] m += f"{self.train_metric_names[i]}: {self.val_best_metric[i]:.4f} " - self.val_best_loss = test_stats["loss"] + self.val_best_loss = test_stats[val_loss_key] if is_main_process(): self.checkpoint_path = save_model( @@ -973,12 +1007,14 @@ def train(self): epoch="best", model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, + discriminator_without_ddp=getattr(self, "discriminator_without_ddp", None) if is_gan_mode else None, + optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, ) print(f"[Val] best loss: {self.val_best_loss:.4f} best " + m) # Store validation stats if self.log_writer: - self.log_writer.update(test_loss=test_stats["loss"], head="perf", step=epoch) + self.log_writer.update(test_loss=test_stats[val_loss_key], head="perf", step=epoch) for i in range(len(self.train_metric_names)): self.log_writer.update( test_iou=test_stats[self.train_metric_names[i]], @@ -1006,9 +1042,20 @@ def train(self): f.write(json.dumps(log_stats) + "\n") # Create training plot - self.plot_values["loss"].append(train_stats["loss"]) + train_loss_key = "loss_g" if is_gan_mode else "loss" + val_loss_key = "loss_g" if is_gan_mode else "loss" + self.plot_values["loss"].append(train_stats[train_loss_key]) if self.val_generator: - self.plot_values["val_loss"].append(test_stats["loss"]) + self.plot_values["val_loss"].append(test_stats[val_loss_key]) + if is_gan_mode: + if "loss_g" in self.plot_values: + self.plot_values["loss_g"].append(train_stats.get("loss_g", 0)) + if "loss_d" in self.plot_values: + self.plot_values["loss_d"].append(train_stats.get("loss_d", 0)) + if self.val_generator and "val_loss_g" in self.plot_values: + self.plot_values["val_loss_g"].append(test_stats.get("loss_g", 0)) + if self.val_generator and "val_loss_d" in self.plot_values: + self.plot_values["val_loss_d"].append(test_stats.get("loss_d", 0)) for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]]) if self.val_generator: @@ -1024,7 +1071,8 @@ def train(self): ) if self.val_generator and self.early_stopping: - self.early_stopping(test_stats["loss"]) + val_loss_key = "loss_g" if is_gan_mode else "loss" + self.early_stopping(test_stats[val_loss_key]) if self.early_stopping.early_stop: print("Early stopping") break @@ -1043,7 +1091,8 @@ def train(self): self.total_training_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time: {}".format(self.total_training_time_str)) - self.train_metrics_message += ("Train loss: {}\n".format(train_stats["loss"])) + train_loss_key = "loss_g" if is_gan_mode else "loss" + self.train_metrics_message += ("Train loss: {}\n".format(train_stats[train_loss_key])) for i in range(len(self.train_metric_names)): self.train_metrics_message += ("Train {}: {}\n".format(self.train_metric_names[i], train_stats[self.train_metric_names[i]])) if self.val_generator: diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index f4b5d5b6a..0374e17c7 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -1242,7 +1242,7 @@ def sort_key(item): ], "LOSS.CLASS_REBALANCE not in ['none', 'auto'] for INSTANCE_SEG workflow" elif cfg.PROBLEM.TYPE == "DENOISING": loss = "MSE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE - assert loss == "MSE", "LOSS.TYPE must be 'MSE'" + assert loss in ["MSE", "COMPOSED_GAN"], "LOSS.TYPE must be in ['MSE', 'COMPOSED_GAN'] for DENOISING" elif cfg.PROBLEM.TYPE == "CLASSIFICATION": loss = "CE" if cfg.LOSS.TYPE == "" else cfg.LOSS.TYPE assert loss == "CE", "LOSS.TYPE must be 'CE'" @@ -1797,12 +1797,16 @@ def sort_key(item): #### Denoising #### elif cfg.PROBLEM.TYPE == "DENOISING": - if cfg.DATA.TEST.LOAD_GT: - raise ValueError( - "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" - ) - if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): - raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") + if cfg.LOSS.TYPE == "COMPOSED_GAN": + if not cfg.DATA.TRAIN.GT_PATH and not cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA: + raise ValueError("Denoising with COMPOSED_GAN is supervised. 'DATA.TRAIN.GT_PATH' is required.") + else: + if cfg.DATA.TEST.LOAD_GT: + raise ValueError( + "Denoising is made in an unsupervised way so there is no ground truth required. Disable 'DATA.TEST.LOAD_GT'" + ) + if not check_value(cfg.PROBLEM.DENOISING.N2V_PERC_PIX): + raise ValueError("PROBLEM.DENOISING.N2V_PERC_PIX not in [0, 1] range") if cfg.MODEL.SOURCE == "torchvision": raise ValueError("'MODEL.SOURCE' as 'torchvision' is not available in denoising workflow") @@ -2341,6 +2345,7 @@ def sort_key(item): "hrnet48", "hrnet64", "stunet", + "nafnet", ], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']" if ( model_arch @@ -2514,6 +2519,7 @@ def sort_key(item): "hrnet48", "hrnet64", "stunet", + "nafnet", ]: raise ValueError( "Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2', 'hrnet18', 'hrnet32', 'hrnet48', 'hrnet64', 'stunet']".format( @@ -2976,6 +2982,7 @@ def compare_configurations_without_model(actual_cfg, old_cfg, header_message="", "DATA.PATCH_SIZE", "PROBLEM.INSTANCE_SEG.DATA_CHANNELS", "PROBLEM.SUPER_RESOLUTION.UPSCALING", + "MODEL.ARCHITECTURE_D", "DATA.N_CLASSES", ] diff --git a/biapy/engine/denoising.py b/biapy/engine/denoising.py index 9cac8d574..b5c0efac8 100644 --- a/biapy/engine/denoising.py +++ b/biapy/engine/denoising.py @@ -24,9 +24,10 @@ merge_3D_data_with_overlap, ) from biapy.engine.base_workflow import Base_Workflow +from biapy.models import build_discriminator from biapy.data.data_manipulation import save_tif -from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger -from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation +from biapy.utils.misc import to_pytorch_format, MetricLogger +from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation, ComposedGANLoss class Denoising_Workflow(Base_Workflow): @@ -72,10 +73,20 @@ def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs): # From now on, no modification of the cfg will be allowed self.cfg.freeze() + self.is_gan_mode = str(cfg.LOSS.TYPE).upper() == "COMPOSED_GAN" + self.discriminator = None + self.discriminator_without_ddp = None + self.optimizer_d = None + self.lr_scheduler_d = None + # Workflow specific training variables - self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None + if self.is_gan_mode: + self.mask_path = cfg.DATA.TRAIN.GT_PATH + self.load_Y_val = True + else: + self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None + self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA self.is_y_mask = False - self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA self.norm_module.mask_norm = "as_image" self.test_norm_module.mask_norm = "as_image" @@ -166,6 +177,8 @@ def define_metrics(self): # print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)") if self.cfg.LOSS.TYPE == "MSE": self.loss = loss_encapsulation(n2v_loss_mse) + elif self.cfg.LOSS.TYPE == "COMPOSED_GAN": + self.loss = ComposedGANLoss(cfg=self.cfg, device=self.device) super().define_metrics() @@ -232,7 +245,11 @@ def metric_calculation( with torch.no_grad(): for i, metric in enumerate(list_to_use): - val = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous()) + if self.is_gan_mode: + target_for_metric = _targets.contiguous() + else: + target_for_metric = _targets[:, _output.shape[1]:].contiguous() + val = metric(_output.contiguous(), target_for_metric) val = val.item() if not torch.isnan(val) else 0 out_metrics[list_names_to_use[i]] = val @@ -328,6 +345,36 @@ def process_test_sample(self): verbose=self.cfg.TEST.VERBOSE, ) + def prepare_model(self): + """Build generator model and discriminator when running denoising in GAN mode.""" + super().prepare_model() + + # It is not a GAN model, or we alredy have a discriminator loaded from checkpoint + if not self.is_gan_mode or self.discriminator is not None: + return + + print("#######################") + print("# Build Discriminator #") + print("#######################") + self.discriminator = build_discriminator(self.cfg, self.device) + self.discriminator_without_ddp = self.discriminator + + if self.args.distributed: + self.discriminator = torch.nn.parallel.DistributedDataParallel( + self.discriminator, + device_ids=[self.args.gpu], + find_unused_parameters=False, + ) + self.discriminator_without_ddp = self.discriminator.module + + if self.cfg.MODEL.SOURCE == "biapy" and self.cfg.MODEL.LOAD_CHECKPOINT and self.checkpoint_path: + checkpoint = torch.load(self.checkpoint_path, map_location=self.device) + if "discriminator_state_dict" in checkpoint: + self.discriminator_without_ddp.load_state_dict(checkpoint["discriminator_state_dict"]) + print("Discriminator weights loaded successfully.") + else: + print("Warning: 'discriminator_state_dict' not found in checkpoint.") + def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None: """ Call a regular Pytorch model. @@ -933,4 +980,4 @@ def get_value_manipulation(n2v_manipulator, n2v_neighborhood_radius): Callable Value manipulation function. """ - return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius))) + return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius))) \ No newline at end of file diff --git a/biapy/engine/metrics.py b/biapy/engine/metrics.py index 32dab4283..80bb21db1 100644 --- a/biapy/engine/metrics.py +++ b/biapy/engine/metrics.py @@ -18,6 +18,8 @@ import torch.nn.functional as F import torch.nn as nn from typing import Optional, List, Tuple, Dict +from torchvision import transforms +from torchvision.models import vgg16, VGG16_Weights def jaccard_index_numpy(y_true, y_pred): """ @@ -2162,4 +2164,115 @@ def forward( loss = loss / B iou = iou / B - return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals \ No newline at end of file + return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals + +class VGGLoss(nn.Module): + def __init__(self, device): + super().__init__() + self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].eval().to(device) + for param in self.vgg.parameters(): + param.requires_grad = False + self.loss = nn.L1Loss() + self.preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, pred, target): + if isinstance(pred, dict): + pred = pred["pred"] + if isinstance(target, dict): + target = target["pred"] + + # If 3D, fold Depth (dim 2) into Batch (dim 0) -> (B*D, C, H, W) + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + + # 2D behavior remains identical + if pred.shape[1] == 1: + pred = pred.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + + pred = self.preprocess(pred) + target = self.preprocess(target) + pred_vgg = self.vgg(pred) + target_vgg = self.vgg(target) + return self.loss(pred_vgg, target_vgg) + +class ComposedGANLoss(nn.Module): + """ + Dynamic composite loss for GANs. + Only instantiates heavy loss models (like VGG) if their config weight is > 0. + """ + def __init__(self, cfg, device): + super().__init__() + self.device = device + self.w_gan = cfg.LOSS.COMPOSED_GAN.LAMBDA_GAN + self.w_l1 = cfg.LOSS.COMPOSED_GAN.LAMBDA_RECON + self.w_vgg = cfg.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL + self.w_ssim = cfg.LOSS.COMPOSED_GAN.GAMMA_SSIM + self.w_mse = cfg.LOSS.COMPOSED_GAN.DELTA_MSE + + # Dont load the vgg if not + if self.w_vgg > 0: + self.vgg = VGGLoss(device) + if self.w_ssim > 0: + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device) + + # Standard lightweight losses are always initialized + self.l1 = nn.L1Loss() + self.mse = nn.MSELoss() + self.bce = nn.BCEWithLogitsLoss() + + def forward_generator(self, pred, target, d_fake): + # Dict extraction + if isinstance(pred, dict): pred = pred["pred"] + if isinstance(target, dict): target = target["pred"] + + # NaN Band-aid + pred = torch.nan_to_num(pred, nan=0.0, posinf=1.0, neginf=-1.0) + target = torch.nan_to_num(target, nan=0.0, posinf=1.0, neginf=-1.0) + + total_loss = torch.tensor(0.0, device=self.device) + + # 2. Dynamically build the loss based on config weights + if self.w_l1 > 0: + total_loss += self.w_l1 * self.l1(pred, target) + + if self.w_mse > 0: + total_loss += self.w_mse * self.mse(pred, target) + + if self.w_vgg > 0: + total_loss += self.w_vgg * self.vgg(pred, target) + + if self.w_ssim > 0: + # SSIM requires 4D tensors. Safely route 3D to 2D slices. + if pred.dim() == 5: + B, C, D, H, W = pred.shape + pred_ssim = pred.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + target_ssim = target.permute(0, 2, 1, 3, 4).reshape(B * D, C, H, W) + total_loss += self.w_ssim * (1.0 - self.ssim(pred_ssim, target_ssim)) + else: + total_loss += self.w_ssim * (1.0 - self.ssim(pred, target)) + + if self.w_gan > 0: + total_loss += self.w_gan * self.bce(d_fake, torch.ones_like(d_fake)) + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in generator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss + + def forward_discriminator(self, d_real, d_fake): + # Calculate Adversarial Loss for Discriminator + real_loss = self.bce(d_real, torch.full_like(d_real, 0.9)) # Label smoothing (0.9 instead of 1.0) + fake_loss = self.bce(d_fake, torch.zeros_like(d_fake)) + total_loss = (real_loss + fake_loss) / 2.0 + + # NaN Safety Check + if torch.isnan(total_loss): + print("Warning: NaN detected in discriminator loss. Returning zero loss.") + total_loss = torch.tensor(0.0, requires_grad=True).to(self.device) + + return total_loss \ No newline at end of file diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index e4b05e3b9..c1939cabc 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -38,6 +38,9 @@ def train_one_epoch( memory_bank: Optional[MemoryBank] = None, total_iters: int=0, contrast_warmup_iters: int=0, + model_d: Optional[nn.Module | nn.parallel.DistributedDataParallel] = None, + optimizer_d: Optional[Optimizer] = None, + lr_scheduler_d: Optional[Scheduler] = None, ): """ Train the model for one epoch. @@ -87,18 +90,28 @@ def train_one_epoch( int Number of steps (batches) processed. """ + is_gan = model_d is not None and optimizer_d is not None + # Switch to training mode model.train(True) + if is_gan: + model_d.train(True) # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ", verbose=verbose) - metric_logger.add_meter("loss", SmoothedValue()) + if is_gan: + metric_logger.add_meter("loss_g", SmoothedValue()) + metric_logger.add_meter("loss_d", SmoothedValue()) + else: + metric_logger.add_meter("loss", SmoothedValue()) # Set up the header for logging header = "Epoch: [{}]".format(epoch + 1) print_freq = 10 optimizer.zero_grad() + if is_gan: + optimizer_d.zero_grad() for step, (batch, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): @@ -111,6 +124,9 @@ def train_one_epoch( and isinstance(lr_scheduler, WarmUpCosineDecayScheduler) ): lr_scheduler.adjust_learning_rate(optimizer, step / len(data_loader) + epoch) + if is_gan and lr_scheduler_d and isinstance(lr_scheduler_d, WarmUpCosineDecayScheduler): + lr_scheduler_d.adjust_learning_rate(optimizer_d, step / len(data_loader) + epoch) + # Gather inputs targets = prepare_targets(targets, batch) @@ -121,6 +137,86 @@ def train_one_epoch( f" Input: {batch.shape[1:-1]} vs PATCH_SIZE: {cfg.DATA.PATCH_SIZE[:-1]}" ) + if is_gan: + assert model_d is not None and optimizer_d is not None + + if ( + torch.isnan(batch).any() + or torch.isinf(batch).any() + or torch.isnan(targets).any() + or torch.isinf(targets).any() + ): + print("Warning: NaN or Inf detected in input. Skipping batch.") + continue + + # Phase 1: discriminator update + optimizer_d.zero_grad() + fake_img = model_call_func(batch, is_train=True) + if isinstance(fake_img, dict): + fake_img = fake_img["pred"] + fake_img = torch.clamp(fake_img, 0, 1) + + d_real = model_d(targets) + d_fake = model_d(fake_img.detach()) + loss_d = loss_function.forward_discriminator(d_real, d_fake) + + if torch.isnan(loss_d) or torch.isinf(loss_d): + print("Warning: NaN or Inf detected in discriminator loss. Skipping batch.") + continue + + loss_d.backward() + optimizer_d.step() + + if lr_scheduler_d and isinstance(lr_scheduler_d, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler_d.step() + + # Phase 2: generator update + optimizer.zero_grad() + outputs = model_call_func(batch, is_train=True) + if isinstance(outputs, dict): + outputs = outputs["pred"] + outputs = torch.clamp(outputs, 0, 1) + + d_fake_for_g = model_d(outputs) + loss = loss_function.forward_generator(outputs, targets, d_fake_for_g) + + if torch.isnan(loss) or torch.isinf(loss): + print("Warning: NaN or Inf detected in generator loss. Skipping batch.") + continue + + loss.backward() + optimizer.step() + + if lr_scheduler and isinstance(lr_scheduler, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler.step() + + metric_function(outputs, targets, metric_logger=metric_logger) + + loss_g_value = loss.item() + loss_d_value = loss_d.item() + metric_logger.update(loss_g=loss_g_value, loss_d=loss_d_value) + + if log_writer: + log_writer.update(loss_g=all_reduce_mean(loss_g_value), head="loss") + log_writer.update(loss_d=all_reduce_mean(loss_d_value), head="loss") + + max_lr_g = 0.0 + max_lr_d = 0.0 + for group in optimizer.param_groups: + max_lr_g = max(max_lr_g, group["lr"]) + for group in optimizer_d.param_groups: + max_lr_d = max(max_lr_d, group["lr"]) + + if step == 0: + metric_logger.add_meter("lr_g", SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("lr_d", SmoothedValue(window_size=1, fmt="{value:.6f}")) + + metric_logger.update(lr_g=max_lr_g, lr_d=max_lr_d) + if log_writer: + log_writer.update(lr_g=max_lr_g, head="opt") + log_writer.update(lr_d=max_lr_d, head="opt") + continue + # Pass the images through the model outputs = model_call_func(batch, is_train=True) @@ -193,6 +289,12 @@ def train_one_epoch( if log_writer: log_writer.update(lr=max_lr, head="opt") + if is_gan and cfg.TRAIN.LR_SCHEDULER.NAME not in ["reduceonplateau", "onecycle", "warmupcosine"]: + if lr_scheduler: + lr_scheduler.step() + if lr_scheduler_d: + lr_scheduler_d.step() + # Gather the stats from all processes metric_logger.synchronize_between_processes() print("[Train] averaged stats:", metric_logger) @@ -211,6 +313,9 @@ def evaluate( data_loader: DataLoader, lr_scheduler: Optional[Scheduler] = None, memory_bank: Optional[MemoryBank] = None, + model_d: Optional[nn.Module | nn.parallel.DistributedDataParallel] = None, + lr_scheduler_d: Optional[Scheduler] = None, + device: Optional[torch.device] = None, ): """ Evaluate the model on the validation set. @@ -246,13 +351,21 @@ def evaluate( dict Dictionary of averaged metrics for the validation set. """ + is_gan = model_d is not None + # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ") - metric_logger.add_meter("loss", SmoothedValue()) + if is_gan: + metric_logger.add_meter("loss_g", SmoothedValue()) + metric_logger.add_meter("loss_d", SmoothedValue()) + else: + metric_logger.add_meter("loss", SmoothedValue()) header = "Epoch: [{}]".format(epoch + 1) # Switch to evaluation mode model.eval() + if is_gan: + model_d.eval() for batch in metric_logger.log_every(data_loader, 10, header): # Gather inputs @@ -260,9 +373,30 @@ def evaluate( targets = batch[1] targets = prepare_targets(targets, images) + if is_gan: + assert model_d is not None + outputs = model_call_func(images, is_train=False) + if isinstance(outputs, dict): + outputs = outputs["pred"] + outputs = torch.clamp(outputs, 0, 1) + + d_fake_val = model_d(outputs) + d_real_val = model_d(targets) + loss = loss_function.forward_generator(outputs, targets, d_fake_val) + loss_d = loss_function.forward_discriminator(d_real_val, d_fake_val) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print(f"Validation loss is {loss_value}, skipping batch.") + continue + + metric_function(outputs, targets, metric_logger=metric_logger) + metric_logger.update(loss_g=loss_value, loss_d=loss_d.item()) + continue + # Pass the images through the model outputs = model_call_func(images, is_train=True) - + # Loss function call if memory_bank is not None: with_embed = False @@ -285,7 +419,7 @@ def evaluate( precalculated_metric = loss[1] precalculated_metric_name = loss[2] loss = loss[0] - + loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) @@ -306,10 +440,12 @@ def evaluate( print("[Val] averaged stats:", metric_logger) # Apply reduceonplateau scheduler if the global validation has been reduced - if ( - lr_scheduler - and isinstance(lr_scheduler, ReduceLROnPlateau) - and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau" - ): - lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) + if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + if is_gan: + if lr_scheduler and isinstance(lr_scheduler, ReduceLROnPlateau): + lr_scheduler.step(metric_logger.meters["loss_g"].global_avg, epoch=epoch) + if lr_scheduler_d and isinstance(lr_scheduler_d, ReduceLROnPlateau): + lr_scheduler_d.step(metric_logger.meters["loss_d"].global_avg, epoch=epoch) + elif lr_scheduler and isinstance(lr_scheduler, ReduceLROnPlateau): + lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 5a0f98b23..238903f77 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -366,6 +366,20 @@ def build_model( ) model = MaskedAutoencoderViT(**args) # type: ignore callable_model = MaskedAutoencoderViT # type: ignore + elif modelname == "nafnet": + args = dict( + img_channel=cfg.DATA.PATCH_SIZE[-1], + width=cfg.MODEL.NAFNET.WIDTH, + middle_blk_num=cfg.MODEL.NAFNET.MIDDLE_BLK_NUM, + enc_blk_nums=cfg.MODEL.NAFNET.ENC_BLK_NUMS, + dec_blk_nums=cfg.MODEL.NAFNET.DEC_BLK_NUMS, + drop_out_rate=cfg.MODEL.DROPOUT_VALUES[0], + dw_expand=cfg.MODEL.NAFNET.DW_EXPAND, + ffn_expand=cfg.MODEL.NAFNET.FFN_EXPAND + ) + callable_model = NAFNet # type: ignore + model = callable_model(**args) # type: ignore + # Check the network created model.to(device) if cfg.PROBLEM.NDIM == "2D": @@ -405,6 +419,75 @@ def build_model( return model, str(callable_model.__name__), collected_sources, all_import_lines, scanned_files, args, network_stride # type: ignore +def build_discriminator(cfg: CN, device: torch.device): + """ + Build selected model. + + Parameters + ---------- + cfg : YACS CN object + Configuration. + + device : Torch device + Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", + "xpu", "xla" or "meta". + + Returns + ------- + """ + # 1. Standardize name and Import the module + modelname = str(cfg.MODEL.ARCHITECTURE_D).lower() + + print("###############") + print(f"# Build {modelname.upper()} Disc #") + print("###############") + + # Dynamic import like build_model + mdl = import_module("biapy.models." + modelname) + + names = [x for x in mdl.__dict__ if not x.startswith("_")] + globals().update({k: getattr(mdl, k) for k in names}) + + # 2. Model building block + if modelname == "patchgan": + args = dict( + in_channels=cfg.DATA.PATCH_SIZE[-1], + base_filters=cfg.MODEL.PATCHGAN.BASE_FILTERS + ) + callable_model = PatchGANDiscriminator # type: ignore + else: + raise ValueError(f"Discriminator {modelname} is not implemented or registered.") + + # Instantiate + model = callable_model(**args) + model.to(device) + + # 3. Summary Logic + if cfg.PROBLEM.NDIM == "2D": + sample_size = ( + 1, + cfg.DATA.PATCH_SIZE[2], + cfg.DATA.PATCH_SIZE[0], + cfg.DATA.PATCH_SIZE[1], + ) + else: + sample_size = ( + 1, + cfg.DATA.PATCH_SIZE[3], + cfg.DATA.PATCH_SIZE[0], + cfg.DATA.PATCH_SIZE[1], + cfg.DATA.PATCH_SIZE[2], + ) + + summary( + model, + input_size=sample_size, + col_names=("input_size", "output_size", "num_params"), + depth=10, + device=device.type, + ) + + return model def init_embedding_output(model: nn.Module, n_sigma: int = 2): """ diff --git a/biapy/models/nafnet.py b/biapy/models/nafnet.py new file mode 100644 index 000000000..868f1fdf1 --- /dev/null +++ b/biapy/models/nafnet.py @@ -0,0 +1,165 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + self.eps).sqrt() + y = self.weight.view(1, C, 1, 1) * y + self.bias.view(1, C, 1, 1) + return y + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +class NAFNet(nn.Module): + def __init__( + self, + img_channel=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + drop_out_rate=0.0, + dw_expand=2, + ffn_expand=2 + ): + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + # Pass the new parameters into the NAFBlock + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan, DW_Expand=dw_expand, FFN_Expand=ffn_expand, drop_out_rate=drop_out_rate) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x \ No newline at end of file diff --git a/biapy/models/patchgan.py b/biapy/models/patchgan.py new file mode 100644 index 000000000..bf8fa8933 --- /dev/null +++ b/biapy/models/patchgan.py @@ -0,0 +1,23 @@ +import torch.nn as nn + +class PatchGANDiscriminator(nn.Module): + def __init__(self, in_channels=1, base_filters=64): + super(PatchGANDiscriminator, self).__init__() + + def discriminator_block(in_filters, out_filters, normalization=True): + layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] + if normalization: + layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *discriminator_block(in_channels, base_filters, normalization=False), + *discriminator_block(base_filters, base_filters * 2), + *discriminator_block(base_filters * 2, base_filters * 4), + *discriminator_block(base_filters * 4, base_filters * 8), + nn.Conv2d(base_filters * 8, 1, 4, stride=1, padding=1) + ) + + def forward(self, img): + return self.model(img) \ No newline at end of file diff --git a/biapy/utils/misc.py b/biapy/utils/misc.py index 0ce41e73b..9c4dffa35 100644 --- a/biapy/utils/misc.py +++ b/biapy/utils/misc.py @@ -305,7 +305,19 @@ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: return total_norm -def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, model_build_kwargs=None, extension="pth"): +def save_model( + cfg, + biapy_version, + jobname, + epoch, + model_without_ddp, + optimizer, + model_build_kwargs=None, + extension="pth", + discriminator_without_ddp=None, + optimizer_d=None, + extra_checkpoint_items=None, +): """ Save the model checkpoint to the specified path. @@ -333,6 +345,12 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, extension : str, optional The file extension for the checkpoint file. Options are 'pth' (native PyTorch format) or 'safetensors' (https://github.com/huggingface/safetensors). Defaults to "pth". + discriminator_without_ddp : Optional[nn.Module], optional + Optional discriminator model to include in checkpoints for GAN training. + optimizer_d : Optional[torch.optim.Optimizer], optional + Optional discriminator optimizer state to include in checkpoints for GAN training. + extra_checkpoint_items : Optional[dict], optional + Additional custom fields to append to the checkpoint payload. Returns ------- @@ -352,6 +370,13 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, "biapy_version": biapy_version, } + if discriminator_without_ddp is not None: + to_save["discriminator_state_dict"] = discriminator_without_ddp.state_dict() + if optimizer_d is not None: + to_save["optimizer_d_state_dict"] = optimizer_d.state_dict() + if extra_checkpoint_items: + to_save.update(extra_checkpoint_items) + save_on_master(to_save, checkpoint_path) if len(checkpoint_paths) > 0: return checkpoint_paths[0] From 3dfb5329796ecd65ff042c20d474f3a00cefeb49 Mon Sep 17 00:00:00 2001 From: Ibai Date: Wed, 8 Apr 2026 01:39:49 +0200 Subject: [PATCH 2/5] added nafnet as a model --- biapy/config/config.py | 27 +-- biapy/data/generators/__init__.py | 4 +- biapy/engine/__init__.py | 103 +++++----- biapy/engine/base_workflow.py | 84 +++----- biapy/engine/check_configuration.py | 33 ++- biapy/engine/denoising.py | 51 +---- biapy/engine/metrics.py | 115 ++++++++++- biapy/engine/train_engine.py | 305 ++++++++++------------------ biapy/models/__init__.py | 74 +------ biapy/models/nafnet.py | 200 +++++++++++++++++- biapy/models/patchgan.py | 72 +++++++ biapy/utils/misc.py | 57 +++--- biapy/utils/scripts/run_checks.py | 8 +- 13 files changed, 637 insertions(+), 496 deletions(-) diff --git a/biapy/config/config.py b/biapy/config/config.py index f5f73c361..27d302a42 100644 --- a/biapy/config/config.py +++ b/biapy/config/config.py @@ -1170,7 +1170,6 @@ def __init__(self, job_dir: str, job_identifier: str): _C.MODEL.ARCHITECTURE = "unet" # Architecture of the network. Possible values are: # * 'patchgan' - _C.MODEL.ARCHITECTURE_D = "patchgan" # Number of feature maps on each level of the network. _C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256] # Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr' @@ -1323,11 +1322,12 @@ def __init__(self, job_dir: str, job_identifier: str): _C.MODEL.NAFNET.DW_EXPAND = 2 # Expansion factor for the hidden layer within the feed-forward network. _C.MODEL.NAFNET.FFN_EXPAND = 2 - + # Discriminator architecture + _C.MODEL.NAFNET.ARCHITECTURE_D = "patchgan" # Discriminator PATCHGAN - _C.MODEL.PATCHGAN = CN() + _C.MODEL.NAFNET.PATCHGAN = CN() # Number of initial convolutional filters in the first layer of the discriminator. - _C.MODEL.PATCHGAN.BASE_FILTERS = 64 + _C.MODEL.NAFNET.PATCHGAN.BASE_FILTERS = 64 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loss @@ -1409,9 +1409,6 @@ def __init__(self, job_dir: str, job_identifier: str): # Weight for SSIM term. _C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0 - # Backward-compatible alias for previous naming. - _C.LOSS.GAN = CN() - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Training phase # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1419,20 +1416,14 @@ def __init__(self, job_dir: str, job_identifier: str): _C.TRAIN.ENABLE = False # Enable verbosity _C.TRAIN.VERBOSE = False - # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" - _C.TRAIN.OPTIMIZER = "SGD" - # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" for GAN discriminator - _C.TRAIN.OPTIMIZER_D = "SGD" - # Learning rate - _C.TRAIN.LR = 1.0e-4 - # Learning rate for GAN discriminator - _C.TRAIN.LR_D = 1.0e-4 + # Optimizer(s) to use. Possible values: "SGD", "ADAM" or "ADAMW". + _C.TRAIN.OPTIMIZER = ["SGD"] + # Learning rate(s). + _C.TRAIN.LR = [1.0e-4] # Weight decay _C.TRAIN.W_DECAY = 0.02 # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers - _C.TRAIN.OPT_BETAS = (0.9, 0.999) - # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers for GANS discriminator - _C.TRAIN.OPT_BETAS_D = (0.5, 0.999) + _C.TRAIN.OPT_BETAS = [(0.9, 0.999)] # Batch size _C.TRAIN.BATCH_SIZE = 2 # If memory or # gpus is limited, use this variable to maintain the effective batch size, which is diff --git a/biapy/data/generators/__init__.py b/biapy/data/generators/__init__.py index dcc7720f0..783009fa2 100644 --- a/biapy/data/generators/__init__.py +++ b/biapy/data/generators/__init__.py @@ -246,7 +246,7 @@ def create_train_val_augmentors( dic["zflip"] = cfg.AUGMENTOR.ZFLIP if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet': dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR @@ -293,7 +293,7 @@ def create_train_val_augmentors( ) if cfg.PROBLEM.TYPE == "INSTANCE_SEG": dic["instance_problem"] = True - elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": + elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.MODEL.ARCHITECTURE != 'nafnet': dic["n2v"] = True dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 1873d9ac7..0656e8e4c 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -21,8 +21,7 @@ def prepare_optimizer( cfg: CN, model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel, steps_per_epoch: int, - is_gan: bool = False, -) -> Tuple[Optimizer, Scheduler | None]: +) -> Tuple[list[Optimizer], list[Scheduler | None]]: """ Create and configure the optimizer and learning rate scheduler for the given model. @@ -34,21 +33,54 @@ def prepare_optimizer( ---------- cfg : YACS CN object Configuration object with optimizer and scheduler settings. - model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel or dict + model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel The model to optimize. steps_per_epoch : int Number of steps (batches) per training epoch. - is_gan : bool, optional - Whether to create optimizer/scheduler pairs for GAN generator and discriminator. Returns ------- - optimizer : Optimizer or dict - Configured optimizer for the model or dict with generator/discriminator optimizers in GAN mode. - lr_scheduler : Scheduler or None or dict - Configured scheduler for the model or dict with generator/discriminator schedulers in GAN mode. + optimizer : List[Optimizer] + Configured optimizers for the models. + lr_scheduler : Scheduler or None + Configured learning rate schedulers, or None if not specified. """ - def _make_scheduler(optimizer: Optimizer, lr_value: float) -> Scheduler | None: + + optimizers = [] + lr_schedulers = [] + + if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None: + param_groups = [ + [p for n, p in model_without_ddp.named_parameters() if not n.startswith("discriminator.")], # Generator + model_without_ddp.discriminator.parameters() # Discriminator + ] + else: + param_groups = [model_without_ddp.parameters()] + + ## Not quite sure if this is the best place to do this + if len(cfg.TRAIN.OPTIMIZER) != len(param_groups): + raise ValueError( + f"Configuration mismatch: You requested {len(cfg.TRAIN.OPTIMIZER)} optimizers, " + f"but the model has {len(param_groups)} parameter group(s). " + f"Check your TRAIN.OPTIMIZER list in the config." + ) + + for i in range(len(cfg.TRAIN.OPTIMIZER)): + lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR + opt_args = {} + if cfg.TRAIN.OPTIMIZER[i] in ["ADAM", "ADAMW"]: + opt_args["betas"] = cfg.TRAIN.OPT_BETAS[i] if i < len(cfg.TRAIN.OPT_BETAS) else cfg.TRAIN.OPT_BETAS[0] + optimizer = timm.optim.create_optimizer_v2( + param_groups[i], + opt=cfg.TRAIN.OPTIMIZER[i], + lr=lr, + weight_decay=cfg.TRAIN.W_DECAY, + **opt_args, + ) + print(optimizer) + optimizers.append(optimizer) + + # Learning rate schedulers lr_scheduler = None if cfg.TRAIN.LR_SCHEDULER.NAME != "": if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": @@ -60,7 +92,7 @@ def _make_scheduler(optimizer: Optimizer, lr_value: float) -> Scheduler | None: ) elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": lr_scheduler = WarmUpCosineDecayScheduler( - lr=lr_value, + lr=cfg.TRAIN.LR[i], min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, epochs=cfg.TRAIN.EPOCHS, @@ -68,55 +100,14 @@ def _make_scheduler(optimizer: Optimizer, lr_value: float) -> Scheduler | None: elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": lr_scheduler = OneCycleLR( optimizer, - lr_value, + cfg.TRAIN.LR[i], epochs=cfg.TRAIN.EPOCHS, steps_per_epoch=steps_per_epoch, ) - return lr_scheduler - - def _make_optimizer(model: nn.Module | nn.parallel.DistributedDataParallel, train_cfg: dict): - lr_value = train_cfg["lr"] - opt_name = train_cfg["optimizer"] - betas = train_cfg["betas"] - w_decay = train_cfg["weight_decay"] + + lr_schedulers.append(lr_scheduler) - lr = lr_value if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR - opt_args = {} - if opt_name in ["ADAM", "ADAMW"]: - opt_args["betas"] = betas - - optimizer = timm.optim.create_optimizer_v2( - model, - opt=opt_name, - lr=lr, - weight_decay=w_decay, - **opt_args, - ) - print(optimizer) - lr_scheduler = _make_scheduler(optimizer, lr_value) - return optimizer, lr_scheduler - - g_train_cfg = { - "lr": cfg.TRAIN.LR, - "optimizer": cfg.TRAIN.OPTIMIZER, - "betas": cfg.TRAIN.OPT_BETAS, - "weight_decay": cfg.TRAIN.W_DECAY, - } - - if not is_gan: - return _make_optimizer(model_without_ddp, g_train_cfg) - - d_train_cfg = { - "lr": cfg.TRAIN.LR_D, - "optimizer": cfg.TRAIN.OPTIMIZER_D, - "betas": cfg.TRAIN.OPT_BETAS_D, - "weight_decay": cfg.TRAIN.W_DECAY, - } - - optimizer_g, scheduler_g = _make_optimizer(model_without_ddp["generator"], g_train_cfg) - optimizer_d, scheduler_d = _make_optimizer(model_without_ddp["discriminator"], d_train_cfg) - - return {"generator": optimizer_g, "discriminator": optimizer_d}, {"generator": scheduler_g, "discriminator": scheduler_d,} + return optimizers, lr_schedulers def build_callbacks(cfg: CN) -> EarlyStopping | None: diff --git a/biapy/engine/base_workflow.py b/biapy/engine/base_workflow.py index 5f2b48464..68eafa849 100644 --- a/biapy/engine/base_workflow.py +++ b/biapy/engine/base_workflow.py @@ -837,11 +837,6 @@ def prepare_logging_tool(self): self.plot_values = {} self.plot_values["loss"] = [] self.plot_values["val_loss"] = [] - if getattr(self, "is_gan_mode", False): - self.plot_values["loss_g"] = [] - self.plot_values["loss_d"] = [] - self.plot_values["val_loss_g"] = [] - self.plot_values["val_loss_d"] = [] for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]] = [] self.plot_values["val_" + self.train_metric_names[i]] = [] @@ -858,28 +853,9 @@ def train(self): assert ( self.start_epoch is not None and self.model is not None and self.model_without_ddp is not None and self.loss ) - is_gan_mode = getattr(self, "is_gan_mode", False) - if is_gan_mode: - discriminator_wo_ddp = getattr(self, "discriminator_without_ddp", None) - if discriminator_wo_ddp is None: - raise ValueError("GAN mode requires a discriminator model before training") - optimizers, schedulers = prepare_optimizer( - self.cfg, - { - "generator": self.model_without_ddp, - "discriminator": discriminator_wo_ddp, - }, - len(self.train_generator), - is_gan=True, - ) - self.optimizer = optimizers["generator"] - self.lr_scheduler = schedulers["generator"] - self.optimizer_d = optimizers["discriminator"] - self.lr_scheduler_d = schedulers["discriminator"] - else: - self.optimizer, self.lr_scheduler = prepare_optimizer( - self.cfg, self.model_without_ddp, len(self.train_generator) - ) + self.optimizer, self.lr_scheduler = prepare_optimizer( + self.cfg, self.model_without_ddp, len(self.train_generator) + ) contrast_init_iter = 0 if self.cfg.LOSS.CONTRAST.ENABLE: @@ -934,9 +910,6 @@ def train(self): memory_bank=self.memory_bank, total_iters=total_iters, contrast_warmup_iters=contrast_init_iter, - model_d=getattr(self, "discriminator", None) if is_gan_mode else None, - optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, - lr_scheduler_d=getattr(self, "lr_scheduler_d", None) if is_gan_mode else None, ) total_iters += iterations_done @@ -956,8 +929,6 @@ def train(self): epoch=epoch + 1, model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, - discriminator_without_ddp=getattr(self, "discriminator_without_ddp", None) if is_gan_mode else None, - optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, ) # Validation @@ -973,29 +944,24 @@ def train(self): data_loader=self.val_generator, lr_scheduler=self.lr_scheduler, memory_bank=self.memory_bank, - model_d=getattr(self, "discriminator", None) if is_gan_mode else None, - lr_scheduler_d=getattr(self, "lr_scheduler_d", None) if is_gan_mode else None, - device=self.device if is_gan_mode else None, ) - val_loss_key = "loss_g" if is_gan_mode else "loss" - # Save checkpoint is val loss improved - if test_stats[val_loss_key] < self.val_best_loss: + if test_stats["loss"] < self.val_best_loss: f = os.path.join( self.cfg.PATHS.CHECKPOINT, "{}-checkpoint-best.pth".format(self.job_identifier), ) print( "Val loss improved from {} to {}, saving model to {}".format( - self.val_best_loss, test_stats[val_loss_key], f + self.val_best_loss, test_stats["loss"], f ) ) m = " " for i in range(len(self.val_best_metric)): self.val_best_metric[i] = test_stats[self.train_metric_names[i]] m += f"{self.train_metric_names[i]}: {self.val_best_metric[i]:.4f} " - self.val_best_loss = test_stats[val_loss_key] + self.val_best_loss = test_stats["loss"] if is_main_process(): self.checkpoint_path = save_model( @@ -1007,14 +973,12 @@ def train(self): epoch="best", model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, - discriminator_without_ddp=getattr(self, "discriminator_without_ddp", None) if is_gan_mode else None, - optimizer_d=getattr(self, "optimizer_d", None) if is_gan_mode else None, ) print(f"[Val] best loss: {self.val_best_loss:.4f} best " + m) # Store validation stats if self.log_writer: - self.log_writer.update(test_loss=test_stats[val_loss_key], head="perf", step=epoch) + self.log_writer.update(test_loss=test_stats["loss"], head="perf", step=epoch) for i in range(len(self.train_metric_names)): self.log_writer.update( test_iou=test_stats[self.train_metric_names[i]], @@ -1042,20 +1006,22 @@ def train(self): f.write(json.dumps(log_stats) + "\n") # Create training plot - train_loss_key = "loss_g" if is_gan_mode else "loss" - val_loss_key = "loss_g" if is_gan_mode else "loss" - self.plot_values["loss"].append(train_stats[train_loss_key]) + self.plot_values["loss"].append(train_stats["loss"]) if self.val_generator: - self.plot_values["val_loss"].append(test_stats[val_loss_key]) - if is_gan_mode: - if "loss_g" in self.plot_values: - self.plot_values["loss_g"].append(train_stats.get("loss_g", 0)) - if "loss_d" in self.plot_values: - self.plot_values["loss_d"].append(train_stats.get("loss_d", 0)) - if self.val_generator and "val_loss_g" in self.plot_values: - self.plot_values["val_loss_g"].append(test_stats.get("loss_g", 0)) - if self.val_generator and "val_loss_d" in self.plot_values: - self.plot_values["val_loss_d"].append(test_stats.get("loss_d", 0)) + self.plot_values["val_loss"].append(test_stats["loss"]) + extra_loss_keys = [k for k in train_stats if "loss" in k and k != "loss"] + for loss_key in extra_loss_keys: + val_loss_key = f"val_{loss_key}" + + if loss_key not in self.plot_values: + self.plot_values[loss_key] = [] + if self.val_generator: + self.plot_values[val_loss_key] = [] + + # Append the values + self.plot_values[loss_key].append(train_stats[loss_key]) + if self.val_generator: + self.plot_values[val_loss_key].append(test_stats.get(loss_key, 0.0)) for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]]) if self.val_generator: @@ -1071,8 +1037,7 @@ def train(self): ) if self.val_generator and self.early_stopping: - val_loss_key = "loss_g" if is_gan_mode else "loss" - self.early_stopping(test_stats[val_loss_key]) + self.early_stopping(test_stats["loss"]) if self.early_stopping.early_stop: print("Early stopping") break @@ -1091,8 +1056,7 @@ def train(self): self.total_training_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time: {}".format(self.total_training_time_str)) - train_loss_key = "loss_g" if is_gan_mode else "loss" - self.train_metrics_message += ("Train loss: {}\n".format(train_stats[train_loss_key])) + self.train_metrics_message += ("Train loss: {}\n".format(train_stats["loss"])) for i in range(len(self.train_metric_names)): self.train_metrics_message += ("Train {}: {}\n".format(self.train_metric_names[i], train_stats[self.train_metric_names[i]])) if self.val_generator: diff --git a/biapy/engine/check_configuration.py b/biapy/engine/check_configuration.py index 0374e17c7..a6da41bc6 100644 --- a/biapy/engine/check_configuration.py +++ b/biapy/engine/check_configuration.py @@ -1797,9 +1797,12 @@ def sort_key(item): #### Denoising #### elif cfg.PROBLEM.TYPE == "DENOISING": - if cfg.LOSS.TYPE == "COMPOSED_GAN": + if cfg.PROBLEM.DENOISING.LOAD_GT_DATA or cfg.LOSS.TYPE == "COMPOSED_GAN": if not cfg.DATA.TRAIN.GT_PATH and not cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA: - raise ValueError("Denoising with COMPOSED_GAN is supervised. 'DATA.TRAIN.GT_PATH' is required.") + raise ValueError( + "Supervised denoising (e.g., with COMPOSED_GAN or LOAD_GT_DATA=True) " + "requires ground truth. 'DATA.TRAIN.GT_PATH' must be provided." + ) else: if cfg.DATA.TEST.LOAD_GT: raise ValueError( @@ -2720,11 +2723,26 @@ def sort_key(item): assert cfg.MODEL.OUT_CHECKPOINT_FORMAT in ["pth", "safetensors"], "MODEL.OUT_CHECKPOINT_FORMAT not in ['pth', 'safetensors']" ### Train ### - assert cfg.TRAIN.OPTIMIZER in [ - "SGD", - "ADAM", - "ADAMW", - ], "TRAIN.OPTIMIZER not in ['SGD', 'ADAM', 'ADAMW']" + if not isinstance(cfg.TRAIN.OPTIMIZER, list): + raise ValueError("'TRAIN.OPTIMIZER' must be a list") + if not isinstance(cfg.TRAIN.LR, list): + raise ValueError("'TRAIN.LR' must be a list") + if not isinstance(cfg.TRAIN.OPT_BETAS, list): + raise ValueError("'TRAIN.OPT_BETAS' must be a list") + if len(cfg.TRAIN.OPTIMIZER) != len(cfg.TRAIN.LR): + raise ValueError("'TRAIN.OPTIMIZER' and 'TRAIN.LR' must have the same length") + print(cfg.TRAIN.OPT_BETAS) + print(len(cfg.TRAIN.OPT_BETAS)) + if len(cfg.TRAIN.OPT_BETAS) not in [1, len(cfg.TRAIN.OPTIMIZER)]: + raise ValueError("'TRAIN.OPT_BETAS' must have length 1 or match 'TRAIN.OPTIMIZER' length") + + for beta_pair in cfg.TRAIN.OPT_BETAS: + if not isinstance(beta_pair, (list, tuple)) or len(beta_pair) != 2: + raise ValueError("Each entry in 'TRAIN.OPT_BETAS' must be a tuple/list of length 2") + + for opt in cfg.TRAIN.OPTIMIZER: + if opt not in ["SGD", "ADAM", "ADAMW"]: + raise ValueError("'TRAIN.OPTIMIZER' values must be in ['SGD', 'ADAM', 'ADAMW']") if cfg.TRAIN.ENABLE and cfg.TRAIN.LR_SCHEDULER.NAME != "": if cfg.TRAIN.LR_SCHEDULER.NAME not in [ @@ -2982,7 +3000,6 @@ def compare_configurations_without_model(actual_cfg, old_cfg, header_message="", "DATA.PATCH_SIZE", "PROBLEM.INSTANCE_SEG.DATA_CHANNELS", "PROBLEM.SUPER_RESOLUTION.UPSCALING", - "MODEL.ARCHITECTURE_D", "DATA.N_CLASSES", ] diff --git a/biapy/engine/denoising.py b/biapy/engine/denoising.py index b5c0efac8..a22bcb20a 100644 --- a/biapy/engine/denoising.py +++ b/biapy/engine/denoising.py @@ -24,9 +24,8 @@ merge_3D_data_with_overlap, ) from biapy.engine.base_workflow import Base_Workflow -from biapy.models import build_discriminator from biapy.data.data_manipulation import save_tif -from biapy.utils.misc import to_pytorch_format, MetricLogger +from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation, ComposedGANLoss @@ -73,20 +72,10 @@ def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs): # From now on, no modification of the cfg will be allowed self.cfg.freeze() - self.is_gan_mode = str(cfg.LOSS.TYPE).upper() == "COMPOSED_GAN" - self.discriminator = None - self.discriminator_without_ddp = None - self.optimizer_d = None - self.lr_scheduler_d = None - # Workflow specific training variables - if self.is_gan_mode: - self.mask_path = cfg.DATA.TRAIN.GT_PATH - self.load_Y_val = True - else: - self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None - self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA + self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None self.is_y_mask = False + self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA self.norm_module.mask_norm = "as_image" self.test_norm_module.mask_norm = "as_image" @@ -245,7 +234,7 @@ def metric_calculation( with torch.no_grad(): for i, metric in enumerate(list_to_use): - if self.is_gan_mode: + if _targets.shape[1] == _output.shape[1]: target_for_metric = _targets.contiguous() else: target_for_metric = _targets[:, _output.shape[1]:].contiguous() @@ -345,36 +334,6 @@ def process_test_sample(self): verbose=self.cfg.TEST.VERBOSE, ) - def prepare_model(self): - """Build generator model and discriminator when running denoising in GAN mode.""" - super().prepare_model() - - # It is not a GAN model, or we alredy have a discriminator loaded from checkpoint - if not self.is_gan_mode or self.discriminator is not None: - return - - print("#######################") - print("# Build Discriminator #") - print("#######################") - self.discriminator = build_discriminator(self.cfg, self.device) - self.discriminator_without_ddp = self.discriminator - - if self.args.distributed: - self.discriminator = torch.nn.parallel.DistributedDataParallel( - self.discriminator, - device_ids=[self.args.gpu], - find_unused_parameters=False, - ) - self.discriminator_without_ddp = self.discriminator.module - - if self.cfg.MODEL.SOURCE == "biapy" and self.cfg.MODEL.LOAD_CHECKPOINT and self.checkpoint_path: - checkpoint = torch.load(self.checkpoint_path, map_location=self.device) - if "discriminator_state_dict" in checkpoint: - self.discriminator_without_ddp.load_state_dict(checkpoint["discriminator_state_dict"]) - print("Discriminator weights loaded successfully.") - else: - print("Warning: 'discriminator_state_dict' not found in checkpoint.") - def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None: """ Call a regular Pytorch model. @@ -980,4 +939,4 @@ def get_value_manipulation(n2v_manipulator, n2v_neighborhood_radius): Callable Value manipulation function. """ - return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius))) \ No newline at end of file + return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius))) diff --git a/biapy/engine/metrics.py b/biapy/engine/metrics.py index 80bb21db1..09d4df49a 100644 --- a/biapy/engine/metrics.py +++ b/biapy/engine/metrics.py @@ -2167,7 +2167,36 @@ def forward( return loss + prediction.sum() * 0, float(iou), "IoU" # keep graph identical to originals class VGGLoss(nn.Module): + """Perceptual loss based on VGG16 feature activations. + + This loss compares intermediate VGG feature maps of prediction and target + images using an L1 distance. It is commonly used as a perceptual term in + image-to-image GAN training. + + Notes + ----- + - Uses pretrained ``torchvision.models.vgg16`` features up to layer ``:16``. + - Supports both 2D `(B, C, H, W)` and 3D `(B, C, D, H, W)` tensors. + For 3D inputs, depth is folded into batch to reuse 2D VGG. + - Single-channel inputs are replicated to 3 channels before VGG. + + References + ---------- + - Johnson et al., "Perceptual Losses for Real-Time Style Transfer and + Super-Resolution", ECCV 2016. + https://arxiv.org/abs/1603.08155 + - Implementation adapted for this project from: + https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + """ + def __init__(self, device): + """Initialize VGG perceptual loss. + + Parameters + ---------- + device : torch.device + Device where VGG features and loss operations are executed. + """ super().__init__() self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].eval().to(device) for param in self.vgg.parameters(): @@ -2176,6 +2205,20 @@ def __init__(self, device): self.preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def forward(self, pred, target): + """Compute perceptual distance between prediction and target. + + Parameters + ---------- + pred : torch.Tensor or dict + Predicted image tensor. If dict, prediction is taken from ``pred['pred']``. + target : torch.Tensor or dict + Target image tensor. If dict, target is taken from ``target['pred']``. + + Returns + ------- + torch.Tensor + Scalar perceptual loss value (L1 over VGG features). + """ if isinstance(pred, dict): pred = pred["pred"] if isinstance(target, dict): @@ -2199,11 +2242,45 @@ def forward(self, pred, target): return self.loss(pred_vgg, target_vgg) class ComposedGANLoss(nn.Module): + """Weighted composite loss for generator and discriminator training. + + This class combines multiple objectives for GAN-based image restoration: + + - Adversarial BCE term + - L1 reconstruction term + - MSE reconstruction term + - VGG perceptual term + - SSIM term + + Each term is controlled by configuration weights under + ``LOSS.COMPOSED_GAN``. Heavy components (VGG/SSIM modules) are created only + when their weight is greater than zero. + + References + ---------- + - Isola et al., "Image-to-Image Translation with Conditional Adversarial + Networks", CVPR 2017 (pix2pix). + https://arxiv.org/abs/1611.07004 + - Generator family inspiration (NAFNet/NAFSSR): + Chu et al., "NAFSSR: Stereo Image Super-Resolution Using NAFNet", + CVPR Workshops 2022. + https://openaccess.thecvf.com/content/CVPR2022W/NTIRE/html/Chu_NAFSSR_Stereo_Image_Super-Resolution_Using_NAFNet_CVPRW_2022_paper.html + - Structural/perceptual metrics are implemented with torchmetrics + (e.g., ``torchmetrics.image.StructuralSimilarityIndexMeasure``). + - Implementation adapted for this project from: + https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge """ - Dynamic composite loss for GANs. - Only instantiates heavy loss models (like VGG) if their config weight is > 0. - """ + def __init__(self, cfg, device): + """Initialize composed GAN loss from configuration. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Global configuration node. Uses ``cfg.LOSS.COMPOSED_GAN`` weights. + device : torch.device + Device where loss terms are computed. + """ super().__init__() self.device = device self.w_gan = cfg.LOSS.COMPOSED_GAN.LAMBDA_GAN @@ -2224,6 +2301,22 @@ def __init__(self, cfg, device): self.bce = nn.BCEWithLogitsLoss() def forward_generator(self, pred, target, d_fake): + """Compute weighted generator loss. + + Parameters + ---------- + pred : torch.Tensor or dict + Generator prediction. If dict, reads ``pred['pred']``. + target : torch.Tensor or dict + Ground-truth target. If dict, reads ``target['pred']``. + d_fake : torch.Tensor + Discriminator logits for generated samples. + + Returns + ------- + torch.Tensor + Scalar generator loss as weighted sum of active terms. + """ # Dict extraction if isinstance(pred, dict): pred = pred["pred"] if isinstance(target, dict): target = target["pred"] @@ -2265,6 +2358,22 @@ def forward_generator(self, pred, target, d_fake): return total_loss def forward_discriminator(self, d_real, d_fake): + """Compute discriminator adversarial loss. + + Uses BCE with one-sided label smoothing for real logits. + + Parameters + ---------- + d_real : torch.Tensor + Discriminator logits for real samples. + d_fake : torch.Tensor + Discriminator logits for generated samples. + + Returns + ------- + torch.Tensor + Scalar discriminator loss. + """ # Calculate Adversarial Loss for Discriminator real_loss = self.bce(d_real, torch.full_like(d_real, 0.9)) # Label smoothing (0.9 instead of 1.0) fake_loss = self.bce(d_fake, torch.zeros_like(d_fake)) diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index c1939cabc..7be23222a 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -29,18 +29,15 @@ def train_one_epoch( metric_function: Callable, prepare_targets: Callable, data_loader: DataLoader, - optimizer: Optimizer, + optimizer: list[Optimizer], device: torch.device, epoch: int, log_writer: Optional[TensorboardLogger] = None, - lr_scheduler: Optional[Scheduler] = None, + lr_scheduler: Optional[list[Optional[Scheduler]]] = None, verbose: bool = False, memory_bank: Optional[MemoryBank] = None, total_iters: int=0, contrast_warmup_iters: int=0, - model_d: Optional[nn.Module | nn.parallel.DistributedDataParallel] = None, - optimizer_d: Optional[Optimizer] = None, - lr_scheduler_d: Optional[Scheduler] = None, ): """ Train the model for one epoch. @@ -64,7 +61,7 @@ def train_one_epoch( Function to prepare targets for loss/metrics. data_loader : DataLoader Training data loader. - optimizer : Optimizer + optimizer : List[Optimizer] Optimizer for model parameters. device : torch.device Device to use. @@ -72,7 +69,7 @@ def train_one_epoch( Current epoch number. log_writer : TensorboardLogger, optional Logger for TensorBoard. - lr_scheduler : Scheduler, optional + lr_scheduler : List[Scheduler], optional Learning rate scheduler. verbose : bool, optional Verbosity flag. @@ -90,28 +87,23 @@ def train_one_epoch( int Number of steps (batches) processed. """ - is_gan = model_d is not None and optimizer_d is not None - # Switch to training mode model.train(True) - if is_gan: - model_d.train(True) + has_discriminator = hasattr(model, "discriminator") and model.discriminator is not None + lr_scheduler = [None] * len(optimizer) if lr_scheduler is None else lr_scheduler # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ", verbose=verbose) - if is_gan: - metric_logger.add_meter("loss_g", SmoothedValue()) - metric_logger.add_meter("loss_d", SmoothedValue()) - else: - metric_logger.add_meter("loss", SmoothedValue()) + for i in range(len(optimizer)): + loss_name = "loss" if i == 0 else f"loss_{i}" + metric_logger.add_meter(loss_name, SmoothedValue()) # Set up the header for logging header = "Epoch: [{}]".format(epoch + 1) print_freq = 10 - optimizer.zero_grad() - if is_gan: - optimizer_d.zero_grad() + for opt in optimizer: + opt.zero_grad() for step, (batch, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): @@ -120,13 +112,10 @@ def train_one_epoch( if ( epoch % cfg.TRAIN.ACCUM_ITER == 0 and cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine" - and lr_scheduler - and isinstance(lr_scheduler, WarmUpCosineDecayScheduler) ): - lr_scheduler.adjust_learning_rate(optimizer, step / len(data_loader) + epoch) - if is_gan and lr_scheduler_d and isinstance(lr_scheduler_d, WarmUpCosineDecayScheduler): - lr_scheduler_d.adjust_learning_rate(optimizer_d, step / len(data_loader) + epoch) - + for sched, opt in zip(lr_scheduler, optimizer): + if sched and isinstance(sched, WarmUpCosineDecayScheduler): + sched.adjust_learning_rate(opt, step / len(data_loader) + epoch) # Gather inputs targets = prepare_targets(targets, batch) @@ -137,91 +126,25 @@ def train_one_epoch( f" Input: {batch.shape[1:-1]} vs PATCH_SIZE: {cfg.DATA.PATCH_SIZE[:-1]}" ) - if is_gan: - assert model_d is not None and optimizer_d is not None - - if ( - torch.isnan(batch).any() - or torch.isinf(batch).any() - or torch.isnan(targets).any() - or torch.isinf(targets).any() - ): - print("Warning: NaN or Inf detected in input. Skipping batch.") - continue - - # Phase 1: discriminator update - optimizer_d.zero_grad() - fake_img = model_call_func(batch, is_train=True) - if isinstance(fake_img, dict): - fake_img = fake_img["pred"] - fake_img = torch.clamp(fake_img, 0, 1) - - d_real = model_d(targets) - d_fake = model_d(fake_img.detach()) - loss_d = loss_function.forward_discriminator(d_real, d_fake) - - if torch.isnan(loss_d) or torch.isinf(loss_d): - print("Warning: NaN or Inf detected in discriminator loss. Skipping batch.") - continue - - loss_d.backward() - optimizer_d.step() - - if lr_scheduler_d and isinstance(lr_scheduler_d, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler_d.step() - - # Phase 2: generator update - optimizer.zero_grad() - outputs = model_call_func(batch, is_train=True) - if isinstance(outputs, dict): - outputs = outputs["pred"] - outputs = torch.clamp(outputs, 0, 1) - - d_fake_for_g = model_d(outputs) - loss = loss_function.forward_generator(outputs, targets, d_fake_for_g) - - if torch.isnan(loss) or torch.isinf(loss): - print("Warning: NaN or Inf detected in generator loss. Skipping batch.") - continue - - loss.backward() - optimizer.step() - - if lr_scheduler and isinstance(lr_scheduler, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler.step() - - metric_function(outputs, targets, metric_logger=metric_logger) - - loss_g_value = loss.item() - loss_d_value = loss_d.item() - metric_logger.update(loss_g=loss_g_value, loss_d=loss_d_value) - - if log_writer: - log_writer.update(loss_g=all_reduce_mean(loss_g_value), head="loss") - log_writer.update(loss_d=all_reduce_mean(loss_d_value), head="loss") - - max_lr_g = 0.0 - max_lr_d = 0.0 - for group in optimizer.param_groups: - max_lr_g = max(max_lr_g, group["lr"]) - for group in optimizer_d.param_groups: - max_lr_d = max(max_lr_d, group["lr"]) - - if step == 0: - metric_logger.add_meter("lr_g", SmoothedValue(window_size=1, fmt="{value:.6f}")) - metric_logger.add_meter("lr_d", SmoothedValue(window_size=1, fmt="{value:.6f}")) - - metric_logger.update(lr_g=max_lr_g, lr_d=max_lr_d) - if log_writer: - log_writer.update(lr_g=max_lr_g, head="opt") - log_writer.update(lr_d=max_lr_d, head="opt") - continue - # Pass the images through the model outputs = model_call_func(batch, is_train=True) # Loss function call - if memory_bank is not None: + losses = [] + if has_discriminator and len(optimizer) > 1: + fake_img = outputs["pred"] if isinstance(outputs, dict) else outputs + fake_img = torch.clamp(fake_img, 0, 1) + + d_fake_for_g = model.discriminator(fake_img) + loss_g = loss_function.forward_generator(fake_img, targets, d_fake_for_g) + losses.append(loss_g) + + d_real = model.discriminator(targets) + d_fake = model.discriminator(fake_img.detach()) + loss_d = loss_function.forward_discriminator(d_real, d_fake) + losses.append(loss_d) + + elif memory_bank is not None: if total_iters + step >= contrast_warmup_iters: with_embed = True else: @@ -240,20 +163,23 @@ def train_one_epoch( memory_bank.dequeue_and_enqueue( outputs['key'], targets.detach(), ) + losses.append(loss) else: loss = loss_function(outputs, targets) + losses.append(loss) # Separate metric if precalculated inside the loss (e.g. Embedding loss) precalculated_metric, precalculated_metric_name = None, None - if isinstance(loss, tuple): - precalculated_metric = loss[1] - precalculated_metric_name = loss[2] - loss = loss[0] + if isinstance(losses[0], tuple): + precalculated_metric = losses[0][1] + precalculated_metric_name = losses[0][2] + losses[0] = losses[0][0] - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) + for l_val in losses: + loss_value = l_val.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) # Calculate the metrics if precalculated_metric is None: @@ -262,38 +188,45 @@ def train_one_epoch( metric_logger.meters[precalculated_metric_name].update(precalculated_metric) # Forward pass scaling the loss - loss /= cfg.TRAIN.ACCUM_ITER if (step + 1) % cfg.TRAIN.ACCUM_ITER == 0: - loss.backward() - optimizer.step() # update weight - optimizer.zero_grad() - if lr_scheduler and isinstance(lr_scheduler, OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": - lr_scheduler.step() + for i, (opt, loss_tensor) in enumerate(zip(optimizer, losses)): + loss_tensor = loss_tensor / cfg.TRAIN.ACCUM_ITER + + loss_tensor.backward() + opt.step() # update weight + opt.zero_grad() + + if lr_scheduler[i] and isinstance(lr_scheduler[i], OneCycleLR) and cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": + lr_scheduler[i].step() if device.type != "cpu": getattr(torch, device.type).synchronize() # Update loss in loggers - metric_logger.update(loss=loss_value) - loss_value_reduce = all_reduce_mean(loss_value) - if log_writer: - log_writer.update(loss=loss_value_reduce, head="loss") + for i, loss_tensor in enumerate(losses): + loss_name = "loss" if i == 0 else f"loss_{i}" + val = loss_tensor.item() * cfg.TRAIN.ACCUM_ITER + metric_logger.update(**{loss_name: val}) + loss_value_reduce = all_reduce_mean(val) + if log_writer: + log_writer.update(head="loss", **{loss_name: loss_value_reduce}) # Update lr in loggers - max_lr = 0.0 - for group in optimizer.param_groups: - max_lr = max(max_lr, group["lr"]) - if step == 0: - metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) - metric_logger.update(lr=max_lr) - if log_writer: - log_writer.update(lr=max_lr, head="opt") - - if is_gan and cfg.TRAIN.LR_SCHEDULER.NAME not in ["reduceonplateau", "onecycle", "warmupcosine"]: - if lr_scheduler: - lr_scheduler.step() - if lr_scheduler_d: - lr_scheduler_d.step() + for i, opt in enumerate(optimizer): + lr_name = "lr" if i == 0 else f"lr_{i}" + max_lr = 0.0 + for group in opt.param_groups: + max_lr = max(max_lr, group["lr"]) + if step == 0: + metric_logger.add_meter(lr_name, SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.update(**{lr_name: max_lr}) + if log_writer: + log_writer.update(head="opt", **{lr_name: max_lr}) + + if cfg.TRAIN.LR_SCHEDULER.NAME not in ["reduceonplateau", "onecycle", "warmupcosine"]: + for sched in lr_scheduler: + if sched: + sched.step() # Gather the stats from all processes metric_logger.synchronize_between_processes() @@ -311,11 +244,8 @@ def evaluate( prepare_targets: Callable, epoch: int, data_loader: DataLoader, - lr_scheduler: Optional[Scheduler] = None, + lr_scheduler: Optional[list[Optional[Scheduler]]] = None, memory_bank: Optional[MemoryBank] = None, - model_d: Optional[nn.Module | nn.parallel.DistributedDataParallel] = None, - lr_scheduler_d: Optional[Scheduler] = None, - device: Optional[torch.device] = None, ): """ Evaluate the model on the validation set. @@ -351,21 +281,17 @@ def evaluate( dict Dictionary of averaged metrics for the validation set. """ - is_gan = model_d is not None - + has_discriminator = hasattr(model, "discriminator") and model.discriminator is not None # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ") - if is_gan: - metric_logger.add_meter("loss_g", SmoothedValue()) - metric_logger.add_meter("loss_d", SmoothedValue()) - else: - metric_logger.add_meter("loss", SmoothedValue()) + num_losses = 2 if has_discriminator and len(lr_scheduler) > 1 else 1 + for i in range(num_losses): + loss_name = "loss" if i == 0 else f"loss_{i}" + metric_logger.add_meter(loss_name, SmoothedValue()) header = "Epoch: [{}]".format(epoch + 1) # Switch to evaluation mode model.eval() - if is_gan: - model_d.eval() for batch in metric_logger.log_every(data_loader, 10, header): # Gather inputs @@ -373,34 +299,25 @@ def evaluate( targets = batch[1] targets = prepare_targets(targets, images) - if is_gan: - assert model_d is not None - outputs = model_call_func(images, is_train=False) - if isinstance(outputs, dict): - outputs = outputs["pred"] - outputs = torch.clamp(outputs, 0, 1) - - d_fake_val = model_d(outputs) - d_real_val = model_d(targets) - loss = loss_function.forward_generator(outputs, targets, d_fake_val) - loss_d = loss_function.forward_discriminator(d_real_val, d_fake_val) - - loss_value = loss.item() - if not math.isfinite(loss_value): - print(f"Validation loss is {loss_value}, skipping batch.") - continue - - metric_function(outputs, targets, metric_logger=metric_logger) - metric_logger.update(loss_g=loss_value, loss_d=loss_d.item()) - continue - # Pass the images through the model - outputs = model_call_func(images, is_train=True) - + outputs = model_call_func(images, is_train=True) # Im not Undertanding why is this True? + # Loss function call - if memory_bank is not None: - with_embed = False + losses = [] + if has_discriminator and len(lr_scheduler) > 1: + fake_img = outputs["pred"] if isinstance(outputs, dict) else outputs + fake_img = torch.clamp(fake_img, 0, 1) + + d_fake_for_g = model.discriminator(fake_img) + loss_g = loss_function.forward_generator(fake_img, targets, d_fake_for_g) + losses.append(loss_g) + d_real = model.discriminator(targets) + d_fake = model.discriminator(fake_img.detach()) + loss_d = loss_function.forward_discriminator(d_real, d_fake) + losses.append(loss_d) + + elif memory_bank is not None: outputs = { "pred": outputs["pred"], "embed": outputs["embed"], @@ -408,22 +325,24 @@ def evaluate( 'pixel_queue': memory_bank.pixel_queue, 'segment_queue': memory_bank.segment_queue, } - loss = loss_function(outputs, targets, with_embed=with_embed) + losses.append(loss) else: loss = loss_function(outputs, targets) + losses.append(loss) # Separate metric if precalculated inside the loss (e.g. Embedding loss) precalculated_metric, precalculated_metric_name = None, None - if isinstance(loss, tuple): - precalculated_metric = loss[1] - precalculated_metric_name = loss[2] - loss = loss[0] + if isinstance(losses[0], tuple): + precalculated_metric = losses[0][1] + precalculated_metric_name = losses[0][2] + losses[0] = losses[0][0] - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) + for l_val in losses: + loss_value = l_val.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) # Calculate the metrics if precalculated_metric is not None: @@ -432,7 +351,9 @@ def evaluate( metric_function(outputs, targets, metric_logger=metric_logger) # Update loss in loggers - metric_logger.update(loss=loss) + for i, loss_tensor in enumerate(losses): + loss_name = "loss" if i == 0 else f"loss_{i}" + metric_logger.update(**{loss_name: loss_tensor.item()}) # Gather the stats from all processes metric_logger.synchronize_between_processes() @@ -441,11 +362,9 @@ def evaluate( # Apply reduceonplateau scheduler if the global validation has been reduced if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": - if is_gan: - if lr_scheduler and isinstance(lr_scheduler, ReduceLROnPlateau): - lr_scheduler.step(metric_logger.meters["loss_g"].global_avg, epoch=epoch) - if lr_scheduler_d and isinstance(lr_scheduler_d, ReduceLROnPlateau): - lr_scheduler_d.step(metric_logger.meters["loss_d"].global_avg, epoch=epoch) - elif lr_scheduler and isinstance(lr_scheduler, ReduceLROnPlateau): - lr_scheduler.step(metric_logger.meters["loss"].global_avg, epoch=epoch) + for i, sched in enumerate(lr_scheduler): + if sched and isinstance(sched, ReduceLROnPlateau): + loss_name = "loss" if i == 0 else f"loss_{i}" + sched.step(metric_logger.meters[loss_name].global_avg, epoch=epoch) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/biapy/models/__init__.py b/biapy/models/__init__.py index 238903f77..e34740190 100644 --- a/biapy/models/__init__.py +++ b/biapy/models/__init__.py @@ -375,7 +375,9 @@ def build_model( dec_blk_nums=cfg.MODEL.NAFNET.DEC_BLK_NUMS, drop_out_rate=cfg.MODEL.DROPOUT_VALUES[0], dw_expand=cfg.MODEL.NAFNET.DW_EXPAND, - ffn_expand=cfg.MODEL.NAFNET.FFN_EXPAND + ffn_expand=cfg.MODEL.NAFNET.FFN_EXPAND, + discriminator_arch=cfg.MODEL.NAFNET.ARCHITECTURE_D, + patchgan_base_filters=cfg.MODEL.NAFNET.PATCHGAN.BASE_FILTERS, ) callable_model = NAFNet # type: ignore model = callable_model(**args) # type: ignore @@ -419,76 +421,6 @@ def build_model( return model, str(callable_model.__name__), collected_sources, all_import_lines, scanned_files, args, network_stride # type: ignore -def build_discriminator(cfg: CN, device: torch.device): - """ - Build selected model. - - Parameters - ---------- - cfg : YACS CN object - Configuration. - - device : Torch device - Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", - "xpu", "xla" or "meta". - - Returns - ------- - """ - # 1. Standardize name and Import the module - modelname = str(cfg.MODEL.ARCHITECTURE_D).lower() - - print("###############") - print(f"# Build {modelname.upper()} Disc #") - print("###############") - - # Dynamic import like build_model - mdl = import_module("biapy.models." + modelname) - - names = [x for x in mdl.__dict__ if not x.startswith("_")] - globals().update({k: getattr(mdl, k) for k in names}) - - # 2. Model building block - if modelname == "patchgan": - args = dict( - in_channels=cfg.DATA.PATCH_SIZE[-1], - base_filters=cfg.MODEL.PATCHGAN.BASE_FILTERS - ) - callable_model = PatchGANDiscriminator # type: ignore - else: - raise ValueError(f"Discriminator {modelname} is not implemented or registered.") - - # Instantiate - model = callable_model(**args) - model.to(device) - - # 3. Summary Logic - if cfg.PROBLEM.NDIM == "2D": - sample_size = ( - 1, - cfg.DATA.PATCH_SIZE[2], - cfg.DATA.PATCH_SIZE[0], - cfg.DATA.PATCH_SIZE[1], - ) - else: - sample_size = ( - 1, - cfg.DATA.PATCH_SIZE[3], - cfg.DATA.PATCH_SIZE[0], - cfg.DATA.PATCH_SIZE[1], - cfg.DATA.PATCH_SIZE[2], - ) - - summary( - model, - input_size=sample_size, - col_names=("input_size", "output_size", "num_params"), - depth=10, - device=device.type, - ) - - return model - def init_embedding_output(model: nn.Module, n_sigma: int = 2): """ Initialize the output layer of the model for embedding. diff --git a/biapy/models/nafnet.py b/biapy/models/nafnet.py index 868f1fdf1..38b7f0cc0 100644 --- a/biapy/models/nafnet.py +++ b/biapy/models/nafnet.py @@ -1,21 +1,103 @@ +"""NAFNet model components and GAN discriminator builder utilities. + +This module provides: + +1. Lightweight building blocks (`SimpleGate`, `LayerNorm2d`, `NAFBlock`) used + by NAFNet. +2. The `NAFNet` encoder-decoder model for image restoration / image-to-image + workflows. +3. A discriminator builder function used by GAN-based training setups in BiaPy. + +Compared with traditional restoration backbones, NAFNet simplifies nonlinear +design while preserving strong reconstruction quality via gated depthwise blocks +and residual scaling. + +Reference +--------- +`Simple Baselines for Image Restoration `_. + +Related Work +------------ +The generator design is also inspired by the NAFSSR family: +`NAFSSR: Stereo Image Super-Resolution Using NAFNet +. +Implementation adapted for this project from: +https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge +Citation +-------- +Chu, Xiaojie and Chen, Liangyu and Yu, Wenqing. "NAFSSR: Stereo Image +Super-Resolution Using NAFNet." CVPR Workshops, 2022. +""" + import torch import torch.nn as nn import torch.nn.functional as F -import math +from yacs.config import CfgNode as CN +from torchinfo import summary + +from biapy.models.patchgan import PatchGANDiscriminator class SimpleGate(nn.Module): + """Simple channel-gating operator used in NAF blocks. + + The input tensor is split into two equal channel groups and both parts are + multiplied element-wise. + """ + def forward(self, x): + """Apply channel-wise gating. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)` where `C` must be divisible + by 2. + + Returns + ------- + torch.Tensor + Tensor with shape `(N, C/2, H, W)` obtained by multiplying both + channel chunks element-wise. + """ x1, x2 = x.chunk(2, dim=1) return x1 * x2 + class LayerNorm2d(nn.Module): + """Layer normalization over channel dimension for 2D features. + + This normalization computes mean and variance across channels for each + spatial position and applies learned affine parameters. + """ + def __init__(self, channels, eps=1e-6): + """Initialize layer normalization parameters. + + Parameters + ---------- + channels : int + Number of channels in the input tensor. + eps : float, optional + Numerical stability constant added to the variance. + """ super(LayerNorm2d, self).__init__() self.register_parameter('weight', nn.Parameter(torch.ones(channels))) self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) self.eps = eps def forward(self, x): + """Normalize each spatial position across channels. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Normalized tensor with same shape as input. + """ N, C, H, W = x.size() mu = x.mean(1, keepdim=True) var = (x - mu).pow(2).mean(1, keepdim=True) @@ -23,8 +105,32 @@ def forward(self, x): y = self.weight.view(1, C, 1, 1) * y + self.bias.view(1, C, 1, 1) return y + class NAFBlock(nn.Module): + """Core NAFNet residual block. + + The block combines: + 1. Layer normalization. + 2. Pointwise + depthwise convolutions. + 3. `SimpleGate` and simplified channel attention. + 4. A lightweight FFN branch. + 5. Two residual scaling parameters (`beta`, `gamma`). + """ + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + """Initialize one NAF block. + + Parameters + ---------- + c : int + Number of input/output channels in the block. + DW_Expand : int, optional + Expansion ratio for the depthwise branch before gating. + FFN_Expand : int, optional + Expansion ratio for the feed-forward branch. + drop_out_rate : float, optional + Dropout probability used in both residual branches. + """ super().__init__() dw_channel = c * DW_Expand self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) @@ -54,6 +160,18 @@ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) def forward(self, inp): + """Apply the NAF block transformation. + + Parameters + ---------- + inp : torch.Tensor + Input feature map with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Output feature map with the same shape as `inp`. + """ x = inp x = self.norm1(x) @@ -77,6 +195,15 @@ def forward(self, inp): return y + x * self.gamma class NAFNet(nn.Module): + """NAFNet encoder-decoder architecture for image restoration. + + The model follows a U-shaped design with: + 1. Intro and ending convolutions. + 2. Multiple encoder stages with downsampling. + 3. Bottleneck NAF blocks. + 4. Decoder stages with PixelShuffle upsampling and skip connections. + """ + def __init__( self, img_channel=3, @@ -86,8 +213,36 @@ def __init__( dec_blk_nums=[], drop_out_rate=0.0, dw_expand=2, - ffn_expand=2 + ffn_expand=2, + discriminator_arch=None, + patchgan_base_filters=64, ): + """Initialize a NAFNet model. + + Parameters + ---------- + img_channel : int, optional + Number of input/output image channels. + width : int, optional + Base number of channels. + middle_blk_num : int, optional + Number of NAF blocks in the bottleneck. + enc_blk_nums : list[int], optional + Number of NAF blocks per encoder stage. + dec_blk_nums : list[int], optional + Number of NAF blocks per decoder stage. + drop_out_rate : float, optional + Dropout probability used inside blocks. + dw_expand : int, optional + Expansion ratio for depthwise branch. + ffn_expand : int, optional + Expansion ratio for feed-forward branch. + + Notes + ----- + Spatial padding is handled in `check_image_size` to ensure dimensions are + divisible by the encoder downsampling factor. + """ super().__init__() self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True) @@ -131,8 +286,35 @@ def __init__( ) self.padder_size = 2 ** len(self.encoders) + + discriminator = None + if discriminator_arch == "patchgan": + discriminator = PatchGANDiscriminator( + in_channels=img_channel, + base_filters=patchgan_base_filters, + ) + + self.discriminator = discriminator + def forward(self, inp): + """Run a forward pass through NAFNet. + + Parameters + ---------- + inp : torch.Tensor + Input image tensor with shape `(N, C, H, W)`. + + Notes + ----- + The input is internally padded to satisfy the downsampling factor and + then cropped back to original size at the end of the forward pass. + + Returns + ------- + torch.Tensor + Restored image with original spatial size `(H, W)`. + """ B, C, H, W = inp.shape inp = self.check_image_size(inp) @@ -158,8 +340,20 @@ def forward(self, inp): return x[:, :, :H, :W] def check_image_size(self, x): + """Pad image so height/width are divisible by internal stride. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Padded tensor compatible with encoder/decoder downsampling. + """ _, _, h, w = x.size() mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) - return x \ No newline at end of file + return x diff --git a/biapy/models/patchgan.py b/biapy/models/patchgan.py index bf8fa8933..b7e736ddf 100644 --- a/biapy/models/patchgan.py +++ b/biapy/models/patchgan.py @@ -1,10 +1,70 @@ +"""PatchGAN discriminator model used in image-to-image GAN training. + +This module implements a convolutional discriminator that predicts realism at +the patch level instead of producing a single global score. Patch-level +classification is commonly used in conditional GAN pipelines because it +emphasizes local texture and edge consistency, which is especially useful in +restoration and translation tasks. + +Classes +------- +PatchGANDiscriminator + Multi-layer convolutional discriminator with strided downsampling blocks and + a final 1-channel logits map. + +Notes +----- +The output tensor shape is `(N, 1, H_patch, W_patch)`, where each spatial value +acts as a local real/fake logit for a receptive-field patch in the input image. + +Implementation adapted for this project from: +https://github.com/GolpedeRemo37/NafNet-in-AI4Life-Microscopy-Supervised-Denoising-Challenge + +""" + import torch.nn as nn + class PatchGANDiscriminator(nn.Module): + """PatchGAN discriminator based on strided convolutional blocks. + + Parameters + ---------- + in_channels : int, optional + Number of channels in the input image. + base_filters : int, optional + Number of filters in the first discriminator block. Each subsequent + block doubles this value. + + Notes + ----- + The architecture follows a typical PatchGAN design: + 1. Four convolutional downsampling blocks. + 2. Batch normalization on all blocks except the first one. + 3. LeakyReLU activations. + 4. Final convolution producing a patch-logits map. + """ + def __init__(self, in_channels=1, base_filters=64): super(PatchGANDiscriminator, self).__init__() def discriminator_block(in_filters, out_filters, normalization=True): + """Create one discriminator stage. + + Parameters + ---------- + in_filters : int + Number of input channels. + out_filters : int + Number of output channels. + normalization : bool, optional + Whether to include BatchNorm after convolution. + + Returns + ------- + list[nn.Module] + Layers composing one stage of the discriminator. + """ layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] if normalization: layers.append(nn.BatchNorm2d(out_filters)) @@ -20,4 +80,16 @@ def discriminator_block(in_filters, out_filters, normalization=True): ) def forward(self, img): + """Run a forward pass through the discriminator. + + Parameters + ---------- + img : torch.Tensor + Input tensor with shape `(N, C, H, W)`. + + Returns + ------- + torch.Tensor + Patch-wise realism logits with shape `(N, 1, H_patch, W_patch)`. + """ return self.model(img) \ No newline at end of file diff --git a/biapy/utils/misc.py b/biapy/utils/misc.py index 9c4dffa35..56aa1c8c4 100644 --- a/biapy/utils/misc.py +++ b/biapy/utils/misc.py @@ -305,19 +305,7 @@ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: return total_norm -def save_model( - cfg, - biapy_version, - jobname, - epoch, - model_without_ddp, - optimizer, - model_build_kwargs=None, - extension="pth", - discriminator_without_ddp=None, - optimizer_d=None, - extra_checkpoint_items=None, -): +def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, model_build_kwargs=None, extension="pth"): """ Save the model checkpoint to the specified path. @@ -337,7 +325,7 @@ def save_model( The current epoch number. model_without_ddp : nn.Module The model instance, typically the unwrapped model if using DistributedDataParallel. - optimizer : torch.optim.Optimizer + optimizer : List[torch.optim.Optimizer] The optimizer's state. model_build_kwargs : Optional[Dict], optional Keyword arguments used to build the model, useful for re-instantiating @@ -345,12 +333,6 @@ def save_model( extension : str, optional The file extension for the checkpoint file. Options are 'pth' (native PyTorch format) or 'safetensors' (https://github.com/huggingface/safetensors). Defaults to "pth". - discriminator_without_ddp : Optional[nn.Module], optional - Optional discriminator model to include in checkpoints for GAN training. - optimizer_d : Optional[torch.optim.Optimizer], optional - Optional discriminator optimizer state to include in checkpoints for GAN training. - extra_checkpoint_items : Optional[dict], optional - Additional custom fields to append to the checkpoint payload. Returns ------- @@ -364,18 +346,15 @@ def save_model( to_save = { "model_build_kwargs": model_build_kwargs, "model": model_without_ddp.state_dict(), - "optimizer": optimizer.state_dict(), + "optimizer": [opt.state_dict() for opt in optimizer] if optimizer else None, "epoch": epoch, "cfg": cfg, "biapy_version": biapy_version, } - - if discriminator_without_ddp is not None: - to_save["discriminator_state_dict"] = discriminator_without_ddp.state_dict() - if optimizer_d is not None: - to_save["optimizer_d_state_dict"] = optimizer_d.state_dict() - if extra_checkpoint_items: - to_save.update(extra_checkpoint_items) + + # For Gan Models + if hasattr(model_without_ddp, 'discriminator'): + to_save["discriminator_state_dict"] = model_without_ddp.discriminator.state_dict() save_on_master(to_save, checkpoint_path) if len(checkpoint_paths) > 0: @@ -479,8 +458,8 @@ def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=Non The model instance (unwrapped if DDP is used) to load weights into. device : torch.device The device to map the loaded checkpoint to. - optimizer : Optional[torch.optim.Optimizer], optional - The optimizer instance to load state into. If None, optimizer state is not loaded. + optimizer : Optional[List[torch.optim.Optimizer]], optional + The list of optimizer instances to load state into. If None, optimizer state is not loaded. Defaults to None. just_extract_checkpoint_info : bool, optional If True, only the configuration (`cfg`) and BiaPy version from the checkpoint @@ -575,13 +554,27 @@ def load_model_checkpoint(cfg, jobname, model_without_ddp, device, optimizer=Non print("Model weights loaded!") + if "discriminator_state_dict" in checkpoint: + if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None: + # We use strict=False just in case there are minor architecture changes + model_without_ddp.discriminator.load_state_dict(checkpoint["discriminator_state_dict"], strict=False) + print("Discriminator weights loaded!") + if cfg.MODEL.LOAD_CHECKPOINT_ONLY_WEIGHTS: return start_epoch, resume # Load also opt, epoch and scaler info if "optimizer" in checkpoint and optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer"], strict=False) - print("Optimizer info loaded!") + # im leaving this for prior non list optimizers for backward compatibility, + checkpoint_optimizer = checkpoint["optimizer"] + if isinstance(checkpoint_optimizer, dict): + checkpoint_optimizer = [checkpoint_optimizer] + + loaded_optimizers = 0 + for opt, opt_state in zip(optimizer, checkpoint_optimizer): + opt.load_state_dict(opt_state) + loaded_optimizers += 1 + print(f"Optimizer info loaded for {loaded_optimizers}/{len(optimizer)} optimizer(s)!") if "epoch" in checkpoint: start_epoch = checkpoint["epoch"] diff --git a/biapy/utils/scripts/run_checks.py b/biapy/utils/scripts/run_checks.py index 0252a2661..7416af2c1 100644 --- a/biapy/utils/scripts/run_checks.py +++ b/biapy/utils/scripts/run_checks.py @@ -2274,7 +2274,7 @@ def runjob(test_info, results_folder, yaml_file, biapy_folder, multigpu=False, b biapy_config['TRAIN']['PATIENCE'] = -1 biapy_config['TRAIN']['LR_SCHEDULER'] = {} biapy_config['TRAIN']['LR_SCHEDULER']['NAME'] = 'onecycle' - biapy_config['TRAIN']['LR'] = 0.001 + biapy_config['TRAIN']['LR'] = [0.001] biapy_config['TRAIN']['BATCH_SIZE'] = 16 biapy_config['MODEL']['ARCHITECTURE'] = 'resunet' @@ -3334,7 +3334,7 @@ def runjob(test_info, results_folder, yaml_file, biapy_folder, multigpu=False, b biapy_config['TRAIN']['EPOCHS'] = 100 biapy_config['TRAIN']['BATCH_SIZE'] = 1 biapy_config['TRAIN']['PATIENCE'] = 20 - biapy_config['TRAIN']['LR'] = 0.0001 + biapy_config['TRAIN']['LR'] = [0.0001] biapy_config['TRAIN']['LR_SCHEDULER'] = {} biapy_config['TRAIN']['LR_SCHEDULER']['NAME'] = 'warmupcosine' biapy_config['TRAIN']['LR_SCHEDULER']['MIN_LR'] = 5.E-6 @@ -3449,8 +3449,8 @@ def runjob(test_info, results_folder, yaml_file, biapy_folder, multigpu=False, b biapy_config['TRAIN']['ENABLE'] = True biapy_config['TRAIN']['EPOCHS'] = 80 biapy_config['TRAIN']['PATIENCE'] = -1 - biapy_config['TRAIN']['OPTIMIZER'] = "ADAMW" - biapy_config['TRAIN']['LR'] = 1.E-4 + biapy_config['TRAIN']['OPTIMIZER'] = ["ADAMW"] + biapy_config['TRAIN']['LR'] = [1.E-4] biapy_config['TRAIN']['LR_SCHEDULER'] = {} biapy_config['TRAIN']['LR_SCHEDULER']['NAME'] = 'warmupcosine' biapy_config['TRAIN']['LR_SCHEDULER']['MIN_LR'] = 5.E-6 From cf27ff0d2ecf8b18658273aa6fdbc7932bec9213 Mon Sep 17 00:00:00 2001 From: Ibai Date: Wed, 8 Apr 2026 11:19:04 +0200 Subject: [PATCH 3/5] small modifications for comments and cleanup --- biapy/engine/__init__.py | 4 ++-- biapy/engine/train_engine.py | 5 ----- biapy/utils/misc.py | 2 +- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 0656e8e4c..8a93aa2aa 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -40,9 +40,9 @@ def prepare_optimizer( Returns ------- - optimizer : List[Optimizer] + optimizers : List[Optimizer] Configured optimizers for the models. - lr_scheduler : Scheduler or None + lr_schedulers : List[Scheduler | None] Configured learning rate schedulers, or None if not specified. """ diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index 7be23222a..b89c8c8e8 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -223,11 +223,6 @@ def train_one_epoch( if log_writer: log_writer.update(head="opt", **{lr_name: max_lr}) - if cfg.TRAIN.LR_SCHEDULER.NAME not in ["reduceonplateau", "onecycle", "warmupcosine"]: - for sched in lr_scheduler: - if sched: - sched.step() - # Gather the stats from all processes metric_logger.synchronize_between_processes() print("[Train] averaged stats:", metric_logger) diff --git a/biapy/utils/misc.py b/biapy/utils/misc.py index 56aa1c8c4..9af6c2f6a 100644 --- a/biapy/utils/misc.py +++ b/biapy/utils/misc.py @@ -346,7 +346,7 @@ def save_model(cfg, biapy_version, jobname, epoch, model_without_ddp, optimizer, to_save = { "model_build_kwargs": model_build_kwargs, "model": model_without_ddp.state_dict(), - "optimizer": [opt.state_dict() for opt in optimizer] if optimizer else None, + "optimizer": [opt.state_dict() for opt in optimizer], # should i check if none? if i leave empty it uses default. "epoch": epoch, "cfg": cfg, "biapy_version": biapy_version, From 60524a9d2122bddd880e044295806ca0e255c58f Mon Sep 17 00:00:00 2001 From: Ibai Date: Wed, 8 Apr 2026 12:12:28 +0200 Subject: [PATCH 4/5] small modification on losses --- biapy/engine/__init__.py | 12 +++++++----- biapy/engine/train_engine.py | 9 +++++++-- biapy/utils/util.py | 27 +++++++++++++++------------ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/biapy/engine/__init__.py b/biapy/engine/__init__.py index 8a93aa2aa..700058c4a 100644 --- a/biapy/engine/__init__.py +++ b/biapy/engine/__init__.py @@ -51,11 +51,13 @@ def prepare_optimizer( if hasattr(model_without_ddp, 'discriminator') and model_without_ddp.discriminator is not None: param_groups = [ - [p for n, p in model_without_ddp.named_parameters() if not n.startswith("discriminator.")], # Generator - model_without_ddp.discriminator.parameters() # Discriminator + # Generator + [p for n, p in model_without_ddp.named_parameters() if not n.startswith("discriminator.")], # should this be and p.requires_grad, same below? + # Discriminator + [p for p in model_without_ddp.discriminator.parameters()] ] else: - param_groups = [model_without_ddp.parameters()] + param_groups = [[p for p in model_without_ddp.parameters()]] ## Not quite sure if this is the best place to do this if len(cfg.TRAIN.OPTIMIZER) != len(param_groups): @@ -66,14 +68,14 @@ def prepare_optimizer( ) for i in range(len(cfg.TRAIN.OPTIMIZER)): - lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR + lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else [cfg.TRAIN.LR_SCHEDULER.MIN_LR] * len(cfg.TRAIN.LR) opt_args = {} if cfg.TRAIN.OPTIMIZER[i] in ["ADAM", "ADAMW"]: opt_args["betas"] = cfg.TRAIN.OPT_BETAS[i] if i < len(cfg.TRAIN.OPT_BETAS) else cfg.TRAIN.OPT_BETAS[0] optimizer = timm.optim.create_optimizer_v2( param_groups[i], opt=cfg.TRAIN.OPTIMIZER[i], - lr=lr, + lr=lr[i], weight_decay=cfg.TRAIN.W_DECAY, **opt_args, ) diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index b89c8c8e8..46089f85d 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -192,6 +192,9 @@ def train_one_epoch( for i, (opt, loss_tensor) in enumerate(zip(optimizer, losses)): loss_tensor = loss_tensor / cfg.TRAIN.ACCUM_ITER + if has_discriminator and i == 1: + opt.zero_grad() + loss_tensor.backward() opt.step() # update weight opt.zero_grad() @@ -279,7 +282,7 @@ def evaluate( has_discriminator = hasattr(model, "discriminator") and model.discriminator is not None # Ensure correct order of each epoch info by adding loss first metric_logger = MetricLogger(delimiter=" ") - num_losses = 2 if has_discriminator and len(lr_scheduler) > 1 else 1 + num_losses = 2 if has_discriminator else 1 for i in range(num_losses): loss_name = "loss" if i == 0 else f"loss_{i}" metric_logger.add_meter(loss_name, SmoothedValue()) @@ -299,7 +302,7 @@ def evaluate( # Loss function call losses = [] - if has_discriminator and len(lr_scheduler) > 1: + if has_discriminator: fake_img = outputs["pred"] if isinstance(outputs, dict) else outputs fake_img = torch.clamp(fake_img, 0, 1) @@ -313,6 +316,8 @@ def evaluate( losses.append(loss_d) elif memory_bank is not None: + with_embed = False + outputs = { "pred": outputs["pred"], "embed": outputs["embed"], diff --git a/biapy/utils/util.py b/biapy/utils/util.py index be1e6e6a5..a1eca0745 100644 --- a/biapy/utils/util.py +++ b/biapy/utils/util.py @@ -76,18 +76,21 @@ def create_plots(results, metrics, job_id, chartOutDir): os.environ["QT_QPA_PLATFORM"] = "offscreen" # Loss - plt.plot(results["loss"]) - if "val_loss" in results: - plt.plot(results["val_loss"]) - plt.title("Model JOBID=" + job_id + " loss") - plt.ylabel("Value") - plt.xlabel("Epoch") - if "val_loss" in results: - plt.legend(["Train loss", "Val. loss"], loc="upper left") - else: - plt.legend(["Train loss"], loc="upper left") - plt.savefig(os.path.join(chartOutDir, job_id + "_loss.png")) - plt.clf() + loss_keys = ["loss"] + [k for k in results if "loss" in k and k not in ["loss", "val_loss"]] + for loss_key in loss_keys: + val_loss_key = f"val_{loss_key}" + plt.plot(results[loss_key]) + if val_loss_key in results: + plt.plot(results[val_loss_key]) + plt.title("Model JOBID=" + job_id + " " + loss_key) + plt.ylabel("Value") + plt.xlabel("Epoch") + if val_loss_key in results: + plt.legend([f"Train {loss_key}", f"Val. {loss_key}"], loc="upper left") + else: + plt.legend([f"Train {loss_key}"], loc="upper left") + plt.savefig(os.path.join(chartOutDir, job_id + "_" + loss_key + ".png")) + plt.clf() # Metric for i in range(len(metrics)): From 67ba7882c6ea7af37719e906c395c60e73d06ec4 Mon Sep 17 00:00:00 2001 From: Ibai Date: Wed, 8 Apr 2026 12:35:12 +0200 Subject: [PATCH 5/5] null check --- biapy/engine/train_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/biapy/engine/train_engine.py b/biapy/engine/train_engine.py index 46089f85d..54b689e2e 100644 --- a/biapy/engine/train_engine.py +++ b/biapy/engine/train_engine.py @@ -361,7 +361,7 @@ def evaluate( print("[Val] averaged stats:", metric_logger) # Apply reduceonplateau scheduler if the global validation has been reduced - if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": + if lr_scheduler and cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": for i, sched in enumerate(lr_scheduler): if sched and isinstance(sched, ReduceLROnPlateau): loss_name = "loss" if i == 0 else f"loss_{i}"