diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index b02439d26..bef9e5d61 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -226,6 +226,7 @@ Available Datasets
datasets/pyhealth.datasets.MIMIC4Dataset
datasets/pyhealth.datasets.MedicalTranscriptionsDataset
datasets/pyhealth.datasets.CardiologyDataset
+ datasets/pyhealth.datasets.Cardiology2Dataset
datasets/pyhealth.datasets.eICUDataset
datasets/pyhealth.datasets.ISRUCDataset
datasets/pyhealth.datasets.MIMICExtractDataset
diff --git a/docs/api/datasets/pyhealth.datasets.Cardiology2Dataset.rst b/docs/api/datasets/pyhealth.datasets.Cardiology2Dataset.rst
new file mode 100644
index 000000000..10a829a51
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.Cardiology2Dataset.rst
@@ -0,0 +1,11 @@
+pyhealth.datasets.Cardiology2Dataset
+=====================================
+
+The PhysioNet/Computing in Cardiology Challenge 2020 dataset of 12-lead ECG recordings.
+
+For more information, refer to `PhysioNet page `.
+
+.. autoclass:: pyhealth.datasets.Cardiology2Dataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index d85d04bc3..74d0a7ef8 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -209,6 +209,7 @@ Available Tasks
In-Hospital Mortality (MIMIC-IV)
MIMIC-III ICD-9 Coding
Cardiology Detection
+ Cardiology Multilabel Classification
COVID-19 CXR Classification
DKA Prediction (MIMIC-IV)
Drug Recommendation
diff --git a/docs/api/tasks/pyhealth.tasks.CardiologyMultilabelClassification.rst b/docs/api/tasks/pyhealth.tasks.CardiologyMultilabelClassification.rst
new file mode 100644
index 000000000..6e2343b5f
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.CardiologyMultilabelClassification.rst
@@ -0,0 +1,12 @@
+pyhealth.tasks.CardiologyMultilabelClassification
+==================================================
+
+Multi-label ECG classification over 24 SNOMED-CT diagnosis codes from the
+PhysioNet/Computing in Cardiology Challenge 2020 dataset. The task follows
+the benchmark protocol of `Nonaka & Seita (2021) `,
+evaluated with macro-averaged ROC-AUC.
+
+.. autoclass:: pyhealth.tasks.CardiologyMultilabelClassification
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/cardiology_multilabel.ipynb b/examples/cardiology_multilabel.ipynb
new file mode 100644
index 000000000..fe8d26c49
--- /dev/null
+++ b/examples/cardiology_multilabel.ipynb
@@ -0,0 +1,603 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "72da6a5d",
+ "metadata": {},
+ "source": [
+ "# Cardiology Multilabel Classification\n",
+ "\n",
+ "This notebook shows how to build a multilabel ECG classification pipeline with the PhysioNet 2020 cardiology dataset in PyHealth. It loads the dataset, applies the `CardiologyMultilabelClassification` task, inspects example samples, and includes a small Member 2 sanity check that verifies a model can run a forward pass and backpropagate on a tiny 5-sample subset.\n",
+ "\n",
+ "Download link: https://physionet.org/content/challenge-2020/1.0.2/\n",
+ "You'll need to run the following in terminal:\n",
+ "\n",
+ "`cd ~/data && wget -r -N -c -np https://physionet.org/files/challenge-2020/1.0.2/`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c153769",
+ "metadata": {},
+ "source": [
+ "## 1. Install Deps"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "71cd4a61",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install -q -e ..\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "26615505",
+ "metadata": {},
+ "source": [
+ "## 2. Load the Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6eb9ddba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pathlib import Path\n",
+ "\n",
+ "DATA_ROOT = str(Path.home() / \"data\" / \"physionet.org\" / \"files\" / \"challenge-2020\" / \"1.0.2\" / \"training\")\n",
+ "print(f\"DATA_ROOT = {DATA_ROOT}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1e2444ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Optional: clear the cache to force a full rebuild\n",
+ "# If you decide to change any of the core code this will be necessary to pick up changes\n",
+ "import shutil, os\n",
+ "cache_dir = \"/tmp/pyhealth_cardiology\"\n",
+ "if os.path.exists(cache_dir):\n",
+ " shutil.rmtree(cache_dir)\n",
+ " print(f\"Cleared cache: {cache_dir}\")\n",
+ "else:\n",
+ " print(f\"No cache found at {cache_dir}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a09a77de",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyhealth.datasets import Cardiology2Dataset\n",
+ "\n",
+ "dataset = Cardiology2Dataset(\n",
+ " root=DATA_ROOT,\n",
+ " chosen_dataset=[1, 1, 0, 0, 0, 0], # Only load cpsc_2018 datasets\n",
+ " dev=True,\n",
+ " cache_dir=\"/tmp/pyhealth_cardiology\"\n",
+ ")\n",
+ "\n",
+ "dataset.stats()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8c0e5c91",
+ "metadata": {},
+ "source": [
+ "## 3. Apply the Multilabel Classification Task"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d14b2002",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyhealth.tasks import CardiologyMultilabelClassification\n",
+ "\n",
+ "task = CardiologyMultilabelClassification()\n",
+ "sample_dataset = dataset.set_task(task)\n",
+ "\n",
+ "print(f\"Total samples: {len(sample_dataset)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a8d42d18",
+ "metadata": {},
+ "source": [
+ "## 4. Inspect Sample"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f68361cc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sample = sample_dataset[0]\n",
+ "print(f\"keys: {list(sample.keys())}\")\n",
+ "print(f\"patient_id: {sample['patient_id']}\")\n",
+ "print(f\"visit_id: {sample['visit_id']}\")\n",
+ "print(f\"signal: {sample['signal']}\")\n",
+ "print(f\"labels: {sample['labels']}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5d2290c9-1e4b-4a6b-ad2c-ce1d75f59a45",
+ "metadata": {},
+ "source": [
+ "## 5. Member 2 Sanity Check"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8f3f79f1-f3f0-44f6-80c4-a8c6fcaee5af",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %% [markdown]\n",
+ "# # Cardiology Multilabel Classification with ResNet-18\n",
+ "#\n",
+ "# This notebook extends the cardiology multilabel example to:\n",
+ "# 1. load the PhysioNet 2020 cardiology dataset\n",
+ "# 2. apply the Cardiology multilabel task\n",
+ "# 3. split the data into train/val/test sets\n",
+ "# 4. convert ECG windows into pseudo-images for torchvision ResNet-18\n",
+ "# 5. train and evaluate a ResNet-18 baseline\n",
+ "# 6. run a simple lead ablation experiment\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 1. Install Deps\n",
+ "\n",
+ "# %%\n",
+ "%pip install -q -e ..\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 2. Imports and Config\n",
+ "\n",
+ "# %%\n",
+ "from pathlib import Path\n",
+ "import copy\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "\n",
+ "from pyhealth.datasets import (\n",
+ " Cardiology2Dataset,\n",
+ " create_sample_dataset,\n",
+ " get_dataloader,\n",
+ " split_by_patient,\n",
+ ")\n",
+ "from pyhealth.models import TorchvisionModel\n",
+ "from pyhealth.tasks import CardiologyMultilabelClassification\n",
+ "from pyhealth.trainer import Trainer\n",
+ "\n",
+ "SEED = 42\n",
+ "BATCH_SIZE = 32\n",
+ "EPOCHS = 5\n",
+ "LR = 1e-3\n",
+ "TRAIN_RATIOS = [0.7, 0.1, 0.2]\n",
+ "USE_PRETRAINED = False\n",
+ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "STEPS_PER_EPOCH = 500\n",
+ "\n",
+ "torch.manual_seed(SEED)\n",
+ "np.random.seed(SEED)\n",
+ "\n",
+ "DATA_ROOT = str(\n",
+ " Path.home() / \"data\" / \"physionet.org\" / \"files\" / \"challenge-2020\" / \"1.0.2\" / \"training\"\n",
+ ")\n",
+ "CACHE_DIR = \"/tmp/pyhealth_cardiology\"\n",
+ "\n",
+ "print(\"DATA_ROOT =\", DATA_ROOT)\n",
+ "print(\"DEVICE =\", DEVICE)\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 3. Load the Dataset\n",
+ "#\n",
+ "# Set `DEV = False` for the real experiment.\n",
+ "# Set `DEV = True` if you only want a quick smoke test.\n",
+ "\n",
+ "# %%\n",
+ "DEV = False\n",
+ "\n",
+ "dataset = Cardiology2Dataset(\n",
+ " root=DATA_ROOT,\n",
+ " chosen_dataset=[1, 1, 0, 0, 0, 0], # adjust if you want more subsets\n",
+ " dev=DEV,\n",
+ " cache_dir=CACHE_DIR,\n",
+ ")\n",
+ "\n",
+ "dataset.stats()\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 4. Apply the Cardiology Multilabel Task\n",
+ "\n",
+ "# %%\n",
+ "task = CardiologyMultilabelClassification()\n",
+ "sample_dataset = dataset.set_task(task)\n",
+ "\n",
+ "print(\"Total windowed samples:\", len(sample_dataset))\n",
+ "print(\"Number of labels:\", sample_dataset.output_processors[\"labels\"].size())\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 5. Inspect One Sample\n",
+ "\n",
+ "# %%\n",
+ "sample = sample_dataset[0]\n",
+ "print(\"keys :\", list(sample.keys()))\n",
+ "print(\"patient_id:\", sample[\"patient_id\"])\n",
+ "print(\"visit_id :\", sample[\"visit_id\"])\n",
+ "print(\"signal shape:\", np.asarray(sample[\"signal\"]).shape)\n",
+ "print(\"labels :\", sample[\"labels\"])\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 6. Split into Train / Val / Test\n",
+ "#\n",
+ "# We split by patient to reduce leakage across windows from the same recording/patient.\n",
+ "\n",
+ "# %%\n",
+ "train_dataset, val_dataset, test_dataset = split_by_patient(\n",
+ " sample_dataset,\n",
+ " ratios=TRAIN_RATIOS,\n",
+ " seed=SEED,\n",
+ ")\n",
+ "\n",
+ "print(\"Train samples:\", len(train_dataset))\n",
+ "print(\"Val samples :\", len(val_dataset))\n",
+ "print(\"Test samples :\", len(test_dataset))\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 7. Convert ECG Windows to Pseudo-Images for torchvision\n",
+ "#\n",
+ "# Each ECG window starts as `(12, 1250)`.\n",
+ "# We convert it to `(1, 12, 1250)`, which lets the PyHealth torchvision wrapper\n",
+ "# repeat the single channel to 3 channels internally for ResNet-18.\n",
+ "\n",
+ "# %%\n",
+ "def _multihot_to_label_list(labels, labels_processor):\n",
+ " # MultiLabelProcessor expects SNOMED code strings; task samples already use tensors.\n",
+ " if isinstance(labels, list):\n",
+ " return labels\n",
+ " idx_to_code = {i: code for code, i in labels_processor.label_vocab.items()}\n",
+ " if hasattr(labels, \"detach\"):\n",
+ " flat = labels.detach().cpu().reshape(-1)\n",
+ " return [\n",
+ " idx_to_code[int(j)]\n",
+ " for j in range(flat.numel())\n",
+ " if float(flat[j]) > 0.5\n",
+ " ]\n",
+ " flat = np.asarray(labels, dtype=np.float64).reshape(-1)\n",
+ " return [idx_to_code[int(j)] for j in range(flat.size) if flat[j] > 0.5]\n",
+ "\n",
+ "\n",
+ "def to_resnet_ready_samples(dataset_split, labels_processor):\n",
+ " converted = []\n",
+ " for i in range(len(dataset_split)):\n",
+ " s = copy.deepcopy(dataset_split[i])\n",
+ " signal = np.asarray(s[\"signal\"], dtype=np.float32)\n",
+ "\n",
+ " # Per-window z-score normalization\n",
+ " signal = (signal - signal.mean()) / (signal.std() + 1e-8)\n",
+ "\n",
+ " # Convert from (12, 1250) to (1, 12, 1250)\n",
+ " s[\"signal\"] = np.expand_dims(signal, axis=0)\n",
+ " s[\"labels\"] = _multihot_to_label_list(s[\"labels\"], labels_processor)\n",
+ " converted.append(s)\n",
+ " return converted\n",
+ "\n",
+ "\n",
+ "_labels_processor = sample_dataset.output_processors[\"labels\"]\n",
+ "train_samples_img = to_resnet_ready_samples(train_dataset, _labels_processor)\n",
+ "val_samples_img = to_resnet_ready_samples(val_dataset, _labels_processor)\n",
+ "test_samples_img = to_resnet_ready_samples(test_dataset, _labels_processor)\n",
+ "\n",
+ "shared_output_processors = sample_dataset.output_processors\n",
+ "\n",
+ "train_img_dataset = create_sample_dataset(\n",
+ " samples=train_samples_img,\n",
+ " input_schema={\"signal\": \"tensor\"},\n",
+ " output_schema=task.output_schema,\n",
+ " dataset_name=\"cardiology_resnet_train\",\n",
+ " output_processors=shared_output_processors,\n",
+ ")\n",
+ "\n",
+ "val_img_dataset = create_sample_dataset(\n",
+ " samples=val_samples_img,\n",
+ " input_schema={\"signal\": \"tensor\"},\n",
+ " output_schema=task.output_schema,\n",
+ " dataset_name=\"cardiology_resnet_val\",\n",
+ " output_processors=shared_output_processors,\n",
+ ")\n",
+ "\n",
+ "test_img_dataset = create_sample_dataset(\n",
+ " samples=test_samples_img,\n",
+ " input_schema={\"signal\": \"tensor\"},\n",
+ " output_schema=task.output_schema,\n",
+ " dataset_name=\"cardiology_resnet_test\",\n",
+ " output_processors=shared_output_processors,\n",
+ ")\n",
+ "\n",
+ "batch = next(iter(get_dataloader(train_img_dataset, batch_size=4, shuffle=False)))\n",
+ "print(\"ResNet input batch shape:\", tuple(batch[\"signal\"].shape))\n",
+ "print(\"Target batch shape :\", tuple(batch[\"labels\"].shape))\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 8. Build ResNet-18 from torchvision\n",
+ "\n",
+ "# %%\n",
+ "weights = \"DEFAULT\" if USE_PRETRAINED else None\n",
+ "\n",
+ "model = TorchvisionModel(\n",
+ " dataset=train_img_dataset,\n",
+ " model_name=\"resnet18\",\n",
+ " model_config={\"weights\": weights},\n",
+ ")\n",
+ "\n",
+ "model\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 9. Build Dataloaders\n",
+ "\n",
+ "# %%\n",
+ "train_loader = get_dataloader(train_img_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
+ "val_loader = get_dataloader(val_img_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
+ "test_loader = get_dataloader(test_img_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
+ "\n",
+ "_n_train_batches = len(train_loader)\n",
+ "if STEPS_PER_EPOCH is not None:\n",
+ " print(\n",
+ " f\"train_loader: {_n_train_batches} batches/epoch; \"\n",
+ " f\"steps_per_epoch={STEPS_PER_EPOCH}\"\n",
+ " )\n",
+ "else:\n",
+ " print(\n",
+ " f\"train_loader: {_n_train_batches} batches/epoch (full pass; slow on CPU)\"\n",
+ " )\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 10. Train ResNet-18\n",
+ "#\n",
+ "# The paper description mentions macro ROC-AUC, so we monitor that here.\n",
+ "\n",
+ "# %%\n",
+ "metrics = [\n",
+ " \"roc_auc_macro\",\n",
+ " \"pr_auc_macro\",\n",
+ " \"f1_macro\",\n",
+ " \"jaccard_macro\",\n",
+ " \"hamming_loss\",\n",
+ "]\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " metrics=metrics,\n",
+ " device=DEVICE,\n",
+ " output_path=\"./output\",\n",
+ " exp_name=\"cardiology_multilabel_resnet18\",\n",
+ ")\n",
+ "\n",
+ "trainer.train(\n",
+ " train_dataloader=train_loader,\n",
+ " val_dataloader=val_loader,\n",
+ " epochs=EPOCHS,\n",
+ " optimizer_params={\"lr\": LR},\n",
+ " steps_per_epoch=STEPS_PER_EPOCH,\n",
+ " monitor=\"roc_auc_macro\",\n",
+ " monitor_criterion=\"max\",\n",
+ " load_best_model_at_last=True,\n",
+ ")\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 11. Evaluate on Validation and Test Sets\n",
+ "\n",
+ "# %%\n",
+ "val_results = trainer.evaluate(val_loader)\n",
+ "test_results = trainer.evaluate(test_loader)\n",
+ "\n",
+ "print(\"Validation Results\")\n",
+ "print(pd.Series(val_results).sort_index())\n",
+ "\n",
+ "print(\"\\nTest Results\")\n",
+ "print(pd.Series(test_results).sort_index())\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 12. Optional: Compare Against the Paper\n",
+ "#\n",
+ "# Fill in the paper numbers once you confirm the exact table/setting you want to reproduce.\n",
+ "\n",
+ "# %%\n",
+ "paper_results = {\n",
+ " # \"roc_auc_macro\": ...,\n",
+ " # \"pr_auc_macro\": ...,\n",
+ "}\n",
+ "\n",
+ "if paper_results:\n",
+ " comparison = pd.DataFrame(\n",
+ " {\n",
+ " \"paper\": pd.Series(paper_results),\n",
+ " \"our_resnet18\": pd.Series(test_results),\n",
+ " }\n",
+ " )\n",
+ " comparison[\"delta\"] = comparison[\"our_resnet18\"] - comparison[\"paper\"]\n",
+ " display(comparison.sort_index())\n",
+ "else:\n",
+ " print(\"Add the paper metrics here once you identify the exact comparison table.\")\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 13. Ablation: 12-Lead vs 1-Lead\n",
+ "#\n",
+ "# This reruns the same pipeline with a different `leads` argument in the task.\n",
+ "\n",
+ "# %%\n",
+ "def run_resnet_experiment(\n",
+ " leads,\n",
+ " experiment_name,\n",
+ " dev=DEV,\n",
+ " epochs=EPOCHS,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " lr=LR,\n",
+ " use_pretrained=USE_PRETRAINED,\n",
+ " steps_per_epoch=STEPS_PER_EPOCH,\n",
+ "):\n",
+ " task = CardiologyMultilabelClassification(leads=leads)\n",
+ "\n",
+ " base_dataset = Cardiology2Dataset(\n",
+ " root=DATA_ROOT,\n",
+ " chosen_dataset=[1, 1, 0, 0, 0, 0],\n",
+ " dev=dev,\n",
+ " cache_dir=CACHE_DIR,\n",
+ " )\n",
+ "\n",
+ " full_dataset = base_dataset.set_task(task)\n",
+ "\n",
+ " train_ds, val_ds, test_ds = split_by_patient(\n",
+ " full_dataset,\n",
+ " ratios=TRAIN_RATIOS,\n",
+ " seed=SEED,\n",
+ " )\n",
+ "\n",
+ " labels_processor = full_dataset.output_processors[\"labels\"]\n",
+ "\n",
+ " def convert(split_ds):\n",
+ " out = []\n",
+ " for i in range(len(split_ds)):\n",
+ " s = copy.deepcopy(split_ds[i])\n",
+ " signal = np.asarray(s[\"signal\"], dtype=np.float32)\n",
+ " signal = (signal - signal.mean()) / (signal.std() + 1e-8)\n",
+ " s[\"signal\"] = np.expand_dims(signal, axis=0)\n",
+ " s[\"labels\"] = _multihot_to_label_list(s[\"labels\"], labels_processor)\n",
+ " out.append(s)\n",
+ " return out\n",
+ "\n",
+ " shared_output_processors = full_dataset.output_processors\n",
+ "\n",
+ " train_img = create_sample_dataset(\n",
+ " samples=convert(train_ds),\n",
+ " input_schema={\"signal\": \"tensor\"},\n",
+ " output_schema=task.output_schema,\n",
+ " dataset_name=f\"{experiment_name}_train\",\n",
+ " output_processors=shared_output_processors,\n",
+ " )\n",
+ " val_img = create_sample_dataset(\n",
+ " samples=convert(val_ds),\n",
+ " input_schema={\"signal\": \"tensor\"},\n",
+ " output_schema=task.output_schema,\n",
+ " dataset_name=f\"{experiment_name}_val\",\n",
+ " output_processors=shared_output_processors,\n",
+ " )\n",
+ " test_img = create_sample_dataset(\n",
+ " samples=convert(test_ds),\n",
+ " input_schema={\"signal\": \"tensor\"},\n",
+ " output_schema=task.output_schema,\n",
+ " dataset_name=f\"{experiment_name}_test\",\n",
+ " output_processors=shared_output_processors,\n",
+ " )\n",
+ "\n",
+ " train_loader = get_dataloader(train_img, batch_size=batch_size, shuffle=True)\n",
+ " val_loader = get_dataloader(val_img, batch_size=batch_size, shuffle=False)\n",
+ " test_loader = get_dataloader(test_img, batch_size=batch_size, shuffle=False)\n",
+ "\n",
+ " weights = \"DEFAULT\" if use_pretrained else None\n",
+ " model = TorchvisionModel(\n",
+ " dataset=train_img,\n",
+ " model_name=\"resnet18\",\n",
+ " model_config={\"weights\": weights},\n",
+ " )\n",
+ "\n",
+ " trainer = Trainer(\n",
+ " model=model,\n",
+ " metrics=metrics,\n",
+ " device=DEVICE,\n",
+ " output_path=\"./output\",\n",
+ " exp_name=experiment_name,\n",
+ " )\n",
+ "\n",
+ " trainer.train(\n",
+ " train_dataloader=train_loader,\n",
+ " val_dataloader=val_loader,\n",
+ " epochs=epochs,\n",
+ " optimizer_params={\"lr\": lr},\n",
+ " steps_per_epoch=steps_per_epoch,\n",
+ " monitor=\"roc_auc_macro\",\n",
+ " monitor_criterion=\"max\",\n",
+ " load_best_model_at_last=True,\n",
+ " )\n",
+ "\n",
+ " return trainer.evaluate(test_loader)\n",
+ "\n",
+ "# %%\n",
+ "ablation_results = {\n",
+ " \"12-lead\": run_resnet_experiment(\n",
+ " leads=list(range(12)),\n",
+ " experiment_name=\"cardiology_resnet18_12lead\",\n",
+ " ),\n",
+ " \"1-lead\": run_resnet_experiment(\n",
+ " leads=[0],\n",
+ " experiment_name=\"cardiology_resnet18_1lead\",\n",
+ " ),\n",
+ "}\n",
+ "\n",
+ "ablation_df = pd.DataFrame(ablation_results).T\n",
+ "display(ablation_df)\n",
+ "\n",
+ "# %% [markdown]\n",
+ "# ## 14. Clean Up Temporary In-Memory Datasets\n",
+ "\n",
+ "# %%\n",
+ "for ds in [train_img_dataset, val_img_dataset, test_img_dataset]:\n",
+ " ds.close()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "03f898b7-4cc6-4fc3-be40-6863b279d5b2",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.13 (PyHealth)",
+ "language": "python",
+ "name": "pyhealth"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/cardiology_multilabel_resnet_lead_ablation.py b/examples/cardiology_multilabel_resnet_lead_ablation.py
new file mode 100644
index 000000000..416efa524
--- /dev/null
+++ b/examples/cardiology_multilabel_resnet_lead_ablation.py
@@ -0,0 +1,135 @@
+"""
+Spatial Feature Ablation Study: 12-Lead Clinical ECG vs. 1-Lead Wearable ECG
+
+This script demonstrates how varying the spatial feature dimensions (number of ECG leads)
+affects the input shape and predictive framework of a PyHealth model.
+
+1. Clinical Context & Objective
+-------------------------------
+Standard clinical ECGs utilize 12 distinct leads to capture the electrical activity of the
+heart from multiple spatial angles. However, modern wearable devices (like smartwatches)
+typically only capture a single lead (equivalent to Lead I).
+
+This ablation study benchmarks the structural impact of transitioning from a 12-lead to a
+1-lead setup. By isolating the 'leads' parameter in the CardiologyMultilabelClassification
+task, we evaluate how PyHealth's native ResNet architecture adapts to the loss of spatial
+projection vectors.
+
+2. Experimental Setup
+---------------------
+- Dataset: Synthetic data generated to mimic the PhysioNet/CinC Challenge 2020 format.
+- Task Configuration 1 (Baseline): All 12 leads utilized. Input shape is (12, 1250).
+- Task Configuration 2 (Ablation): Only Lead I (index 0) utilized. Input shape is (1, 1250).
+- Model: PyHealth's native ResNet, initialized dynamically based on the dataset's feature space.
+
+3. Expected Findings
+--------------------
+While models trained on 1-lead data might maintain robust performance for rhythm-based
+abnormalities (like Atrial Fibrillation), their performance is expected to degrade significantly
+for morphology-based diagnoses that rely on spatial axes, such as Bundle Branch Blocks (LBBB/RBBB)
+or Axis Deviations.
+"""
+
+import os
+import shutil
+import numpy as np
+import pandas as pd
+from scipy.io import savemat
+
+# PyHealth Imports
+from pyhealth.datasets import Cardiology2Dataset, get_dataloader
+from pyhealth.tasks import CardiologyMultilabelClassification
+from pyhealth.models import ResNet
+
+def generate_synthetic_data(root_dir: str, num_patients: int = 3):
+ """Generates synthetic .mat and .hea files to simulate the PhysioNet dataset."""
+ patient_dir = os.path.join(root_dir, "cpsc_2018", "g1")
+ os.makedirs(patient_dir, exist_ok=True)
+
+ # 164934002 = T wave abnormality, 426783006 = Sinus rhythm
+ sample_dx = "426783006,164934002"
+
+ for i in range(num_patients):
+ mat_path = os.path.join(patient_dir, f"A{i:04d}.mat")
+ hea_path = os.path.join(patient_dir, f"A{i:04d}.hea")
+
+ # 12 leads, 10 seconds at 500Hz = 5000 samples
+ synthetic_signal = np.random.randn(12, 5000)
+ savemat(mat_path, {"val": synthetic_signal})
+
+ with open(hea_path, "w") as f:
+ f.write(f"A{i:04d} 12 500 5000\n")
+ f.write("# Age: 63\n")
+ f.write("# Sex: Male\n")
+ f.write(f"# Dx: {sample_dx}\n")
+
+def run_ablation_experiment():
+ print("Initializing Spatial Feature Ablation Study...")
+ SYNTHETIC_ROOT = "/tmp/synthetic_cardiology_data"
+ CACHE_DIR = "/tmp/pyhealth_cache_ablation"
+
+ if os.path.exists(SYNTHETIC_ROOT):
+ shutil.rmtree(SYNTHETIC_ROOT)
+ generate_synthetic_data(SYNTHETIC_ROOT)
+
+ results = []
+
+ # Define our two configurations for the ablation study
+ configs = {
+ "12-Lead (Clinical)": list(range(12)),
+ "1-Lead (Wearable)": [0]
+ }
+
+ for setup_name, leads in configs.items():
+ # 1. Load Dataset (using dev=True to minimize overhead)
+ dataset = Cardiology2Dataset(
+ root=SYNTHETIC_ROOT,
+ chosen_dataset=[1, 0, 0, 0, 0, 0],
+ cache_dir=CACHE_DIR,
+ dev=True
+ )
+
+ # 2. Apply Task with varying feature dimensions
+ task = CardiologyMultilabelClassification(leads=leads)
+ sample_dataset = dataset.set_task(task)
+
+ # 3. Initialize PyHealth Dataloader
+ dataloader = get_dataloader(sample_dataset, batch_size=2, dev=True)
+ batch = next(iter(dataloader))
+
+ # 4. Initialize native PyHealth Model
+ # ResNet automatically adapts to the feature dimension defined in the dataset
+ model = ResNet(
+ dataset=sample_dataset,
+ feature_keys=["signal"],
+ label_key="labels",
+ mode="multilabel"
+ )
+
+ # 5. Forward pass through the PyHealth model
+ out = model(**batch)
+
+ # Record structural findings
+ signal_shape = batch["signal"].shape
+ results.append({
+ "Configuration": setup_name,
+ "Input Channels": signal_shape[1],
+ "Batch Input Shape": tuple(signal_shape),
+ "Loss Output Type": type(out["loss"]).__name__,
+ "Logits Shape": tuple(out["y_prob"].shape),
+ "Model Parameters": sum(p.numel() for p in model.parameters())
+ })
+
+ # Clean up synthetic data
+ shutil.rmtree(SYNTHETIC_ROOT)
+
+ # Output tabular findings
+ df = pd.DataFrame(results)
+ print("\n" + "="*80)
+ print("ABLATION STUDY: PYHEALTH RESNET FEATURE VARIATION RESULTS")
+ print("="*80)
+ print(df.to_string(index=False))
+ print("="*80)
+
+if __name__ == "__main__":
+ run_ablation_experiment()
\ No newline at end of file
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 7ac05f259..43d6faa38 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -48,6 +48,7 @@ def __init__(self, *args, **kwargs):
from .base_dataset import BaseDataset
from .cardiology import CardiologyDataset
+from .cardiology2 import Cardiology2Dataset
from .chestxray14 import ChestXray14Dataset
from .clinvar import ClinVarDataset
from .cosmic import COSMICDataset
diff --git a/pyhealth/datasets/cardiology2.py b/pyhealth/datasets/cardiology2.py
new file mode 100644
index 000000000..733062534
--- /dev/null
+++ b/pyhealth/datasets/cardiology2.py
@@ -0,0 +1,276 @@
+"""
+PyHealth dataset for the PhysioNet/Computing in Cardiology Challenge 2020.
+
+Dataset link:
+ https://physionet.org/content/challenge-2020/1.0.2/
+
+Dataset paper: (please cite if you use this dataset)
+ Perez Alday EA, Gu A, Shah AJ, Robichaux C, Wong AI, Liu C, Liu F,
+ Rad AB, Elola A, Seyedi S, Li Q, Sharma A, Clifford GD, Reyna MA.
+ Classification of 12-lead ECGs: the PhysioNet/Computing in Cardiology Challenge 2020.
+ Physiol Meas. 2020 Nov 11. http://doi.org/10.1088/1361-6579/abc960.
+
+Dataset resource:
+ Perez Alday, E. A., Gu, A., Shah, A., Liu, C., Sharma, A., Seyedi, S.,
+ Bahrami Rad, A., Reyna, M., & Clifford, G. (2022). Classification of
+ 12-lead ECGs: The PhysioNet/Computing in Cardiology Challenge 2020 (version 1.0.2).
+ PhysioNet. RRID:SCR_007345. https://doi.org/10.13026/dvyd-kd57
+
+PhysioNet:
+ Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R.,
+ ... & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet:
+ Components of a new research resource for complex physiologic signals.
+ Circulation [Online]. 101 (23), pp. e215–e220. RRID:SCR_007345.
+
+Author:
+ John Ma (jm119@illinois.edu)
+"""
+
+import logging
+import os
+from pathlib import Path
+from typing import Dict, List, Optional
+
+import pandas as pd
+
+from pyhealth.datasets import BaseDataset
+
+logger = logging.getLogger(__name__)
+
+SUBDATASET_NAMES: List[str] = [
+ "cpsc_2018",
+ "cpsc_2018_extra",
+ "georgia",
+ "ptb",
+ "ptb-xl",
+ "st_petersburg_incart",
+]
+
+class Cardiology2Dataset(BaseDataset):
+ """Dataset class for the PhysioNet/CinC Challenge 2020 12-lead ECG data.
+
+ The dataset bundles six sub-collections of 12-lead ECG recordings stored as
+ MATLAB '.mat' files with companion '.hea' header files containing
+ SNOMED-CT diagnosis codes, patient sex, and patient age.
+
+ Dataset is available at:
+ https://physionet.org/content/challenge-2020/1.0.2/
+
+ Args:
+ root (str): Root directory of the raw data, e.g.
+ '"/data/physionet.org/files/challenge-2020/1.0.2/training"'.
+ chosen_dataset (List[int]): Binary list of length 6 indicating which
+ sub-datasets to include. Indices correspond to:
+ '["cpsc_2018", "cpsc_2018_extra", "georgia", "ptb", "ptb-xl", "st_petersburg_incart"]'.
+ Default: '[1, 1, 1, 1, 1, 1]' (all six).
+ config_path (Optional[str]): Path to the YAML config file. Defaults to
+ the bundled 'configs/cardiology.yaml'.
+
+ Attributes:
+ classes (List[str]): Union of common SNOMED-CT diagnosis codes across
+ five symptom categories (AR, BBBFB, AD, CD, WA).
+ chosen_dataset (List[int]): The sub-dataset selection mask.
+
+ Examples:
+ >>> from pyhealth.datasets import Cardiology2Dataset
+ >>> dataset = Cardiology2Dataset(
+ ... root="/data/physionet.org/files/challenge-2020/1.0.2/training",
+ ... )
+ >>> dataset.stats()
+ """
+
+ """
+ Classes:
+ Source: https://github.com/physionetchallenges/evaluation-2020/blob/master/dx_mapping_scored.csv
+ """
+ classes: List[str] = [
+ "270492004",
+ "164889003",
+ "164890007",
+ "426627000",
+ "713427006",
+ "713426002",
+ "445118002",
+ "39732003",
+ "164909002",
+ "251146004",
+ "698252002",
+ "10370003",
+ "284470004",
+ "427172004",
+ "164947007",
+ "111975006",
+ "164917005",
+ "47665007",
+ "59118001",
+ "427393009",
+ "426177001",
+ "426783006",
+ "427084000",
+ "63593006",
+ "164934002",
+ "59931005",
+ "17338001",
+ ]
+
+ def __init__(
+ self,
+ root: str,
+ chosen_dataset: List[int] = [1, 1, 1, 1, 1, 1],
+ config_path: Optional[str] = str(
+ Path(__file__).parent / "configs" / "cardiology.yaml"
+ ),
+ **kwargs,
+ ) -> None:
+ if len(chosen_dataset) != 6 or not all(v in (0, 1) for v in chosen_dataset):
+ raise ValueError(
+ "chosen_dataset must be a binary list of length 6, e.g. [1,1,1,1,1,1]"
+ )
+
+ self.chosen_dataset = chosen_dataset
+ self._index_data(root)
+ super().__init__(
+ root=root,
+ tables=["cardiology"],
+ dataset_name="Cardiology",
+ config_path=config_path,
+ **kwargs,
+ )
+
+ @property
+ def default_task(self):
+ """Returns the default multi-label ECG classification task.
+
+ Returns:
+ CardiologyMultilabelClassification: the default task.
+
+ Example::
+ >>> dataset = Cardiology2Dataset(root="...")
+ >>> task = dataset.default_task
+ """
+ from pyhealth.tasks import CardiologyMultilabelClassification
+ return CardiologyMultilabelClassification()
+
+ def _index_data(self, root: str) -> None:
+ """Scans all .hea files and writes a flat metadata CSV.
+
+ For each recording the following fields are extracted from the header:
+
+ - patient_id: "{dataset_idx}_{patient_idx}"
+ - signal_path: absolute path to the .mat file
+ - dx: comma-separated SNOMED-CT diagnosis codes
+ - sex: patient sex string (e.g. "Male")
+ - age: patient age string (e.g. 63)
+ - chosen_dataset: name of the sub-dataset this recording belongs to (e.g "cpsc_2018")
+
+ The resulting table is written to '{root}/cardiology-metadata-pyhealth.csv'
+
+ Args:
+ root (str): Root directory of the raw data.
+
+ Raises:
+ FileNotFoundError: If 'root' does not exist.
+ """
+ if not os.path.exists(root):
+ raise FileNotFoundError(f"Dataset root does not exist: {root}")
+
+ out_path = os.path.join(root, "cardiology-metadata-pyhealth.csv")
+ if os.path.isfile(out_path):
+ logger.info(f"Found existing metadata index: {out_path}")
+ logger.info(f"Overwriting existing metadata index...")
+
+ active_datasets = [
+ (idx, name)
+ for idx, (name, flag) in enumerate(
+ zip(SUBDATASET_NAMES, self.chosen_dataset)
+ )
+ if flag
+ ]
+
+ rows = []
+ for dataset_idx, name in active_datasets:
+ dataset_dir = os.path.join(root, name)
+ patient_dirs = sorted(
+ d for d in os.listdir(dataset_dir)
+ if os.path.isdir(os.path.join(dataset_dir, d))
+ )
+
+ for patient_idx, patient_dir in enumerate(patient_dirs):
+ patient_root = os.path.join(dataset_dir, patient_dir)
+ pid = f"{dataset_idx}_{patient_idx}"
+
+ for record in self._collect_recordings(patient_root):
+ record["patient_id"] = pid
+ record["chosen_dataset"] = name
+ rows.append(record)
+
+ df = pd.DataFrame(rows)
+ df.to_csv(out_path, index=False)
+ logger.info(
+ f"Wrote metadata index with {len(df)} recordings to {out_path}"
+ )
+
+ def _collect_recordings(self, patient_dir: str) -> List[Dict]:
+ """Collects metadata for all recordings in a patient directory.
+
+ Finds every '.hea' file, checks for a matching '.mat' file,
+ and parses the header to extract diagnosis codes and demographics.
+
+ Args:
+ patient_dir (str): Absolute path to a patient directory.
+
+ Returns:
+ List[Dict]: One dict per valid recording with keys
+ 'signal_path', 'dx', 'sex', and 'age'.
+ """
+ records = []
+ hea_files = [
+ f for f in os.listdir(patient_dir) if f.endswith(".hea")
+ ]
+ for hea_file in hea_files:
+ file_name = hea_file[:-4]
+ mat_path = os.path.join(patient_dir, file_name + ".mat")
+ hea_path = os.path.join(patient_dir, hea_file)
+
+ if not os.path.isfile(mat_path):
+ logger.debug(f"No matching .mat for {hea_path}, skipping")
+ continue
+
+ dx, sex, age = self._parse_header(hea_path)
+ records.append({
+ "signal_path": mat_path,
+ "dx": dx,
+ "sex": sex,
+ "age": age,
+ })
+ return records
+
+ @staticmethod
+ def _parse_header(hea_path: str):
+ """Parses Dx, Sex, and Age from a PhysioNet 2020 .hea header file.
+
+ The last few lines of each '.hea' file follow this format:
+
+ # Age: 63
+ # Sex: Male
+ # Dx: 426783006,164934002
+
+ Args:
+ hea_path (str): Path to the '.hea' file.
+
+ Returns:
+ Tuple[str, str, str]: '(dx, sex, age)' as raw strings.
+ 'dx' is a comma-separated SNOMED-CT code string.
+ """
+ dx = sex = age = ""
+ with open(hea_path, "r") as f:
+ for line in f:
+ line = line.strip()
+ if line.startswith("# Dx:"):
+ dx = line.split(":", 1)[1].strip()
+ elif line.startswith("# Sex:"):
+ sex = line.split(":", 1)[1].strip()
+ elif line.startswith("# Age:"):
+ age = line.split(":", 1)[1].strip()
+
+ return dx, sex, age
diff --git a/pyhealth/datasets/configs/cardiology.yaml b/pyhealth/datasets/configs/cardiology.yaml
new file mode 100644
index 000000000..b70f8813f
--- /dev/null
+++ b/pyhealth/datasets/configs/cardiology.yaml
@@ -0,0 +1,13 @@
+# Author: John Ma (jm119@illinois.edu)
+version: "1.0"
+tables:
+ cardiology:
+ file_path: "cardiology-metadata-pyhealth.csv"
+ patient_id: "patient_id"
+ timestamp: null
+ attributes:
+ - "signal_path"
+ - "dx"
+ - "sex"
+ - "age"
+ - "chosen_dataset"
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 2f4294a19..eaec67908 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -9,6 +9,9 @@
cardiology_isCD_fn,
cardiology_isWA_fn,
)
+from .cardiology_multilabel_classification import (
+ CardiologyMultilabelClassification,
+)
from .chestxray14_binary_classification import ChestXray14BinaryClassification
from .chestxray14_multilabel_classification import ChestXray14MultilabelClassification
from .covid19_cxr_classification import COVID19CXRClassification
diff --git a/pyhealth/tasks/cardiology_multilabel_classification.py b/pyhealth/tasks/cardiology_multilabel_classification.py
new file mode 100644
index 000000000..8e75e267f
--- /dev/null
+++ b/pyhealth/tasks/cardiology_multilabel_classification.py
@@ -0,0 +1,142 @@
+"""
+PyHealth task for multi-label ECG classification using the CardiologyDataset.
+
+Author:
+ John Ma (jm119@illinois.edu)
+"""
+
+import logging
+import os
+from typing import Dict, List, Optional
+
+import numpy as np
+from scipy.io import loadmat
+
+from pyhealth.data import Event, Patient
+from pyhealth.tasks import BaseTask
+
+logger = logging.getLogger(__name__)
+
+
+class CardiologyMultilabelClassification(BaseTask):
+ """Multi-label ECG classification task for the CardiologyDataset.
+
+ Each 2.5-second ECG window is labeled with all SNOMED-CT diagnosis codes
+ present in the recording header, spanning five symptom categories:
+ Arrhythmias (AR), Bundle Branch & Fascicular Blocks (BBBFB), Axis
+ Deviations (AD), Conduction Delays (CD), and Wave Abnormalities (WA).
+ The label space is defined by :data:`Cardiology2Dataset.classes` (24 codes).
+
+ The task follows the multi-label benchmark protocol of Nonaka & Seita
+ (2021), evaluated with macro-averaged ROC-AUC over all label classes.
+
+ Attributes:
+ task_name (str): The name of the task.
+ input_schema (Dict[str, str]): The schema for the task input.
+ output_schema (Dict[str, str]): The schema for the task output.
+ epoch_sec (float): Window length in seconds. Default 2.5.
+ shift (float): Sliding window step in seconds. Default 1.25.
+ leads (Optional[List[int]]): List of indices representing which ECG leads to keep.
+ For example, [0] for Lead I only. Defaults to None (keeps all 12 leads).
+
+ Examples:
+ >>> from pyhealth.datasets import Cardiology2Dataset
+ >>> from pyhealth.tasks import CardiologyMultilabelClassification
+ >>> dataset = Cardiology2Dataset(
+ ... root="/data/physionet.org/files/challenge-2020/1.0.2/training",
+ ... )
+ >>> task = CardiologyMultilabelClassification()
+ >>> sample_dataset = dataset.set_task(task)
+ """
+
+ task_name: str = "CardiologyMultilabelClassification"
+ input_schema: Dict[str, str] = {"signal": "tensor"}
+ output_schema: Dict[str, str] = {"labels": "multilabel"}
+
+ def __init__(
+ self,
+ epoch_sec: float = 2.5,
+ shift: float = 1.25,
+ leads: Optional[List[int]] = None,
+ ) -> None:
+ """Initializes the task.
+
+ Args:
+ epoch_sec (float): Length of each sliding window in seconds. Default 2.5.
+ shift (float): Step size of the sliding window in seconds. Default 1.25.
+ leads (Optional[List[int]]): List of indices representing which ECG leads to keep.
+ For example, [0] for Lead I only. Defaults to None (keeps all 12 leads).
+ """
+ self.epoch_sec = epoch_sec
+ self.shift = shift
+ self.leads = leads if leads is not None else list(range(12)) # Default to all 12 leads
+
+ def __call__(self, patient: Patient) -> List[Dict]:
+ """Generates multi-label classification samples for a single patient.
+
+ For each ECG recording event, the raw signal is loaded from the
+ '.mat' file and sliced into overlapping windows of 'epoch_sec'
+ seconds with a step of 'shift' seconds.
+
+ Args:
+ patient (Patient): A Patient object produced by
+ :class:`~pyhealth.datasets.Cardiology2Dataset`. Each event of
+ type 'cardiology' must contain 'signal_path' and 'dx'
+ attributes.
+
+ Returns:
+ List[Dict]: One dict per epoch window, each containing:
+ - 'patient_id' (str): patient identifier.
+ - 'visit_id' (str): stem of the signal filename.
+ - 'signal' (np.ndarray): epoch array of shape
+ '(12, epoch_sec * 500)'.
+ - 'labels' (List[str]): SNOMED-CT codes present in this
+ recording (filtered to 'Cardiology2Dataset.classes').
+ """
+ from pyhealth.datasets import Cardiology2Dataset
+
+ events: List[Event] = patient.get_events(event_type="cardiology")
+ samples = []
+
+ known_codes = set(Cardiology2Dataset.classes)
+ fs = 500
+ epoch_samples = int(fs * self.epoch_sec)
+ shift_samples = int(fs * self.shift)
+
+ # Convert leads to a numpy array for robust advanced indexing
+ lead_indices = np.array(self.leads, dtype=int)
+
+ for event in events:
+ signal_path = event["signal_path"]
+ dx_raw = event["dx"]
+
+ labels: List[str] = [
+ code for code in dx_raw.split(",") if code.strip() in known_codes
+ ]
+
+ try:
+ X: np.ndarray = loadmat(signal_path)["val"]
+ except Exception as exc:
+ logger.warning(f"Failed to load {signal_path}: {exc}")
+ continue
+
+ if X.shape[1] < epoch_samples: # if the signal is too short, skip
+ continue
+
+ visit_id = os.path.splitext(os.path.basename(signal_path))[0]
+ n_windows = (X.shape[1] - epoch_samples) // shift_samples + 1
+
+ for i in range(n_windows):
+ # Apply the spatial ablation (lead slicing)
+ epoch = X[lead_indices, shift_samples * i : shift_samples * i + epoch_samples]
+
+ samples.append(
+ {
+ "patient_id": patient.patient_id,
+ "visit_id": visit_id,
+ "signal": epoch,
+ "labels": labels,
+ }
+ )
+
+ return samples
diff --git a/tests/core/test_cardiology_multilabel.py b/tests/core/test_cardiology_multilabel.py
new file mode 100644
index 000000000..b76264d93
--- /dev/null
+++ b/tests/core/test_cardiology_multilabel.py
@@ -0,0 +1,184 @@
+"""Unit tests for the Cardiology2Dataset and CardiologyMultilabelClassification."""
+
+# TestCardiology2Dataset covers the dataset
+# TestCardiologyMultilabelClassification covers the task
+import csv
+from pathlib import Path
+import tempfile
+import unittest
+from unittest.mock import patch
+
+import numpy as np
+
+from pyhealth.datasets import Cardiology2Dataset
+from pyhealth.tasks import CardiologyMultilabelClassification
+
+
+class TestCardiology2Dataset(unittest.TestCase):
+ def _write_recording(
+ self,
+ patient_dir: Path,
+ record_name: str,
+ dx: str,
+ sex: str = "Male",
+ age: str = "63",
+ signal_length: int = 2500,
+ ) -> None:
+ patient_dir.mkdir(parents=True, exist_ok=True)
+ (patient_dir / f"{record_name}.mat").write_bytes(b"")
+ (patient_dir / f"{record_name}.hea").write_text(
+ "\n".join(
+ [
+ f"{record_name} 12 500 {signal_length} 16 0 0 0 0",
+ f"# Age: {age}",
+ f"# Sex: {sex}",
+ f"# Dx: {dx}",
+ ]
+ )
+ + "\n"
+ )
+
+ def test_invalid_chosen_dataset_raises(self):
+ with tempfile.TemporaryDirectory() as tmp:
+ with self.assertRaises(ValueError):
+ Cardiology2Dataset(root=tmp, chosen_dataset=[1, 0, 1])
+
+ def test_dataset_indexes_metadata_and_default_task(self):
+ with tempfile.TemporaryDirectory() as tmp:
+ root = Path(tmp)
+ self._write_recording(
+ root / "cpsc_2018" / "patient_a",
+ "A0001",
+ dx="164889003,427172004",
+ sex="Female",
+ age="54",
+ )
+ self._write_recording(
+ root / "cpsc_2018" / "patient_a",
+ "A0002",
+ dx="426627000",
+ sex="Female",
+ age="54",
+ signal_length=1500,
+ )
+ self._write_recording(
+ root / "georgia" / "patient_b",
+ "E0001",
+ dx="713427006",
+ sex="Male",
+ age="61",
+ )
+
+ cache_dir = root / "cache"
+ dataset = Cardiology2Dataset(
+ root=str(root),
+ chosen_dataset=[1, 0, 1, 0, 0, 0],
+ cache_dir=str(cache_dir),
+ )
+
+ metadata_path = root / "cardiology-metadata-pyhealth.csv"
+ self.assertTrue(metadata_path.exists())
+
+ with metadata_path.open(newline="") as f:
+ metadata = list(csv.DictReader(f))
+ self.assertEqual(len(metadata), 3)
+ self.assertCountEqual(
+ [row["chosen_dataset"] for row in metadata],
+ ["cpsc_2018", "cpsc_2018", "georgia"],
+ )
+ self.assertIn("signal_path", metadata[0])
+ self.assertIn("dx", metadata[0])
+ self.assertIn("sex", metadata[0])
+ self.assertIn("age", metadata[0])
+
+ self.assertEqual(len(dataset.unique_patient_ids), 2)
+ patient = dataset.get_patient("0_0")
+ events = patient.get_events(event_type="cardiology")
+
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[0]["patient_id"], "0_0")
+ self.assertEqual(events[0]["sex"], "Female")
+ self.assertEqual(events[0]["age"], "54")
+ self.assertEqual(events[0]["chosen_dataset"], "cpsc_2018")
+ self.assertTrue(str(events[0]["signal_path"]).endswith(".mat"))
+
+ self.assertIsInstance(
+ dataset.default_task, CardiologyMultilabelClassification
+ )
+
+
+class TestCardiologyMultilabelClassification(unittest.TestCase):
+ def _write_recording(
+ self,
+ patient_dir: Path,
+ record_name: str,
+ dx: str,
+ signal_length: int,
+ ) -> None:
+ patient_dir.mkdir(parents=True, exist_ok=True)
+ (patient_dir / f"{record_name}.mat").write_bytes(b"")
+ (patient_dir / f"{record_name}.hea").write_text(
+ "\n".join(
+ [
+ f"{record_name} 12 500 {signal_length} 16 0 0 0 0",
+ "# Age: 63",
+ "# Sex: Male",
+ f"# Dx: {dx}",
+ ]
+ )
+ + "\n"
+ )
+
+ def test_task_generates_windowed_samples_and_filters_labels(self):
+ with tempfile.TemporaryDirectory() as tmp:
+ root = Path(tmp)
+ patient_dir = root / "cpsc_2018" / "patient_a"
+ self._write_recording(
+ patient_dir,
+ "A0001",
+ dx="164889003,427172004,999999999",
+ signal_length=2500,
+ )
+ self._write_recording(
+ patient_dir,
+ "A0002",
+ dx="164889003",
+ signal_length=1000,
+ )
+
+ dataset = Cardiology2Dataset(
+ root=str(root),
+ chosen_dataset=[1, 0, 0, 0, 0, 0],
+ cache_dir=str(root / "cache"),
+ )
+ patient = dataset.get_patient("0_0")
+ task = CardiologyMultilabelClassification(
+ epoch_sec=2.5,
+ shift=1.25,
+ leads=[0, 2, 4],
+ )
+
+ fake_signal = np.arange(12 * 2500, dtype=np.float32).reshape(12, 2500)
+ with patch(
+ "pyhealth.tasks.cardiology_multilabel_classification.loadmat",
+ side_effect=[{"val": fake_signal}, {"val": fake_signal[:, :1000]}],
+ ):
+ samples = task(patient)
+
+ self.assertEqual(len(samples), 3)
+ for sample in samples:
+ self.assertEqual(sample["patient_id"], "0_0")
+ self.assertEqual(sample["visit_id"], "A0001")
+ self.assertEqual(sample["signal"].shape, (3, 1250))
+ self.assertEqual(sample["labels"], ["164889003", "427172004"])
+
+ def test_task_schema_attributes(self):
+ task = CardiologyMultilabelClassification(leads=[0])
+ self.assertEqual(task.task_name, "CardiologyMultilabelClassification")
+ self.assertEqual(task.input_schema, {"signal": "tensor"})
+ self.assertEqual(task.output_schema, {"labels": "multilabel"})
+ self.assertEqual(task.leads, [0])
+
+
+if __name__ == "__main__":
+ unittest.main()