From 154635fc56b92602f78b8e7286837957d0cab110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jules=20No=C3=ABl-Audoux?= <82194297+jules-noelaudoux@users.noreply.github.com> Date: Sun, 19 Apr 2026 23:29:59 +0200 Subject: [PATCH] fix: Vignette node resolution handling and batch processing Refactored the Vignette class to fix several critical bugs that prevented it from working correctly: - Switched meshgrid indexing to "xy" to fix transposed masks on non-square resolutions. - Removed a broken interpolation patch that caused runtime crashes on 4D/6D tensors. - Implemented proper batch dimension handling (B, H, W, C) for ComfyUI compatibility. - Fixed variable shadowing and updated the strength logic to utilize the full [0.0, 10.0] input range. - Added proper clamping to ensure output remains in valid [0.0, 1.0] image space. The node now works perfectly across any image resolution and batch size. --- post_processing_nodes.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/post_processing_nodes.py b/post_processing_nodes.py index dfd9322..48a8452 100644 --- a/post_processing_nodes.py +++ b/post_processing_nodes.py @@ -1332,20 +1332,27 @@ def INPUT_TYPES(s): def apply_vignette(self, image: torch.Tensor, vignette: float): if vignette == 0: return (image,) - height, width, _ = image.shape[-3:] + + batch_size, height, width, _ = image.shape + + # Build the distance grid — use indexing="xy" so shape is (height, width) x = torch.linspace(-1, 1, width, device=image.device) y = torch.linspace(-1, 1, height, device=image.device) - X, Y = torch.meshgrid(x, y, indexing="ij") + X, Y = torch.meshgrid(x, y, indexing="xy") radius = torch.sqrt(X ** 2 + Y ** 2) - radius = radius / torch.amax(radius, dim=(0, 1), keepdim=True) - opacity = torch.tensor(vignette, device=image.device) - opacity = torch.clamp(opacity, 0.0, 1.0) - vignette = 1 - radius.unsqueeze(0).unsqueeze(-1) * opacity + # Normalize to [0, 1] + radius = radius / radius.amax() - vignette_image = torch.clamp(image * vignette, 0, 1) + # Scale by strength — allow full input range instead of clamping to 1 + strength = min(vignette, 10.0) + vignette_mask = torch.clamp(1.0 - radius * strength, 0.0, 1.0) - return (vignette_image,) + # Broadcast over batch and channels: (1, H, W, 1) + vignette_mask = vignette_mask.unsqueeze(0).unsqueeze(-1) + + result = torch.clamp(image * vignette_mask, 0, 1) + return (result,) def gaussian_kernel(kernel_size: int, sigma: float): x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")