Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions pyhealth/models/wav2sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from pyhealth.models import BaseModel


class Wav2Sleep(BaseModel):
"""Wav2Sleep: A Unified Multi-Modal Approach to Sleep Stage Classification.

This model employs modality-specific convolutional encoders, a
transformer-based fusion mechanism (Epoch Mixer), and a dilated
convolutional sequence mixer.

Paper: Carter, J. F.; and Tarassenko, L. 2024. wav2sleep: A Unified
Multi-Modal Approach to Sleep Stage Classification from Physiological Signals.

Args:
dataset: PyHealth dataset object.
feature_keys: List of keys in the dataset for input features.
label_key: Key in the dataset for the label.
mode: "binary", "multiclass", or "multilabel".
embedding_dim: Internal hidden dimension for all modules. Default is 128.
nhead: Number of heads in the Transformer Epoch Mixer. Default is 4.
num_layers: Number of Transformer layers. Default is 2.
mask_prob: Probability for stochastic masking during training. Default is 0.2.
**kwargs: Additional hyperparameter arguments.
"""

def __init__(
self,
dataset,
feature_keys: List[str],
label_key: str,
mode: str,
embedding_dim: int = 128,
nhead: int = 4,
num_layers: int = 2,
mask_prob: float = 0.2,
**kwargs,
):
super(Wav2Sleep, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)
self.embedding_dim = embedding_dim
self.mask_prob = mask_prob

# 1. [span_3](start_span)Signal Encoders: Modality-specific CNNs[span_3](end_span)
self.feature_encoders = nn.ModuleDict()
for key in feature_keys:
# Placeholder for actual CNN architecture
self.feature_encoders[key] = nn.Sequential(
nn.Conv1d(1, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1),
nn.Linear(64, embedding_dim)
)

# 2. [span_4](start_span)Epoch Mixer: Transformer with [CLS] token[span_4](end_span)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim, nhead=nhead, batch_first=True
)
self.epoch_mixer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

# 3. [span_5](start_span)Sequence Mixer: Dilated Convolutions[span_5](end_span)
self.sequence_mixer = nn.Sequential(
nn.Conv1d(embedding_dim, embedding_dim, kernel_size=3, padding=2, dilation=2),
nn.ReLU()
)

# Final Classification Head
self.fc = nn.Linear(embedding_dim, self.total_num_classes)

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward pass implementing stochastic masking and fusion.

Steps:
1. Encode each available modality.
2. [span_6](start_span)Apply stochastic masking (training only)[span_6](end_span).
3. Fuse features using [CLS] token in Transformer.
4. Model temporal sequence with dilated convolutions.
"""
pass