Skip to content

Commit 0526960

Browse files
add unit tests for anomalydino. change distance computation from cdist to matmul, work with half tensors
Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com>
1 parent 7d1216c commit 0526960

File tree

3 files changed

+124
-7
lines changed

3 files changed

+124
-7
lines changed

src/anomalib/models/image/anomaly_dino/torch_model.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,16 +251,24 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
251251
self.embedding_store.append(features)
252252
return torch.tensor(0.0, device=device, requires_grad=True)
253253

254+
# check bank isn't empty at inference
255+
if self.memory_bank.numel() == 0:
256+
msg = "Memory bank is empty. Run the model in training mode and call `fit()` before inference."
257+
raise RuntimeError(msg)
258+
259+
# Ensure dtype consistency
260+
if features.dtype != self.memory_bank.dtype:
261+
features = features.to(self.memory_bank.dtype)
262+
254263
# Inference
255264
# L2-normalized distances
256265
# memory_bank : [M, D], features : [Q, D]
257266

258-
# Compute pairwise distances [Q, M]
259-
dists = torch.cdist(features, self.memory_bank, p=2)
260-
261-
# Convert L2 to cosine distance
262-
# (since both vectors are normalized, divide by 2)
263-
dists = dists / 2.0
267+
# Compute cosine distance using matrix multiplication
268+
# both features and memory_bank are already L2-normalized.
269+
# cdist is not for half precision, but matmul is.
270+
similarity = torch.matmul(features, self.memory_bank.T) # [Q, M]
271+
dists = (torch.ones_like(similarity) - similarity).clamp(min=0.0, max=2.0) # cosine distance ∈ [0, 2]
264272

265273
# Get top-k nearest neighbors
266274
k = max(1, self.num_neighbours)
@@ -270,7 +278,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
270278
min_dists = topk_vals.mean(dim=1) if k > 1 else topk_vals.squeeze(1)
271279

272280
# Vectorized reconstruction
273-
distances_full = torch.zeros((b, grid_size[0] * grid_size[1]), device=device)
281+
distances_full = torch.zeros(
282+
(b, grid_size[0] * grid_size[1]),
283+
device=device,
284+
dtype=min_dists.dtype,
285+
)
274286
batch_idx, patch_idx = torch.nonzero(masks, as_tuple=True)
275287
distances_full[batch_idx, patch_idx] = min_dists
276288

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Unit tests for AnomalyDINO."""
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Unit tests for the WinCLIP torch model."""
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from _pytest.monkeypatch import MonkeyPatch
10+
11+
from anomalib.models.image.anomaly_dino.torch_model import AnomalyDINOModel
12+
13+
14+
class TestAnomalyDINOModel:
15+
"""Test the AnomalyDINO torch model."""
16+
17+
@staticmethod
18+
def test_initialization_defaults() -> None:
19+
"""Test initialization with default arguments."""
20+
model = AnomalyDINOModel()
21+
assert model.encoder_name.startswith("dinov2")
22+
assert model.memory_bank.numel() == 0
23+
24+
@staticmethod
25+
def test_invalid_encoder_name_raises() -> None:
26+
"""Test that invalid encoder names raise an error."""
27+
with pytest.raises(ValueError, match="Encoder must be dinov2"):
28+
_ = AnomalyDINOModel(encoder_name="resnet50")
29+
30+
@staticmethod
31+
def test_fit_raises_without_embeddings() -> None:
32+
"""Test that fit raises when no embeddings have been collected."""
33+
model = AnomalyDINOModel()
34+
with pytest.raises(ValueError, match="No embeddings collected"):
35+
model.fit()
36+
37+
@staticmethod
38+
def test_forward_train_adds_embeddings(monkeypatch: MonkeyPatch) -> None:
39+
"""Test training mode collects embeddings into store."""
40+
model = AnomalyDINOModel()
41+
model.train()
42+
43+
fake_features = torch.randn(2, 8, 128)
44+
monkeypatch.setattr(model, "extract_features", lambda _: fake_features)
45+
46+
x = torch.randn(2, 3, 224, 224)
47+
output = model(x)
48+
assert torch.is_tensor(output)
49+
assert output.requires_grad
50+
assert len(model.embedding_store) == 1
51+
assert model.embedding_store[0].ndim == 2
52+
53+
@staticmethod
54+
def test_forward_eval_raises_with_empty_memory_bank(monkeypatch: MonkeyPatch) -> None:
55+
"""Test that inference raises an error when memory bank is empty."""
56+
model = AnomalyDINOModel()
57+
model.eval()
58+
59+
fake_features = torch.randn(1, 16, 64)
60+
monkeypatch.setattr(model, "extract_features", lambda _: fake_features)
61+
model.register_buffer("memory_bank", torch.empty(0, 64))
62+
63+
x = torch.randn(1, 3, 224, 224)
64+
with pytest.raises(RuntimeError, match="Memory bank is empty"):
65+
_ = model(x)
66+
67+
@staticmethod
68+
def test_compute_background_masks_runs() -> None:
69+
"""Test that background mask computation produces boolean masks."""
70+
b, h, w, d = 2, 8, 8, 16
71+
features = np.random.randn(b, h * w, d).astype(np.float32) # noqa: NPY002
72+
masks = AnomalyDINOModel.compute_background_masks(features, (h, w))
73+
assert masks.shape == (b, h * w)
74+
assert masks.dtype == bool
75+
76+
@staticmethod
77+
def test_mean_top1p_computation() -> None:
78+
"""Test that mean_top1p returns expected shape and value."""
79+
distances = torch.arange(0, 100, dtype=torch.float32).view(1, -1)
80+
result = AnomalyDINOModel.mean_top1p(distances)
81+
assert result.shape == (1, 1)
82+
assert torch.allclose(result, torch.tensor([[99.0]]))
83+
84+
@staticmethod
85+
def test_forward_half_precision_eval(monkeypatch: MonkeyPatch) -> None:
86+
"""Test inference in half precision (float16) using matmul cosine distance."""
87+
model = AnomalyDINOModel().half()
88+
model.eval()
89+
90+
fake_features = torch.randn(1, 16, 64, dtype=torch.float16)
91+
monkeypatch.setattr(model, "extract_features", lambda _: fake_features)
92+
monkeypatch.setattr(model.anomaly_map_generator, "__call__", lambda x, __: x)
93+
94+
model.register_buffer("memory_bank", torch.randn(16, 64, dtype=torch.float16))
95+
x = torch.randn(1, 3, 224, 224, dtype=torch.float16)
96+
out = model(x)
97+
98+
assert hasattr(out, "pred_score")
99+
assert out.pred_score.shape == (1, 1)
100+
# outputs should be float16-safe with matmul
101+
assert out.pred_score.dtype == torch.float16

0 commit comments

Comments
 (0)