From dc3c714ae201e567b2a8bb68147a798dc668b58c Mon Sep 17 00:00:00 2001 From: harshagl2002 Date: Sun, 12 Apr 2026 23:31:47 -0400 Subject: [PATCH] add notebook --- ...yhealth_mortality_prediction_harsha4.ipynb | 668 ++++++++++++++++++ 1 file changed, 668 insertions(+) create mode 100644 examples/pyhealth_mortality_prediction_harsha4.ipynb diff --git a/examples/pyhealth_mortality_prediction_harsha4.ipynb b/examples/pyhealth_mortality_prediction_harsha4.ipynb new file mode 100644 index 000000000..c7edf2420 --- /dev/null +++ b/examples/pyhealth_mortality_prediction_harsha4.ipynb @@ -0,0 +1,668 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e04edb50", + "metadata": {}, + "source": [ + "# PyHealth In-Hospital Mortality Prediction\n", + "\n", + "Predict in-hospital mortality using EHR time series data from MIMIC-IV.\n", + "\n", + "**Task:** Predict whether a patient will die during their ICU stay based on clinical measurements (vitals, labs) from the first 48 hours.\n", + "\n", + "**Dataset:** MIMIC-IV (using same data as MedMod project, but EHR-only)\n", + "\n", + "**Relation to MedMod:** This is the unimodal EHR baseline from our reproduction - implementing it cleanly with PyHealth.\n", + "\n", + "**Author:** CS598 Deep Learning for Healthcare" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2cb892c7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch version: 2.5.1\n", + "CUDA available: True\n" + ] + } + ], + "source": [ + "# Import required packages\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "print(f\"CUDA available: {torch.cuda.is_available()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4ee1dc2f", + "metadata": {}, + "source": [ + "## 1. Load MIMIC-IV Data\n", + "\n", + "We'll use the preprocessed MIMIC-IV in-hospital mortality data from the mimic4extract pipeline (same as MedMod)." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "1af31429", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ Found data at C:\\Users\\Rohan Suri\\MedMod\\MedMod-main\\MedMod-main\\mimic4extract\\data\\in-hospital-mortality\n", + " - train/ folder: True\n", + " - test/ folder: True\n" + ] + } + ], + "source": [ + "# Path to MIMIC-IV extracted data\n", + "data_root = Path(r\"C:\\Users\\Rohan Suri\\MedMod\\MedMod-main\\MedMod-main\\mimic4extract\\data\\in-hospital-mortality\")\n", + "\n", + "# Check if data exists\n", + "if data_root.exists():\n", + " print(f\"✓ Found data at {data_root}\")\n", + " print(f\" - train/ folder: {(data_root / 'train').exists()}\")\n", + " print(f\" - test/ folder: {(data_root / 'test').exists()}\")\n", + "else:\n", + " print(f\"✗ Data not found at {data_root}\")\n", + " print(\"Please ensure mimic4extract preprocessing is complete.\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2d3643d", + "metadata": {}, + "source": [ + "## 2. Create PyHealth-Compatible Dataset\n", + "\n", + "Load the time series data and convert to PyHealth format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b502d8c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading train data from 29171 patients...\n" + ] + } + ], + "source": [ + "#from pyhealth.datasets import SampleDataset\n", + "from pyhealth.data import Patient, Visit\n", + "\n", + "def load_mimic4_mortality_samples(split='train', max_features=None):\n", + " \"\"\"Load MIMIC-IV mortality data and convert to PyHealth samples.\"\"\"\n", + " \n", + " # Read listfile\n", + " listfile = pd.read_csv(data_root / f\"{split}_listfile.csv\")\n", + " \n", + " samples = []\n", + " data_dir = data_root / split\n", + " all_feature_dims = []\n", + " \n", + " print(f\"Loading {split} data from {len(listfile)} patients...\")\n", + " \n", + " # First pass: determine max feature dimension\n", + " if max_features is None:\n", + " for idx, row in listfile.iterrows():\n", + " ts_file = data_dir / row['stay']\n", + " if ts_file.exists():\n", + " ts_data = pd.read_csv(ts_file)\n", + " numeric_cols = ts_data.select_dtypes(include=[np.number]).columns\n", + " all_feature_dims.append(len(numeric_cols))\n", + " max_features = max(all_feature_dims) if all_feature_dims else 17\n", + " print(f\" Max feature dimension: {max_features}\")\n", + " \n", + " # Second pass: load data with consistent dimensions\n", + " for idx, row in listfile.iterrows():\n", + " if idx % 1000 == 0:\n", + " print(f\" Processed {idx}/{len(listfile)} samples...\")\n", + " \n", + " # Load time series\n", + " ts_file = data_dir / row['stay']\n", + " if not ts_file.exists():\n", + " continue\n", + " \n", + " # Read CSV and ensure numeric data only\n", + " ts_data = pd.read_csv(ts_file)\n", + " \n", + " # Select only numeric columns\n", + " numeric_cols = ts_data.select_dtypes(include=[np.number]).columns\n", + " ts_numeric = ts_data[numeric_cols].apply(pd.to_numeric, errors='coerce')\n", + " \n", + " # Fill NaN with 0\n", + " ts_numeric = ts_numeric.fillna(0)\n", + " \n", + " # Pad or trim to max_features dimension\n", + " ts_array = ts_numeric.values.astype(np.float32)\n", + " if ts_array.shape[1] < max_features:\n", + " # Pad with zeros\n", + " padding = np.zeros((ts_array.shape[0], max_features - ts_array.shape[1]), dtype=np.float32)\n", + " ts_array = np.concatenate([ts_array, padding], axis=1)\n", + " elif ts_array.shape[1] > max_features:\n", + " # Trim\n", + " ts_array = ts_array[:, :max_features]\n", + " \n", + " # Extract features\n", + " features = {\n", + " 'vitals_mean': ts_array.mean(axis=0).tolist(),\n", + " 'vitals_std': ts_array.std(axis=0).tolist(),\n", + " 'vitals_min': ts_array.min(axis=0).tolist(),\n", + " 'vitals_max': ts_array.max(axis=0).tolist(),\n", + " }\n", + " \n", + " # Create sample\n", + " sample = {\n", + " 'patient_id': str(row['stay'].split('_')[0]),\n", + " 'visit_id': str(row['stay']),\n", + " 'features': features,\n", + " 'label': int(row['y_true']),\n", + " 'timeseries': ts_array,\n", + " 'n_features': max_features\n", + " }\n", + " \n", + " samples.append(sample)\n", + " \n", + " print(f\"✓ Loaded {len(samples)} samples from {split} split (feature_dim={max_features})\")\n", + " return samples, max_features\n", + "\n", + "# Load train and test splits\n", + "all_train_samples, max_features = load_mimic4_mortality_samples('train')\n", + "test_samples, _ = load_mimic4_mortality_samples('test', max_features=max_features)\n", + "\n", + "# Split training data into train (80%) and validation (20%)\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "train_samples, val_samples = train_test_split(\n", + " all_train_samples, \n", + " test_size=0.2, \n", + " random_state=42,\n", + " stratify=[s['label'] for s in all_train_samples] # Stratify by label to maintain class balance\n", + ")\n", + "\n", + "print(f\"\\nDataset summary:\")\n", + "print(f\" Train: {len(train_samples)} samples\")\n", + "print(f\" Val: {len(val_samples)} samples\")\n", + "print(f\" Test: {len(test_samples)} samples\")\n", + "print(f\" Feature dimension: {max_features}\")\n", + "\n", + "# Check class balance\n", + "train_mortality_rate = np.mean([s['label'] for s in train_samples])\n", + "val_mortality_rate = np.mean([s['label'] for s in val_samples])\n", + "test_mortality_rate = np.mean([s['label'] for s in test_samples])\n", + "print(f\"\\nMortality rates:\")\n", + "print(f\" Train: {train_mortality_rate:.3f}\")\n", + "print(f\" Val: {val_mortality_rate:.3f}\")\n", + "print(f\" Test: {test_mortality_rate:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1369ea8a", + "metadata": {}, + "source": [ + "## 3. Create PyHealth Sample Dataset\n", + "\n", + "Wrap our samples in PyHealth's SampleDataset format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a110509", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✓ Created dataloaders\n", + " Batch size: 64\n", + " Train batches: 456\n", + " Val batches: 0\n", + " Test batches: 83\n" + ] + } + ], + "source": [ + "from torch.utils.data import Dataset, DataLoader\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "\n", + "class MortalityDataset(Dataset):\n", + " \"\"\"Simple dataset wrapper for mortality prediction.\"\"\"\n", + " \n", + " def __init__(self, samples, use_timeseries=True):\n", + " self.samples = samples\n", + " self.use_timeseries = use_timeseries\n", + " \n", + " def __len__(self):\n", + " return len(self.samples)\n", + " \n", + " def __getitem__(self, idx):\n", + " sample = self.samples[idx]\n", + " \n", + " if self.use_timeseries:\n", + " # Return time series for RNN/LSTM\n", + " # Convert to float array and handle NaNs\n", + " ts = np.array(sample['timeseries'], dtype=np.float32)\n", + " ts = np.nan_to_num(ts, nan=0.0) # Replace NaN with 0\n", + " x = torch.from_numpy(ts)\n", + " else:\n", + " # Return aggregated features for MLP\n", + " features = sample['features']\n", + " feat_array = np.array(\n", + " features['vitals_mean'] + \n", + " features['vitals_std'] + \n", + " features['vitals_min'] + \n", + " features['vitals_max'],\n", + " dtype=np.float32\n", + " )\n", + " feat_array = np.nan_to_num(feat_array, nan=0.0)\n", + " x = torch.from_numpy(feat_array)\n", + " \n", + " y = torch.FloatTensor([sample['label']])\n", + " \n", + " return x, y\n", + "\n", + "def collate_fn(batch):\n", + " \"\"\"Custom collate function to handle variable-length sequences.\"\"\"\n", + " sequences, labels = zip(*batch)\n", + " \n", + " # Pad sequences to same length\n", + " padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0.0)\n", + " labels = torch.stack(labels)\n", + " \n", + " return padded_sequences, labels\n", + "\n", + "# Create datasets\n", + "train_dataset = MortalityDataset(train_samples, use_timeseries=True)\n", + "val_dataset = MortalityDataset(val_samples, use_timeseries=True)\n", + "test_dataset = MortalityDataset(test_samples, use_timeseries=True)\n", + "\n", + "# Create dataloaders with custom collate function\n", + "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, collate_fn=collate_fn)\n", + "val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0, collate_fn=collate_fn)\n", + "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, collate_fn=collate_fn)\n", + "\n", + "print(\"✓ Created dataloaders\")\n", + "print(f\" Batch size: 64\")\n", + "print(f\" Train batches: {len(train_loader)}\")\n", + "print(f\" Val batches: {len(val_loader)}\")\n", + "print(f\" Test batches: {len(test_loader)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c04e51cd", + "metadata": {}, + "source": [ + "## 4. Build LSTM Model\n", + "\n", + "Simple LSTM model for mortality prediction (matches MedMod unimodal baseline)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb7d7bf3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LSTMMortalityModel(\n", + " (lstm): LSTM(18, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)\n", + " (fc): Linear(in_features=512, out_features=1, bias=True)\n", + " (sigmoid): Sigmoid()\n", + ")\n", + "\n", + "Device: cuda\n", + "Input dimension: 18\n", + "Parameters: 2,142,721\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "class LSTMMortalityModel(nn.Module):\n", + " \"\"\"LSTM model for mortality prediction.\"\"\"\n", + " \n", + " def __init__(self, input_dim=76, hidden_dim=256, num_layers=2, dropout=0.3):\n", + " super().__init__()\n", + " \n", + " self.lstm = nn.LSTM(\n", + " input_dim, \n", + " hidden_dim, \n", + " num_layers,\n", + " batch_first=True,\n", + " dropout=dropout,\n", + " bidirectional=True\n", + " )\n", + " \n", + " # Bidirectional doubles the hidden dimension\n", + " self.fc = nn.Linear(hidden_dim * 2, 1)\n", + " self.sigmoid = nn.Sigmoid()\n", + " \n", + " def forward(self, x):\n", + " # x shape: (batch, seq_len, features)\n", + " lstm_out, (hidden, cell) = self.lstm(x)\n", + " \n", + " # Use last hidden state\n", + " last_hidden = lstm_out[:, -1, :]\n", + " \n", + " # Predict\n", + " logits = self.fc(last_hidden)\n", + " probs = self.sigmoid(logits)\n", + " \n", + " return probs\n", + "\n", + "# Initialize model with correct input dimension\n", + "model = LSTMMortalityModel(\n", + " input_dim=max_features, # Use actual feature dimension from data\n", + " hidden_dim=256,\n", + " num_layers=2,\n", + " dropout=0.3\n", + ")\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "model = model.to(device)\n", + "\n", + "print(model)\n", + "print(f\"\\nDevice: {device}\")\n", + "print(f\"Input dimension: {max_features}\")\n", + "print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "87819fe3", + "metadata": {}, + "source": [ + "## 5. Train Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "298862eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting training...\n", + "------------------------------------------------------------\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required.", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[25], line 65\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[38;5;66;03m# Evaluate\u001b[39;00m\n\u001b[0;32m 64\u001b[0m train_metrics \u001b[38;5;241m=\u001b[39m evaluate(model, train_loader, device)\n\u001b[1;32m---> 65\u001b[0m val_metrics \u001b[38;5;241m=\u001b[39m \u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_epochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 68\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Train - Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_metrics[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, AUROC: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_metrics[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mauroc\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, AUPRC: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_metrics[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mauprc\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "Cell \u001b[1;32mIn[25], line 29\u001b[0m, in \u001b[0;36mevaluate\u001b[1;34m(model, dataloader, device)\u001b[0m\n\u001b[0;32m 26\u001b[0m all_probs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(all_probs)\u001b[38;5;241m.\u001b[39mflatten()\n\u001b[0;32m 27\u001b[0m all_labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(all_labels)\u001b[38;5;241m.\u001b[39mflatten()\n\u001b[1;32m---> 29\u001b[0m auroc \u001b[38;5;241m=\u001b[39m \u001b[43mroc_auc_score\u001b[49m\u001b[43m(\u001b[49m\u001b[43mall_labels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mall_probs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m auprc \u001b[38;5;241m=\u001b[39m average_precision_score(all_labels, all_probs)\n\u001b[0;32m 31\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m accuracy_score(all_labels, (all_probs \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0.5\u001b[39m)\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;28mint\u001b[39m))\n", + "File \u001b[1;32mc:\\Users\\Rohan Suri\\A\\envs\\medmod\\lib\\site-packages\\sklearn\\utils\\_param_validation.py:218\u001b[0m, in \u001b[0;36mvalidate_params..decorator..wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 212\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 213\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 214\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 215\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 216\u001b[0m )\n\u001b[0;32m 217\u001b[0m ):\n\u001b[1;32m--> 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 219\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m 220\u001b[0m \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[0;32m 221\u001b[0m \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[0;32m 222\u001b[0m \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[0;32m 223\u001b[0m \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[0;32m 224\u001b[0m msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[0;32m 225\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 226\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 227\u001b[0m \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[0;32m 228\u001b[0m )\n", + "File \u001b[1;32mc:\\Users\\Rohan Suri\\A\\envs\\medmod\\lib\\site-packages\\sklearn\\metrics\\_ranking.py:664\u001b[0m, in \u001b[0;36mroc_auc_score\u001b[1;34m(y_true, y_score, average, sample_weight, max_fpr, multi_class, labels)\u001b[0m\n\u001b[0;32m 476\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) \\\u001b[39;00m\n\u001b[0;32m 477\u001b[0m \u001b[38;5;124;03mfrom prediction scores.\u001b[39;00m\n\u001b[0;32m 478\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 660\u001b[0m \u001b[38;5;124;03marray([0.82, 0.847, 0.93, 0.872, 0.944])\u001b[39;00m\n\u001b[0;32m 661\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 663\u001b[0m y_type \u001b[38;5;241m=\u001b[39m type_of_target(y_true, input_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_true\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 664\u001b[0m y_true \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mensure_2d\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 665\u001b[0m y_score \u001b[38;5;241m=\u001b[39m check_array(y_score, ensure_2d\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 667\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m y_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmulticlass\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[0;32m 668\u001b[0m y_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m y_score\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m y_score\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m 669\u001b[0m ):\n\u001b[0;32m 670\u001b[0m \u001b[38;5;66;03m# do not support partial ROC computation for multiclass\u001b[39;00m\n", + "File \u001b[1;32mc:\\Users\\Rohan Suri\\A\\envs\\medmod\\lib\\site-packages\\sklearn\\utils\\validation.py:1128\u001b[0m, in \u001b[0;36mcheck_array\u001b[1;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_all_finite, ensure_non_negative, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[0;32m 1126\u001b[0m n_samples \u001b[38;5;241m=\u001b[39m _num_samples(array)\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_samples \u001b[38;5;241m<\u001b[39m ensure_min_samples:\n\u001b[1;32m-> 1128\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1129\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFound array with \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m sample(s) (shape=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m) while a\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1130\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m minimum of \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m is required\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1131\u001b[0m \u001b[38;5;241m%\u001b[39m (n_samples, array\u001b[38;5;241m.\u001b[39mshape, ensure_min_samples, context)\n\u001b[0;32m 1132\u001b[0m )\n\u001b[0;32m 1134\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ensure_min_features \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m array\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[0;32m 1135\u001b[0m n_features \u001b[38;5;241m=\u001b[39m array\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m]\n", + "\u001b[1;31mValueError\u001b[0m: Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required." + ] + } + ], + "source": [ + "from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score\n", + "import torch.optim as optim\n", + "\n", + "# Training setup\n", + "criterion = nn.BCELoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)\n", + "\n", + "def evaluate(model, dataloader, device):\n", + " \"\"\"Evaluate model and return metrics.\"\"\"\n", + " model.eval()\n", + " all_probs = []\n", + " all_labels = []\n", + " total_loss = 0\n", + " \n", + " with torch.no_grad():\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + " \n", + " probs = model(x)\n", + " loss = criterion(probs, y)\n", + " \n", + " all_probs.extend(probs.cpu().numpy())\n", + " all_labels.extend(y.cpu().numpy())\n", + " total_loss += loss.item()\n", + " \n", + " all_probs = np.array(all_probs).flatten()\n", + " all_labels = np.array(all_labels).flatten()\n", + " \n", + " # Check if we have any samples\n", + " if len(all_probs) == 0:\n", + " print(f\"WARNING: No samples in dataloader!\")\n", + " return {\n", + " 'loss': 0.0,\n", + " 'auroc': 0.0,\n", + " 'auprc': 0.0,\n", + " 'accuracy': 0.0\n", + " }\n", + " \n", + " # Check if we have both classes\n", + " if len(np.unique(all_labels)) < 2:\n", + " print(f\"WARNING: Only one class present in labels. Cannot compute AUROC/AUPRC.\")\n", + " accuracy = accuracy_score(all_labels, (all_probs > 0.5).astype(int))\n", + " avg_loss = total_loss / len(dataloader)\n", + " return {\n", + " 'loss': avg_loss,\n", + " 'auroc': 0.0,\n", + " 'auprc': 0.0,\n", + " 'accuracy': accuracy\n", + " }\n", + " \n", + " auroc = roc_auc_score(all_labels, all_probs)\n", + " auprc = average_precision_score(all_labels, all_probs)\n", + " accuracy = accuracy_score(all_labels, (all_probs > 0.5).astype(int))\n", + " avg_loss = total_loss / len(dataloader)\n", + " \n", + " return {\n", + " 'loss': avg_loss,\n", + " 'auroc': auroc,\n", + " 'auprc': auprc,\n", + " 'accuracy': accuracy\n", + " }\n", + "\n", + "# Training loop\n", + "num_epochs = 5 # Reduced for quicker testing, 50 is better for final runs\n", + "best_val_auroc = 0\n", + "\n", + "print(\"Starting training...\")\n", + "print(\"-\" * 60)\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " train_loss = 0\n", + " \n", + " for batch_idx, (x, y) in enumerate(train_loader):\n", + " x, y = x.to(device), y.to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " probs = model(x)\n", + " loss = criterion(probs, y)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " train_loss += loss.item()\n", + " \n", + " # Evaluate\n", + " train_metrics = evaluate(model, train_loader, device)\n", + " val_metrics = evaluate(model, val_loader, device)\n", + " \n", + " print(f\"Epoch {epoch+1}/{num_epochs}\")\n", + " print(f\" Train - Loss: {train_metrics['loss']:.4f}, AUROC: {train_metrics['auroc']:.4f}, AUPRC: {train_metrics['auprc']:.4f}\")\n", + " print(f\" Val - Loss: {val_metrics['loss']:.4f}, AUROC: {val_metrics['auroc']:.4f}, AUPRC: {val_metrics['auprc']:.4f}\")\n", + " \n", + " # Save best model\n", + " if val_metrics['auroc'] > best_val_auroc:\n", + " best_val_auroc = val_metrics['auroc']\n", + " torch.save(model.state_dict(), 'best_mortality_model.pt')\n", + " print(f\" ✓ Saved best model (AUROC: {best_val_auroc:.4f})\")\n", + "\n", + "print(\"\\n✓ Training complete!\")" + ] + }, + { + "cell_type": "markdown", + "id": "3abb8fcb", + "metadata": {}, + "source": [ + "## 6. Evaluate on Test Set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95832c49", + "metadata": {}, + "outputs": [], + "source": [ + "# Load best model\n", + "model.load_state_dict(torch.load('best_mortality_model.pt'))\n", + "\n", + "# Evaluate on test set\n", + "test_metrics = evaluate(model, test_loader, device)\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"FINAL TEST RESULTS\")\n", + "print(\"=\" * 60)\n", + "print(f\"Test AUROC: {test_metrics['auroc']:.4f}\")\n", + "print(f\"Test AUPRC: {test_metrics['auprc']:.4f}\")\n", + "print(f\"Test Accuracy: {test_metrics['accuracy']:.4f}\")\n", + "print(f\"Test Loss: {test_metrics['loss']:.4f}\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "id": "19ad4a5f", + "metadata": {}, + "source": [ + "## 7. Detailed Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61751552", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import confusion_matrix, classification_report\n", + "\n", + "# Get predictions\n", + "model.eval()\n", + "all_probs = []\n", + "all_labels = []\n", + "\n", + "with torch.no_grad():\n", + " for x, y in test_loader:\n", + " x = x.to(device)\n", + " probs = model(x)\n", + " all_probs.extend(probs.cpu().numpy())\n", + " all_labels.extend(y.numpy())\n", + "\n", + "all_probs = np.array(all_probs).flatten()\n", + "all_labels = np.array(all_labels).flatten()\n", + "all_preds = (all_probs > 0.5).astype(int)\n", + "\n", + "# Confusion matrix\n", + "cm = confusion_matrix(all_labels, all_preds)\n", + "print(\"Confusion Matrix:\")\n", + "print(cm)\n", + "print()\n", + "\n", + "# Classification report\n", + "print(\"Classification Report:\")\n", + "print(classification_report(all_labels, all_preds, target_names=['Survive', 'Death']))" + ] + }, + { + "cell_type": "markdown", + "id": "a0bc54e1", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. **Loading MIMIC-IV mortality data** - Same preprocessing as MedMod project\n", + "2. **Building LSTM model** - Bidirectional LSTM (matches MedMod unimodal baseline)\n", + "3. **Training with binary cross-entropy** - Standard mortality prediction setup\n", + "4. **Evaluation with AUROC/AUPRC** - Same metrics as MedMod paper\n", + "\n", + "**Comparison to MedMod Project:**\n", + "- **MedMod unimodal baseline:** AUROC 0.822 (reported in paper)\n", + "- **This PyHealth implementation:** Should achieve similar AUROC (~0.80-0.82)\n", + "- **Key difference:** This uses PyHealth-style data loading but custom LSTM model\n", + "\n", + "**Why this approach works:**\n", + "- ✅ Uses same MIMIC-IV data as MedMod\n", + "- ✅ Same task (in-hospital mortality)\n", + "- ✅ Similar model architecture (BiLSTM)\n", + "- ✅ Clean, reproducible code\n", + "- ✅ No multimodal complexity\n", + "\n", + "**For PyHealth contribution:** This could be extended to use PyHealth's built-in MIMIC4Dataset loader instead of manual CSV loading." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "medmod", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}