Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pyhealth/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,25 @@ def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
List of processed sample dictionaries.
"""
pass

class VocabMixin(ABC):
"""
Base class for feature processors that build a vocabulary.

Provides a common interface for accessing vocabulary-related information.
"""

@abstractmethod
def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
pass

@abstractmethod
def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
pass

@abstractmethod
def add(self, vocabularies: set[str]):
"""Add specified vocabularies to the processor."""
pass
31 changes: 26 additions & 5 deletions pyhealth/processors/deep_nested_sequence_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Iterable

import torch

from . import register_processor
from .base_processor import FeatureProcessor
from .base_processor import FeatureProcessor, VocabMixin


@register_processor("deep_nested_sequence")
class DeepNestedSequenceProcessor(FeatureProcessor):
class DeepNestedSequenceProcessor(FeatureProcessor, VocabMixin):
"""
Feature processor for deeply nested categorical sequences with vocabulary.

Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self):
self._max_middle_len = 1 # Maximum length of middle sequences (e.g. visits)
self._max_inner_len = 1 # Maximum length of inner sequences (e.g. codes per visit)

def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
"""Build vocabulary and determine maximum sequence lengths.

Args:
Expand Down Expand Up @@ -86,6 +86,27 @@ def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
self._max_middle_len = max(1, max_middle_len)
self._max_inner_len = max(1, max_inner_len)

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
self.code_vocab = {"<pad>": 0, "<unk>": -1}
for i, v in enumerate(vocab):
self.code_vocab[v] = i + 1

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
self.code_vocab = {"<pad>": 0, "<unk>": -1}
for i, v in enumerate(vocab):
self.code_vocab[v] = i + 1

def add(self, vocabularies: set[str]):
"""Add specified vocabularies to the processor."""
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
self.code_vocab = {"<pad>": 0, "<unk>": -1}
for i, v in enumerate(vocab):
self.code_vocab[v] = i + 1

def process(self, value: List[List[List[Any]]]) -> torch.Tensor:
"""Process deep nested sequence into padded 3D tensor.

Expand Down Expand Up @@ -209,7 +230,7 @@ def __init__(self, forward_fill: bool = True):
self._max_inner_len = 1 # Maximum length of inner sequences (values per visit)
self.forward_fill = forward_fill

def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
"""Determine maximum sequence lengths.

Args:
Expand Down
22 changes: 20 additions & 2 deletions pyhealth/processors/nested_sequence_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch

from . import register_processor
from .base_processor import FeatureProcessor
from .base_processor import FeatureProcessor, VocabMixin


@register_processor("nested_sequence")
class NestedSequenceProcessor(FeatureProcessor):
class NestedSequenceProcessor(FeatureProcessor, VocabMixin):
"""
Feature processor for nested categorical sequences with vocabulary.

Expand Down Expand Up @@ -86,6 +86,24 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
# (-1 because <unk> is already in vocab)
self.code_vocab["<unk>"] = len(self.code_vocab) - 1

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def add(self, vocabularies: set[str]):
"""Add specified vocabularies to the processor."""
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def process(self, value: List[List[Any]]) -> torch.Tensor:
"""Process nested sequence into padded 2D tensor.

Expand Down
22 changes: 20 additions & 2 deletions pyhealth/processors/sequence_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch

from . import register_processor
from .base_processor import FeatureProcessor
from .base_processor import FeatureProcessor, VocabMixin


@register_processor("sequence")
class SequenceProcessor(FeatureProcessor):
class SequenceProcessor(FeatureProcessor, VocabMixin):
"""
Feature processor for encoding categorical sequences (e.g., medical codes) into numerical indices.

Expand Down Expand Up @@ -48,6 +48,24 @@ def process(self, value: Any) -> torch.Tensor:
indices.append(self.code_vocab["<unk>"])

return torch.tensor(indices, dtype=torch.long)

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def add(self, vocabularies: set[str]):
"""Add specified vocabularies to the processor."""
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def size(self):
return len(self.code_vocab)
Expand Down
22 changes: 20 additions & 2 deletions pyhealth/processors/stagenet_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch

from . import register_processor
from .base_processor import FeatureProcessor
from .base_processor import FeatureProcessor, VocabMixin


@register_processor("stagenet")
class StageNetProcessor(FeatureProcessor):
class StageNetProcessor(FeatureProcessor, VocabMixin):
"""
Feature processor for StageNet CODE inputs with coupled value/time data.

Expand Down Expand Up @@ -122,6 +122,24 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
# Since <unk> is already in the vocab dict, we use _next_index
self.code_vocab["<unk>"] = self._next_index

def remove(self, vocabularies: set[str]):
"""Remove specified vocabularies from the processor."""
vocab = list(set(self.code_vocab.keys()) - vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def retain(self, vocabularies: set[str]):
"""Retain only the specified vocabularies in the processor."""
vocab = list(set(self.code_vocab.keys()) & vocabularies)
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def add(self, vocabularies: set[str]):
"""Add specified vocabularies to the processor."""
vocab = list(set(self.code_vocab.keys()) | vocabularies - {"<pad>", "<unk>"})
vocab = ["<pad>"] + vocab + ["<unk>"]
self.code_vocab = {v: i for i, v in enumerate(vocab)}

def process(
self, value: Tuple[Optional[List], List]
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
Expand Down
Loading