diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 765e289..dabe76e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -197,6 +197,28 @@ def run_full_pipeline( explain_with_captum=True, ) + # Test label attention assertions + if label_attention_enabled: + assert predictions["label_attention_attributions"] is not None, ( + "Label attention attributions should not be None when label_attention_enabled is True" + ) + label_attention_attributions = predictions["label_attention_attributions"] + expected_shape = ( + len(sample_text_data), # batch_size + model_params["n_head"], # n_head + model_params["num_classes"], # num_classes + tokenizer.output_dim, # seq_len + ) + assert label_attention_attributions.shape == expected_shape, ( + f"Label attention attributions shape mismatch. " + f"Expected {expected_shape}, got {label_attention_attributions.shape}" + ) + else: + # When label attention is not enabled, the attributions should be None + assert predictions.get("label_attention_attributions") is None, ( + "Label attention attributions should be None when label_attention_enabled is False" + ) + # Test explainability functions text_idx = 0 text = sample_text_data[text_idx]