diff --git a/tests/core/test_binary_metrics.py b/tests/core/test_binary_metrics.py new file mode 100644 index 000000000..22f9b7788 --- /dev/null +++ b/tests/core/test_binary_metrics.py @@ -0,0 +1,120 @@ +import unittest + +import numpy as np + +from pyhealth.metrics import binary_metrics_fn + + +class TestBinaryMetrics(unittest.TestCase): + """Test cases for binary classification metrics.""" + + def setUp(self): + """Set up synthetic binary classification data.""" + np.random.seed(42) + self.y_true = np.array([0, 0, 1, 1, 0, 1]) + self.y_prob = np.array([0.1, 0.4, 0.35, 0.8, 0.2, 0.9]) + + def test_default_metrics(self): + """Test that default metrics (pr_auc, roc_auc, f1) are returned.""" + result = binary_metrics_fn(self.y_true, self.y_prob) + self.assertIn("pr_auc", result) + self.assertIn("roc_auc", result) + self.assertIn("f1", result) + self.assertEqual(len(result), 3) + + def test_accuracy(self): + """Test accuracy metric with known values.""" + result = binary_metrics_fn( + self.y_true, self.y_prob, metrics=["accuracy"], + ) + self.assertIn("accuracy", result) + self.assertIsInstance(result["accuracy"], float) + self.assertGreaterEqual(result["accuracy"], 0.0) + self.assertLessEqual(result["accuracy"], 1.0) + + def test_all_classification_metrics(self): + """Test that all supported classification metrics compute.""" + all_metrics = [ + "pr_auc", "roc_auc", "accuracy", "balanced_accuracy", + "f1", "precision", "recall", "cohen_kappa", "jaccard", + ] + result = binary_metrics_fn( + self.y_true, self.y_prob, metrics=all_metrics, + ) + for metric in all_metrics: + self.assertIn(metric, result) + self.assertIsInstance(result[metric], float) + + @unittest.skip( + "ece_confidence_binary expects 2D arrays but binary_metrics_fn " + "passes 1D - see calibration.py:150" + ) + def test_calibration_metrics(self): + """Test ECE and adaptive ECE metrics.""" + result = binary_metrics_fn( + self.y_true, self.y_prob, metrics=["ECE", "ECE_adapt"], + ) + self.assertIn("ECE", result) + self.assertIn("ECE_adapt", result) + self.assertGreaterEqual(result["ECE"], 0.0) + self.assertGreaterEqual(result["ECE_adapt"], 0.0) + + def test_perfect_predictions(self): + """Test metrics with perfect predictions.""" + y_true = np.array([0, 0, 1, 1]) + y_prob = np.array([0.0, 0.0, 1.0, 1.0]) + result = binary_metrics_fn( + y_true, y_prob, metrics=["accuracy", "f1"], + ) + self.assertEqual(result["accuracy"], 1.0) + self.assertEqual(result["f1"], 1.0) + + def test_custom_threshold(self): + """Test that custom threshold changes predictions.""" + result_low = binary_metrics_fn( + self.y_true, self.y_prob, + metrics=["accuracy"], threshold=0.3, + ) + result_high = binary_metrics_fn( + self.y_true, self.y_prob, + metrics=["accuracy"], threshold=0.9, + ) + # Different thresholds should generally give different results + self.assertIsInstance(result_low["accuracy"], float) + self.assertIsInstance(result_high["accuracy"], float) + + def test_metric_values_in_range(self): + """Test that all metric values are in valid ranges.""" + all_metrics = [ + "pr_auc", "roc_auc", "accuracy", "balanced_accuracy", + "f1", "precision", "recall", "jaccard", + ] + result = binary_metrics_fn( + self.y_true, self.y_prob, metrics=all_metrics, + ) + for metric in all_metrics: + self.assertGreaterEqual( + result[metric], 0.0, f"{metric} below 0", + ) + self.assertLessEqual( + result[metric], 1.0, f"{metric} above 1", + ) + + def test_unknown_metric_raises(self): + """Test that unknown metric name raises ValueError.""" + with self.assertRaises(ValueError): + binary_metrics_fn( + self.y_true, self.y_prob, metrics=["nonexistent"], + ) + + def test_single_metric(self): + """Test requesting a single metric.""" + result = binary_metrics_fn( + self.y_true, self.y_prob, metrics=["roc_auc"], + ) + self.assertEqual(len(result), 1) + self.assertIn("roc_auc", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_embedding_model.py b/tests/core/test_embedding_model.py new file mode 100644 index 000000000..9d45abceb --- /dev/null +++ b/tests/core/test_embedding_model.py @@ -0,0 +1,427 @@ +import unittest + +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models.embedding import EmbeddingModel + + +class TestEmbeddingModelSequence(unittest.TestCase): + """Test EmbeddingModel with sequence processor inputs.""" + + def setUp(self): + """Set up test data and model with sequence inputs.""" + torch.manual_seed(42) + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "diagnoses": ["A", "B", "C"], + "procedures": ["X", "Y"], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "diagnoses": ["D", "E"], + "procedures": ["Y"], + "label": 0, + }, + ] + + self.input_schema = { + "diagnoses": "sequence", + "procedures": "sequence", + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = EmbeddingModel( + dataset=self.dataset, embedding_dim=32, + ) + + def test_initialization(self): + """Test that the EmbeddingModel initializes correctly.""" + self.assertIsInstance(self.model, EmbeddingModel) + self.assertEqual(self.model.embedding_dim, 32) + self.assertIn("diagnoses", self.model.embedding_layers) + self.assertIn("procedures", self.model.embedding_layers) + + def test_embedding_layers_are_correct_type(self): + """Test that sequence inputs use nn.Embedding layers.""" + self.assertIsInstance( + self.model.embedding_layers["diagnoses"], torch.nn.Embedding, + ) + self.assertIsInstance( + self.model.embedding_layers["procedures"], torch.nn.Embedding, + ) + + def test_forward_output_shapes(self): + """Test that forward pass produces correct output shapes.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + inputs = {} + masks = {} + for key in ["diagnoses", "procedures"]: + feature = data_batch[key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors[key].schema() + inputs[key] = feature[schema.index("value")] + if "mask" in schema: + masks[key] = feature[schema.index("mask")] + + with torch.no_grad(): + embedded = self.model(inputs, masks=masks) + + self.assertIn("diagnoses", embedded) + self.assertIn("procedures", embedded) + self.assertEqual(embedded["diagnoses"].shape[-1], 32) + self.assertEqual(embedded["procedures"].shape[-1], 32) + self.assertEqual(embedded["diagnoses"].shape[0], 2) + + def test_forward_with_output_mask(self): + """Test that forward pass returns masks when requested.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + inputs = {} + masks = {} + for key in ["diagnoses", "procedures"]: + feature = data_batch[key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors[key].schema() + inputs[key] = feature[schema.index("value")] + if "mask" in schema: + masks[key] = feature[schema.index("mask")] + + with torch.no_grad(): + embedded, out_masks = self.model( + inputs, masks=masks, output_mask=True, + ) + + self.assertIsInstance(embedded, dict) + self.assertIsInstance(out_masks, dict) + self.assertIn("diagnoses", out_masks) + self.assertIn("procedures", out_masks) + + def test_gradients_flow(self): + """Test that gradients flow through the embedding layers.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + inputs = {} + masks = {} + for key in ["diagnoses", "procedures"]: + feature = data_batch[key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors[key].schema() + inputs[key] = feature[schema.index("value")] + if "mask" in schema: + masks[key] = feature[schema.index("mask")] + + embedded = self.model(inputs, masks=masks) + loss = sum(v.sum() for v in embedded.values()) + loss.backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_gradient) + + +class TestEmbeddingModelTensor(unittest.TestCase): + """Test EmbeddingModel with tensor processor inputs.""" + + def setUp(self): + """Set up test data and model with tensor inputs.""" + torch.manual_seed(42) + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "labs": [1.0, 2.0, 3.0], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "labs": [4.0, 5.0, 6.0], + "label": 0, + }, + ] + + self.input_schema = {"labs": "tensor"} + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = EmbeddingModel( + dataset=self.dataset, embedding_dim=16, + ) + + def test_initialization(self): + """Test that tensor inputs use nn.Linear layers.""" + self.assertIn("labs", self.model.embedding_layers) + self.assertIsInstance( + self.model.embedding_layers["labs"], torch.nn.Linear, + ) + + def test_forward_output_shape(self): + """Test that forward pass produces correct output shapes.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + feature = data_batch["labs"] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors["labs"].schema() + inputs = {"labs": feature[schema.index("value")]} + + with torch.no_grad(): + embedded = self.model(inputs) + + self.assertIn("labs", embedded) + self.assertEqual(embedded["labs"].shape[-1], 16) + self.assertEqual(embedded["labs"].shape[0], 2) + + +class TestEmbeddingModelMultiHot(unittest.TestCase): + """Test EmbeddingModel with multi_hot processor inputs.""" + + def setUp(self): + """Set up test data and model with multi-hot inputs.""" + torch.manual_seed(42) + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "demographics": ["asian", "male"], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "demographics": ["white", "female"], + "label": 0, + }, + ] + + self.input_schema = {"demographics": "multi_hot"} + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = EmbeddingModel( + dataset=self.dataset, embedding_dim=16, + ) + + def test_initialization(self): + """Test that multi-hot inputs use nn.Linear layers.""" + self.assertIn("demographics", self.model.embedding_layers) + self.assertIsInstance( + self.model.embedding_layers["demographics"], torch.nn.Linear, + ) + + def test_forward_output_shape(self): + """Test that forward pass produces correct output shapes.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + feature = data_batch["demographics"] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors["demographics"].schema() + inputs = {"demographics": feature[schema.index("value")]} + + with torch.no_grad(): + embedded = self.model(inputs) + + self.assertIn("demographics", embedded) + self.assertEqual(embedded["demographics"].shape[-1], 16) + self.assertEqual(embedded["demographics"].shape[0], 2) + + +class TestEmbeddingModelNestedSequence(unittest.TestCase): + """Test EmbeddingModel with nested_sequence processor inputs.""" + + def setUp(self): + """Set up test data and model with nested sequence inputs.""" + torch.manual_seed(42) + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": [["A", "B"], ["C", "D", "E"]], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": [["F"], ["G", "H"]], + "label": 0, + }, + ] + + self.input_schema = {"conditions": "nested_sequence"} + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = EmbeddingModel( + dataset=self.dataset, embedding_dim=16, + ) + + def test_initialization(self): + """Test that nested sequence inputs use nn.Embedding layers.""" + self.assertIn("conditions", self.model.embedding_layers) + self.assertIsInstance( + self.model.embedding_layers["conditions"], torch.nn.Embedding, + ) + + def test_forward_output_shape(self): + """Test that forward pass produces correct output shapes.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + feature = data_batch["conditions"] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors["conditions"].schema() + inputs = {"conditions": feature[schema.index("value")]} + masks = {} + if "mask" in schema: + masks["conditions"] = feature[schema.index("mask")] + + with torch.no_grad(): + embedded = self.model(inputs, masks=masks) + + self.assertIn("conditions", embedded) + self.assertEqual(embedded["conditions"].shape[-1], 16) + self.assertEqual(embedded["conditions"].shape[0], 2) + + +class TestEmbeddingModelMixedInputs(unittest.TestCase): + """Test EmbeddingModel with mixed processor types.""" + + def setUp(self): + """Set up test data and model with mixed input types.""" + torch.manual_seed(42) + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "diagnoses": ["A", "B", "C"], + "labs": [1.0, 2.0, 3.0], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "diagnoses": ["D", "E"], + "labs": [4.0, 5.0, 6.0], + "label": 0, + }, + ] + + self.input_schema = { + "diagnoses": "sequence", + "labs": "tensor", + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = EmbeddingModel( + dataset=self.dataset, embedding_dim=32, + ) + + def test_mixed_initialization(self): + """Test that mixed inputs use appropriate layer types.""" + self.assertIsInstance( + self.model.embedding_layers["diagnoses"], torch.nn.Embedding, + ) + self.assertIsInstance( + self.model.embedding_layers["labs"], torch.nn.Linear, + ) + + def test_mixed_forward(self): + """Test forward pass with mixed input types.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + inputs = {} + masks = {} + for key in ["diagnoses", "labs"]: + feature = data_batch[key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors[key].schema() + inputs[key] = feature[schema.index("value")] + if "mask" in schema: + masks[key] = feature[schema.index("mask")] + + with torch.no_grad(): + embedded = self.model(inputs, masks=masks) + + self.assertEqual(embedded["diagnoses"].shape[-1], 32) + self.assertEqual(embedded["labs"].shape[-1], 32) + + def test_custom_embedding_dim(self): + """Test EmbeddingModel with a custom embedding dimension.""" + model = EmbeddingModel(dataset=self.dataset, embedding_dim=64) + self.assertEqual(model.embedding_dim, 64) + + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(loader)) + + inputs = {} + masks = {} + for key in ["diagnoses", "labs"]: + feature = data_batch[key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + schema = self.dataset.input_processors[key].schema() + inputs[key] = feature[schema.index("value")] + if "mask" in schema: + masks[key] = feature[schema.index("mask")] + + with torch.no_grad(): + embedded = model(inputs, masks=masks) + + self.assertEqual(embedded["diagnoses"].shape[-1], 64) + self.assertEqual(embedded["labs"].shape[-1], 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_gan.py b/tests/core/test_gan.py new file mode 100644 index 000000000..1b4edf460 --- /dev/null +++ b/tests/core/test_gan.py @@ -0,0 +1,146 @@ +import unittest + +import torch + +from pyhealth.models import GAN + + +class TestGAN32(unittest.TestCase): + """Test GAN model with 32x32 input size.""" + + def setUp(self): + """Set up GAN model with 32x32 single-channel input.""" + torch.manual_seed(42) + self.model = GAN(input_channel=1, input_size=32, hidden_dim=16) + + def test_initialization(self): + """Test that the GAN model initializes correctly.""" + self.assertIsInstance(self.model, GAN) + self.assertEqual(self.model.hidden_dim, 16) + self.assertIsNotNone(self.model.discriminator) + self.assertIsNotNone(self.model.generator) + + def test_discriminator_output_shape(self): + """Test that the discriminator produces correct output shape.""" + x = torch.randn(2, 1, 32, 32) + with torch.no_grad(): + out = self.model.discriminate(x) + self.assertEqual(out.shape, (2, 1)) + + def test_discriminator_output_range(self): + """Test that discriminator outputs are in [0, 1] range.""" + x = torch.randn(2, 1, 32, 32) + with torch.no_grad(): + out = self.model.discriminate(x) + self.assertTrue(torch.all(out >= 0)) + self.assertTrue(torch.all(out <= 1)) + + def test_generate_fake_shape(self): + """Test that the generator produces correct output shape.""" + with torch.no_grad(): + fake = self.model.generate_fake(n_samples=3, device="cpu") + self.assertEqual(fake.shape[0], 3) + self.assertEqual(fake.shape[1], 1) + + def test_generate_fake_pixel_range(self): + """Test that generated pixels are in [0, 1] range.""" + with torch.no_grad(): + fake = self.model.generate_fake(n_samples=2, device="cpu") + self.assertTrue(torch.all(fake >= 0)) + self.assertTrue(torch.all(fake <= 1)) + + def test_sampling_shape(self): + """Test that latent sampling produces correct shape.""" + eps = self.model.sampling(n_samples=4, device="cpu") + self.assertEqual(eps.shape, (4, 16, 1, 1)) + + def test_discriminator_backward(self): + """Test that gradients flow through the discriminator.""" + x = torch.randn(2, 1, 32, 32) + out = self.model.discriminate(x) + loss = out.mean() + loss.backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.discriminator.parameters() + ) + self.assertTrue(has_gradient) + + def test_generator_backward(self): + """Test that gradients flow through the generator.""" + fake = self.model.generate_fake(n_samples=2, device="cpu") + loss = fake.mean() + loss.backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.generator.parameters() + ) + self.assertTrue(has_gradient) + + +class TestGAN64(unittest.TestCase): + """Test GAN model with 64x64 input size.""" + + def setUp(self): + """Set up GAN model with 64x64 single-channel input.""" + torch.manual_seed(42) + self.model = GAN(input_channel=1, input_size=64, hidden_dim=16) + + def test_discriminator_output_shape(self): + """Test that the discriminator produces correct output shape.""" + x = torch.randn(2, 1, 64, 64) + with torch.no_grad(): + out = self.model.discriminate(x) + self.assertEqual(out.shape, (2, 1)) + + def test_generate_fake_shape(self): + """Test that the generator produces correct output shape.""" + with torch.no_grad(): + fake = self.model.generate_fake(n_samples=2, device="cpu") + self.assertEqual(fake.shape[0], 2) + self.assertEqual(fake.shape[1], 1) + + def test_end_to_end(self): + """Test the full generate-then-discriminate pipeline.""" + with torch.no_grad(): + fake = self.model.generate_fake(n_samples=2, device="cpu") + score = self.model.discriminate(fake) + self.assertEqual(score.shape, (2, 1)) + + +class TestGAN128(unittest.TestCase): + """Test GAN model with 128x128 input size.""" + + def setUp(self): + """Set up GAN model with 128x128 three-channel input.""" + torch.manual_seed(42) + self.model = GAN(input_channel=3, input_size=128, hidden_dim=16) + + def test_discriminator_output_shape(self): + """Test that the discriminator produces correct output shape.""" + x = torch.randn(2, 3, 128, 128) + with torch.no_grad(): + out = self.model.discriminate(x) + self.assertEqual(out.shape, (2, 1)) + + def test_generate_fake_shape(self): + """Test that the generator produces correct output shape.""" + with torch.no_grad(): + fake = self.model.generate_fake(n_samples=2, device="cpu") + self.assertEqual(fake.shape[0], 2) + self.assertEqual(fake.shape[1], 3) + + def test_multichannel_end_to_end(self): + """Test the full pipeline with multi-channel images.""" + with torch.no_grad(): + fake = self.model.generate_fake(n_samples=2, device="cpu") + score = self.model.discriminate(fake) + self.assertEqual(score.shape, (2, 1)) + self.assertTrue(torch.all(score >= 0)) + self.assertTrue(torch.all(score <= 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_label_processor.py b/tests/core/test_label_processor.py new file mode 100644 index 000000000..5710d86fb --- /dev/null +++ b/tests/core/test_label_processor.py @@ -0,0 +1,132 @@ +import unittest + +import torch + +from pyhealth.processors.label_processor import ( + BinaryLabelProcessor, + MultiClassLabelProcessor, +) + + +class TestBinaryLabelProcessor(unittest.TestCase): + """Test cases for the BinaryLabelProcessor.""" + + def test_fit_with_int_labels(self): + """Test fitting with integer 0/1 labels.""" + processor = BinaryLabelProcessor() + samples = [{"label": 0}, {"label": 1}, {"label": 0}] + processor.fit(samples, "label") + self.assertEqual(processor.label_vocab, {0: 0, 1: 1}) + + def test_fit_with_bool_labels(self): + """Test fitting with boolean labels.""" + processor = BinaryLabelProcessor() + samples = [{"label": True}, {"label": False}] + processor.fit(samples, "label") + self.assertEqual(processor.label_vocab, {False: 0, True: 1}) + + def test_fit_with_string_labels(self): + """Test fitting with string labels.""" + processor = BinaryLabelProcessor() + samples = [{"label": "yes"}, {"label": "no"}, {"label": "yes"}] + processor.fit(samples, "label") + self.assertEqual(len(processor.label_vocab), 2) + + def test_fit_non_binary_raises(self): + """Test that fitting with 3+ classes raises ValueError.""" + processor = BinaryLabelProcessor() + samples = [{"label": 0}, {"label": 1}, {"label": 2}] + with self.assertRaises(ValueError): + processor.fit(samples, "label") + + def test_process_returns_tensor(self): + """Test that process returns a float tensor.""" + processor = BinaryLabelProcessor() + samples = [{"label": 0}, {"label": 1}] + processor.fit(samples, "label") + result = processor.process(1) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.dtype, torch.float32) + self.assertEqual(result.shape, (1,)) + + def test_process_correct_values(self): + """Test that process maps labels correctly.""" + processor = BinaryLabelProcessor() + samples = [{"label": 0}, {"label": 1}] + processor.fit(samples, "label") + self.assertEqual(processor.process(0).item(), 0.0) + self.assertEqual(processor.process(1).item(), 1.0) + + def test_size(self): + """Test that size returns 1 for binary labels.""" + processor = BinaryLabelProcessor() + samples = [{"label": 0}, {"label": 1}] + processor.fit(samples, "label") + self.assertEqual(processor.size(), 1) + + def test_schema(self): + """Test that schema returns ('value',).""" + processor = BinaryLabelProcessor() + self.assertEqual(processor.schema(), ("value",)) + + def test_is_not_token(self): + """Test that binary labels are not token-based.""" + processor = BinaryLabelProcessor() + self.assertFalse(processor.is_token()) + + +class TestMultiClassLabelProcessor(unittest.TestCase): + """Test cases for the MultiClassLabelProcessor.""" + + def test_fit_with_int_labels(self): + """Test fitting with sequential integer labels.""" + processor = MultiClassLabelProcessor() + samples = [{"label": 0}, {"label": 1}, {"label": 2}] + processor.fit(samples, "label") + self.assertEqual(processor.label_vocab, {0: 0, 1: 1, 2: 2}) + + def test_fit_with_string_labels(self): + """Test fitting with string labels.""" + processor = MultiClassLabelProcessor() + samples = [ + {"label": "cat"}, {"label": "dog"}, {"label": "bird"}, + ] + processor.fit(samples, "label") + self.assertEqual(len(processor.label_vocab), 3) + + def test_process_returns_tensor(self): + """Test that process returns a long tensor.""" + processor = MultiClassLabelProcessor() + samples = [{"label": 0}, {"label": 1}, {"label": 2}] + processor.fit(samples, "label") + result = processor.process(1) + self.assertIsInstance(result, torch.Tensor) + + def test_process_correct_mapping(self): + """Test that labels map to sequential indices.""" + processor = MultiClassLabelProcessor() + samples = [ + {"label": "cat"}, {"label": "dog"}, {"label": "bird"}, + ] + processor.fit(samples, "label") + indices = set() + for label in ["cat", "dog", "bird"]: + idx = processor.process(label).item() + indices.add(idx) + self.assertEqual(indices, {0, 1, 2}) + + def test_size(self): + """Test that size returns number of classes.""" + processor = MultiClassLabelProcessor() + samples = [{"label": 0}, {"label": 1}, {"label": 2}] + processor.fit(samples, "label") + self.assertEqual(processor.size(), 3) + + def test_schema(self): + """Test that schema returns ('value',).""" + processor = MultiClassLabelProcessor() + self.assertEqual(processor.schema(), ("value",)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_multiclass_metrics.py b/tests/core/test_multiclass_metrics.py new file mode 100644 index 000000000..acbcac213 --- /dev/null +++ b/tests/core/test_multiclass_metrics.py @@ -0,0 +1,142 @@ +import unittest + +import numpy as np + +from pyhealth.metrics import multiclass_metrics_fn + + +class TestMulticlassMetrics(unittest.TestCase): + """Test cases for multiclass classification metrics.""" + + def setUp(self): + """Set up synthetic multiclass classification data.""" + np.random.seed(42) + self.y_true = np.array([0, 1, 2, 2, 0, 1]) + self.y_prob = np.array([ + [0.8, 0.1, 0.1], + [0.1, 0.8, 0.1], + [0.1, 0.1, 0.8], + [0.4, 0.3, 0.3], + [0.7, 0.2, 0.1], + [0.2, 0.6, 0.2], + ]) + + def test_default_metrics(self): + """Test that default metrics are returned.""" + result = multiclass_metrics_fn(self.y_true, self.y_prob) + self.assertIn("accuracy", result) + self.assertIn("f1_macro", result) + self.assertIn("f1_micro", result) + self.assertEqual(len(result), 3) + + def test_accuracy(self): + """Test accuracy metric with known values.""" + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=["accuracy"], + ) + self.assertIn("accuracy", result) + self.assertIsInstance(result["accuracy"], float) + self.assertGreaterEqual(result["accuracy"], 0.0) + self.assertLessEqual(result["accuracy"], 1.0) + + def test_all_f1_variants(self): + """Test all F1 score averaging methods.""" + metrics = ["f1_micro", "f1_macro", "f1_weighted"] + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=metrics, + ) + for m in metrics: + self.assertIn(m, result) + self.assertGreaterEqual(result[m], 0.0) + self.assertLessEqual(result[m], 1.0) + + def test_roc_auc_variants(self): + """Test all ROC AUC averaging methods.""" + metrics = [ + "roc_auc_macro_ovo", "roc_auc_macro_ovr", + "roc_auc_weighted_ovo", "roc_auc_weighted_ovr", + ] + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=metrics, + ) + for m in metrics: + self.assertIn(m, result) + self.assertGreaterEqual(result[m], 0.0) + self.assertLessEqual(result[m], 1.0) + + def test_jaccard_variants(self): + """Test all Jaccard score averaging methods.""" + metrics = ["jaccard_micro", "jaccard_macro", "jaccard_weighted"] + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=metrics, + ) + for m in metrics: + self.assertIn(m, result) + self.assertGreaterEqual(result[m], 0.0) + self.assertLessEqual(result[m], 1.0) + + def test_calibration_metrics(self): + """Test calibration metrics (ECE, brier_top1).""" + metrics = ["ECE", "ECE_adapt", "brier_top1"] + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=metrics, + ) + for m in metrics: + self.assertIn(m, result) + self.assertIsInstance(result[m], float) + self.assertGreaterEqual(result[m], 0.0) + + def test_classwise_ece(self): + """Test classwise ECE metrics.""" + metrics = ["cwECEt", "cwECEt_adapt"] + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=metrics, + ) + for m in metrics: + self.assertIn(m, result) + self.assertGreaterEqual(result[m], 0.0) + + def test_cohen_kappa(self): + """Test Cohen's kappa score.""" + result = multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=["cohen_kappa"], + ) + self.assertIn("cohen_kappa", result) + self.assertGreaterEqual(result["cohen_kappa"], -1.0) + self.assertLessEqual(result["cohen_kappa"], 1.0) + + def test_perfect_predictions(self): + """Test metrics with perfect predictions.""" + y_true = np.array([0, 1, 2]) + y_prob = np.array([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ]) + result = multiclass_metrics_fn( + y_true, y_prob, metrics=["accuracy"], + ) + self.assertEqual(result["accuracy"], 1.0) + + def test_unknown_metric_raises(self): + """Test that unknown metric name raises ValueError.""" + with self.assertRaises(ValueError): + multiclass_metrics_fn( + self.y_true, self.y_prob, metrics=["nonexistent"], + ) + + def test_hits_and_rank_metrics(self): + """Test hits@n and mean_rank metrics.""" + result = multiclass_metrics_fn( + self.y_true, self.y_prob, + metrics=["hits@n", "mean_rank"], + ) + self.assertIn("HITS@1", result) + self.assertIn("HITS@5", result) + self.assertIn("HITS@10", result) + self.assertIn("mean_rank", result) + self.assertIn("mean_reciprocal_rank", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_regression_metrics.py b/tests/core/test_regression_metrics.py new file mode 100644 index 000000000..231ca4cf8 --- /dev/null +++ b/tests/core/test_regression_metrics.py @@ -0,0 +1,90 @@ +import unittest + +import numpy as np + +from pyhealth.metrics import regression_metrics_fn + + +class TestRegressionMetrics(unittest.TestCase): + """Test cases for regression metrics.""" + + def setUp(self): + """Set up synthetic regression data.""" + np.random.seed(42) + self.x = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + self.x_rec = np.array([1.1, 2.2, 2.8, 4.1, 4.9]) + + def test_default_metrics(self): + """Test that default metrics are returned.""" + result = regression_metrics_fn(self.x, self.x_rec) + self.assertIn("kl_divergence", result) + self.assertIn("mse", result) + self.assertIn("mae", result) + self.assertEqual(len(result), 3) + + def test_mse(self): + """Test mean squared error with known values.""" + result = regression_metrics_fn( + self.x, self.x_rec, metrics=["mse"], + ) + self.assertIn("mse", result) + self.assertIsInstance(result["mse"], float) + self.assertGreaterEqual(result["mse"], 0.0) + + def test_mae(self): + """Test mean absolute error with known values.""" + result = regression_metrics_fn( + self.x, self.x_rec, metrics=["mae"], + ) + self.assertIn("mae", result) + self.assertIsInstance(result["mae"], float) + self.assertGreaterEqual(result["mae"], 0.0) + + def test_kl_divergence(self): + """Test KL divergence metric.""" + result = regression_metrics_fn( + self.x, self.x_rec, metrics=["kl_divergence"], + ) + self.assertIn("kl_divergence", result) + self.assertIsInstance(result["kl_divergence"], float) + + def test_perfect_reconstruction(self): + """Test that identical arrays yield zero MSE and MAE.""" + x = np.array([1.0, 2.0, 3.0]) + result = regression_metrics_fn(x, x, metrics=["mse", "mae"]) + self.assertAlmostEqual(result["mse"], 0.0) + self.assertAlmostEqual(result["mae"], 0.0) + + def test_shape_mismatch_raises(self): + """Test that mismatched shapes raise ValueError.""" + x = np.array([1.0, 2.0, 3.0]) + x_rec = np.array([1.0, 2.0]) + with self.assertRaises(ValueError): + regression_metrics_fn(x, x_rec) + + def test_unknown_metric_raises(self): + """Test that unknown metric name raises ValueError.""" + with self.assertRaises(ValueError): + regression_metrics_fn( + self.x, self.x_rec, metrics=["nonexistent"], + ) + + def test_2d_arrays(self): + """Test that 2D arrays are handled via flattening.""" + x = np.array([[1.0, 2.0], [3.0, 4.0]]) + x_rec = np.array([[1.1, 2.1], [3.1, 4.1]]) + result = regression_metrics_fn(x, x_rec, metrics=["mse", "mae"]) + self.assertIn("mse", result) + self.assertIn("mae", result) + self.assertGreater(result["mse"], 0.0) + + def test_single_metric(self): + """Test requesting a single metric.""" + result = regression_metrics_fn( + self.x, self.x_rec, metrics=["mae"], + ) + self.assertEqual(len(result), 1) + + +if __name__ == "__main__": + unittest.main()