Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions pyhealth/models/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,33 @@


class LogisticRegression(BaseModel):
"""Logistic/Linear regression baseline model.
"""Logistic/Linear regression baseline model with optional L1 regularization.

This model uses embeddings from different input features and applies a single
linear transformation (no hidden layers or non-linearity) to produce predictions.

- For classification tasks: acts as logistic regression
- For regression tasks: acts as linear regression

The model automatically handles different input types through the EmbeddingModel,
pools sequence dimensions, concatenates all feature embeddings, and applies a
final linear layer.

L1 regularization (``l1_lambda > 0``) adds a sparsity-inducing penalty to the
weight vector during training, equivalent to scikit-learn's
``LogisticRegression(penalty='l1', C=C)`` with ``l1_lambda = 1 / (C * n_train)``.
This is the formulation used in Boag et al. (2018) "Racial Disparities and
Mistrust in End-of-Life Care" (MLHC 2018) to train interpersonal-feature
mistrust classifiers on MIMIC-III.

Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
embedding_dim: the embedding dimension. Default is 128.
l1_lambda: coefficient for the L1 weight penalty added to the loss.
``loss = BCE + l1_lambda * ||W||_1``. Set to 0.0 (default) to
disable regularization (backward-compatible). Equivalent to
``1 / (C * n_train)`` for sklearn's C-parameterised formulation.
**kwargs: other parameters (for compatibility).

Examples:
Expand Down Expand Up @@ -55,7 +66,7 @@ class LogisticRegression(BaseModel):
... dataset_name="test")
>>>
>>> from pyhealth.models import LogisticRegression
>>> model = LogisticRegression(dataset=dataset)
>>> model = LogisticRegression(dataset=dataset, l1_lambda=1e-4)
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
Expand All @@ -64,7 +75,7 @@ class LogisticRegression(BaseModel):
>>> ret = model(**data_batch)
>>> print(ret)
{
'loss': tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
'loss': tensor(0.6931, grad_fn=<AddBackward0>),
'y_prob': tensor([[0.5123],
[0.4987]], grad_fn=<SigmoidBackward0>),
'y_true': tensor([[1.],
Expand All @@ -80,10 +91,12 @@ def __init__(
self,
dataset: SampleDataset,
embedding_dim: int = 128,
l1_lambda: float = 0.0,
**kwargs,
):
super(LogisticRegression, self).__init__(dataset)
self.embedding_dim = embedding_dim
self.l1_lambda = l1_lambda

assert len(self.label_keys) == 1, "Only one label key is supported"
self.label_key = self.label_keys[0]
Expand Down Expand Up @@ -197,6 +210,10 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
# Obtain y_true, loss, y_prob
y_true = kwargs[self.label_key].to(self.device)
loss = self.get_loss_function()(logits, y_true)
# L1 regularization on the final linear layer's weights (bias excluded),
# equivalent to sklearn's penalty='l1' with C = 1 / (l1_lambda * n_train).
if self.l1_lambda > 0.0:
loss = loss + self.l1_lambda * self.fc.weight.abs().sum()
y_prob = self.prepare_y_prob(logits)

results = {
Expand Down