Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ class ModernBertConfig(PreTrainedConfig):
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
be faster in some scenarios.
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.

Examples:

Expand Down Expand Up @@ -168,7 +165,6 @@ def __init__(
sparse_prediction: Optional[bool] = False,
sparse_pred_ignore_index: Optional[int] = -100,
reference_compile: Optional[bool] = None,
repad_logits_with_grad: Optional[bool] = False,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -197,7 +193,6 @@ def __init__(
self.sparse_prediction = sparse_prediction
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
Expand Down
Loading