diff --git a/pyhealth/models/wav2sleep.py b/pyhealth/models/wav2sleep.py new file mode 100644 index 000000000..161698df4 --- /dev/null +++ b/pyhealth/models/wav2sleep.py @@ -0,0 +1,86 @@ +from typing import Dict, List, Optional +import torch +import torch.nn as nn +from pyhealth.models import BaseModel + + +class Wav2Sleep(BaseModel): + """Wav2Sleep: A Unified Multi-Modal Approach to Sleep Stage Classification. + + This model employs modality-specific convolutional encoders, a + transformer-based fusion mechanism (Epoch Mixer), and a dilated + convolutional sequence mixer. + + Paper: Carter, J. F.; and Tarassenko, L. 2024. wav2sleep: A Unified + Multi-Modal Approach to Sleep Stage Classification from Physiological Signals. + + Args: + dataset: PyHealth dataset object. + feature_keys: List of keys in the dataset for input features. + label_key: Key in the dataset for the label. + mode: "binary", "multiclass", or "multilabel". + embedding_dim: Internal hidden dimension for all modules. Default is 128. + nhead: Number of heads in the Transformer Epoch Mixer. Default is 4. + num_layers: Number of Transformer layers. Default is 2. + mask_prob: Probability for stochastic masking during training. Default is 0.2. + **kwargs: Additional hyperparameter arguments. + """ + + def __init__( + self, + dataset, + feature_keys: List[str], + label_key: str, + mode: str, + embedding_dim: int = 128, + nhead: int = 4, + num_layers: int = 2, + mask_prob: float = 0.2, + **kwargs, + ): + super(Wav2Sleep, self).__init__( + dataset=dataset, + feature_keys=feature_keys, + label_key=label_key, + mode=mode, + ) + self.embedding_dim = embedding_dim + self.mask_prob = mask_prob + + # 1. [span_3](start_span)Signal Encoders: Modality-specific CNNs[span_3](end_span) + self.feature_encoders = nn.ModuleDict() + for key in feature_keys: + # Placeholder for actual CNN architecture + self.feature_encoders[key] = nn.Sequential( + nn.Conv1d(1, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool1d(1), + nn.Linear(64, embedding_dim) + ) + + # 2. [span_4](start_span)Epoch Mixer: Transformer with [CLS] token[span_4](end_span) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim)) + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, nhead=nhead, batch_first=True + ) + self.epoch_mixer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # 3. [span_5](start_span)Sequence Mixer: Dilated Convolutions[span_5](end_span) + self.sequence_mixer = nn.Sequential( + nn.Conv1d(embedding_dim, embedding_dim, kernel_size=3, padding=2, dilation=2), + nn.ReLU() + ) + + # Final Classification Head + self.fc = nn.Linear(embedding_dim, self.total_num_classes) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass implementing stochastic masking and fusion. + + Steps: + 1. Encode each available modality. + 2. [span_6](start_span)Apply stochastic masking (training only)[span_6](end_span). + 3. Fuse features using [CLS] token in Transformer. + 4. Model temporal sequence with dilated convolutions. + """ + pass