diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b02439d26..33cacc504 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -245,3 +245,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + pyhealth.datasets.ptbxl diff --git a/docs/api/datasets/pyhealth.datasets.ptbxl.rst b/docs/api/datasets/pyhealth.datasets.ptbxl.rst new file mode 100644 index 000000000..dc43ce9ee --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ptbxl.rst @@ -0,0 +1,7 @@ +pyhealth.datasets.ptbxl +======================= + +.. autoclass:: pyhealth.datasets.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..f63df4596 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + PTB-XL MI Classification \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst b/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst new file mode 100644 index 000000000..a4495c3be --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.ptbxl_mi_classification +====================================== + +.. autoclass:: pyhealth.tasks.PTBXLMIClassificationTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/ptbxl_mi_classification_cnn.py b/examples/ptbxl_mi_classification_cnn.py new file mode 100644 index 000000000..28f3d0723 --- /dev/null +++ b/examples/ptbxl_mi_classification_cnn.py @@ -0,0 +1,27 @@ +from pyhealth.datasets import PTBXLDataset +from pyhealth.tasks import PTBXLMIClassificationTask +import os + +def main(): + root = os.path.expanduser( + "~/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" + ) + dataset = PTBXLDataset( + root=root, + dev=True, + use_high_resolution=False, # False -> records100, True -> records500 + ) + + task = PTBXLMIClassificationTask( + root=root, + signal_length=1000, # 10 seconds at 100 Hz + normalize=True, + ) + task_dataset = dataset.set_task(task) + + print(task_dataset[0]) + print(f"Number of samples: {len(task_dataset)}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 54e77670c..e00bb968c 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .ptbxl import PTBXLDataset \ No newline at end of file diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py new file mode 100644 index 000000000..00102d948 --- /dev/null +++ b/pyhealth/datasets/ptbxl.py @@ -0,0 +1,49 @@ +import os +from typing import Optional + +import dask.dataframe as dd +import pandas as pd + +from pyhealth.datasets import BaseDataset + + +class PTBXLDataset(BaseDataset): + """PTB-XL ECG dataset represented as an event table.""" + + def __init__( + self, + root: str, + dataset_name: Optional[str] = "PTBXL", + dev: bool = False, + cache_dir: Optional[str] = None, + num_workers: int = 1, + use_high_resolution: bool = False, + ): + self.use_high_resolution = use_high_resolution + super().__init__( + root=root, + tables=["ptbxl"], + dataset_name=dataset_name, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, + ) + + def load_data(self) -> dd.DataFrame: + metadata_path = os.path.join(self.root, "ptbxl_database.csv") + df = pd.read_csv(metadata_path) + + record_path_col = "filename_hr" if self.use_high_resolution else "filename_lr" + + event_df = pd.DataFrame( + { + "patient_id": df["patient_id"].astype(str), + "event_type": "ptbxl", + "timestamp": pd.NaT, + "ptbxl/ecg_id": df["ecg_id"], + "ptbxl/record_path": df[record_path_col], + "ptbxl/scp_codes": df["scp_codes"], + } + ) + + return dd.from_pandas(event_df, npartitions=1) \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..0e9b70b15 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .ptbxl_mi_classification import PTBXLMIClassificationTask \ No newline at end of file diff --git a/pyhealth/tasks/ptbxl_mi_classification.py b/pyhealth/tasks/ptbxl_mi_classification.py new file mode 100644 index 000000000..552563fa8 --- /dev/null +++ b/pyhealth/tasks/ptbxl_mi_classification.py @@ -0,0 +1,135 @@ +"""PTBXL MI classification task for PyHealth. + +This module defines a task that loads PTB-XL ECG records, maps SCP +diagnostic codes to myocardial infarction (MI) labels, and returns one +binary-labeled sample per record. +""" + +import ast +import os +from typing import Dict, List + +import numpy as np +import pandas as pd +import wfdb + +from pyhealth.tasks import BaseTask + + +class PTBXLMIClassificationTask(BaseTask): + """Task for classifying myocardial infarction (MI) in PTB-XL ECG records. + + This task converts the PTB-XL SCP diagnostic codes into a binary MI label + and loads the corresponding ECG signal for each record. + + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): Input schema mapping signal to tensor. + output_schema (Dict[str, str]): Output schema mapping label to binary. + """ + + task_name = "ptbxl_mi_classification" + input_schema = { + "signal": "tensor", + } + output_schema = { + "label": "binary", + } + + def __init__( + self, + root: str, + signal_length: int = 1000, + normalize: bool = True, + ): + """Initialize the PTBXL MI classification task. + + Args: + root: PTB-XL dataset root directory containing `scp_statements.csv`. + signal_length: Number of samples to use for each ECG signal. + normalize: Whether to z-score normalize each ECG channel. + """ + + self.root = root + self.signal_length = signal_length + self.normalize = normalize + + scp_path = os.path.join(self.root, "scp_statements.csv") + scp_df = pd.read_csv(scp_path, index_col=0) + self.mi_codes = set( + scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist() + ) + + def _load_ecg_signal(self, record_rel_path: str) -> np.ndarray: + """Loads a PTB-XL WFDB record and returns shape (12, signal_length).""" + record_path = os.path.join(self.root, record_rel_path) + + # WFDB expects the record path without file extension. + signal, _ = wfdb.rdsamp(record_path) + + # rdsamp returns shape (num_samples, num_channels) + signal = signal.T.astype(np.float32) # -> (channels, time) + + if self.normalize: + mean = signal.mean(axis=1, keepdims=True) + std = signal.std(axis=1, keepdims=True) + std = np.where(std < 1e-6, 1.0, std) + signal = (signal - mean) / std + + current_len = signal.shape[1] + if current_len >= self.signal_length: + signal = signal[:, : self.signal_length] + else: + pad_width = self.signal_length - current_len + signal = np.pad(signal, ((0, 0), (0, pad_width)), mode="constant") + + return signal + + def __call__(self, patient) -> List[Dict]: + """Generate PTB-XL MI samples from a patient record. + + Args: + patient: Patient object containing PTB-XL event data. + + Returns: + A list of sample dictionaries with keys: + - patient_id + - visit_id + - record_id + - signal + - label + """ + + samples = [] + + rows = patient.data_source.to_dicts() + + for idx, row in enumerate(rows): + raw_label = row["ptbxl/scp_codes"] + record_rel_path = row["ptbxl/record_path"] + + try: + scp_codes = ( + ast.literal_eval(raw_label) + if isinstance(raw_label, str) + else raw_label + ) + except (ValueError, SyntaxError): + scp_codes = {} + + label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0 + signal = self._load_ecg_signal(record_rel_path) + + visit_id = str(row["ptbxl/ecg_id"]) + + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": visit_id, + "record_id": idx + 1, + "signal": signal.tolist(), + "label": label, + } + ) + + return samples \ No newline at end of file diff --git a/tests/core/test_ptbxl_dataset.py b/tests/core/test_ptbxl_dataset.py new file mode 100644 index 000000000..111599494 --- /dev/null +++ b/tests/core/test_ptbxl_dataset.py @@ -0,0 +1,36 @@ +import os +import tempfile +import unittest + +from pyhealth.datasets import PTBXLDataset + + +class TestPTBXLDataset(unittest.TestCase): + def test_load_data_dev_mode(self): + with tempfile.TemporaryDirectory() as tmpdir: + csv_path = os.path.join(tmpdir, "ptbxl_database.csv") + + with open(csv_path, "w") as f: + f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n") + f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n') + f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n') + + dataset = PTBXLDataset( + root=tmpdir, + dev=True, + ) + + df = dataset.load_data().compute() + + self.assertEqual(len(df), 2) + self.assertIn("patient_id", df.columns) + self.assertIn("event_type", df.columns) + self.assertIn("ptbxl/ecg_id", df.columns) + self.assertIn("ptbxl/record_path", df.columns) + self.assertIn("ptbxl/scp_codes", df.columns) + self.assertEqual(str(df.iloc[0]["patient_id"]), "100") + self.assertEqual(df.iloc[0]["event_type"], "ptbxl") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_ptbxl_mi_classification.py b/tests/core/test_ptbxl_mi_classification.py new file mode 100644 index 000000000..dee838d19 --- /dev/null +++ b/tests/core/test_ptbxl_mi_classification.py @@ -0,0 +1,63 @@ +import os +import tempfile +import unittest +from unittest.mock import patch + +import numpy as np +import pandas as pd +import polars as pl + +from pyhealth.tasks.ptbxl_mi_classification import PTBXLMIClassificationTask +from pyhealth.data import Patient + + +class TestPTBXLTask(unittest.TestCase): + @patch.object(PTBXLMIClassificationTask, "_load_ecg_signal") + def test_mi_label_extraction(self, mock_load_signal): + mock_load_signal.return_value = np.zeros((12, 1000), dtype=np.float32) + + with tempfile.TemporaryDirectory() as tmpdir: + scp_path = os.path.join(tmpdir, "scp_statements.csv") + + # minimal synthetic SCP mapping + scp_df = pd.DataFrame( + { + "diagnostic_class": ["MI", "NORM"], + }, + index=["IMI", "NORM"], + ) + scp_df.to_csv(scp_path) + + df = pd.DataFrame( + { + "patient_id": ["1", "1"], + "event_type": ["ptbxl", "ptbxl"], + "timestamp": [None, None], + "ptbxl/ecg_id": [100, 101], + "ptbxl/record_path": [ + "records100/00000/00001_lr", + "records100/00000/00002_lr", + ], + "ptbxl/scp_codes": [ + "{'IMI': 1}", + "{'NORM': 1}", + ], + } + ) + + patient = Patient( + patient_id="1", + data_source=pl.from_pandas(df), + ) + + task = PTBXLMIClassificationTask(root=tmpdir) + samples = task(patient) + + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["label"], 1) + self.assertEqual(samples[1]["label"], 0) + self.assertEqual(np.array(samples[0]["signal"]).shape, (12, 1000)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file