diff --git a/post_processing_nodes.py b/post_processing_nodes.py index dfd9322..48a8452 100644 --- a/post_processing_nodes.py +++ b/post_processing_nodes.py @@ -1332,20 +1332,27 @@ def INPUT_TYPES(s): def apply_vignette(self, image: torch.Tensor, vignette: float): if vignette == 0: return (image,) - height, width, _ = image.shape[-3:] + + batch_size, height, width, _ = image.shape + + # Build the distance grid — use indexing="xy" so shape is (height, width) x = torch.linspace(-1, 1, width, device=image.device) y = torch.linspace(-1, 1, height, device=image.device) - X, Y = torch.meshgrid(x, y, indexing="ij") + X, Y = torch.meshgrid(x, y, indexing="xy") radius = torch.sqrt(X ** 2 + Y ** 2) - radius = radius / torch.amax(radius, dim=(0, 1), keepdim=True) - opacity = torch.tensor(vignette, device=image.device) - opacity = torch.clamp(opacity, 0.0, 1.0) - vignette = 1 - radius.unsqueeze(0).unsqueeze(-1) * opacity + # Normalize to [0, 1] + radius = radius / radius.amax() - vignette_image = torch.clamp(image * vignette, 0, 1) + # Scale by strength — allow full input range instead of clamping to 1 + strength = min(vignette, 10.0) + vignette_mask = torch.clamp(1.0 - radius * strength, 0.0, 1.0) - return (vignette_image,) + # Broadcast over batch and channels: (1, H, W, 1) + vignette_mask = vignette_mask.unsqueeze(0).unsqueeze(-1) + + result = torch.clamp(image * vignette_mask, 0, 1) + return (result,) def gaussian_kernel(kernel_size: int, sigma: float): x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")