From d101d910671240a00a6a3897bf2d52472cf2dafc Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 15:39:04 +0000 Subject: [PATCH 1/7] feat: support list of lists labels "ragged_multiclass" --- torchTextClassifiers/dataset/dataset.py | 44 +++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/dataset/dataset.py b/torchTextClassifiers/dataset/dataset.py index 9e7b764..4892f7c 100644 --- a/torchTextClassifiers/dataset/dataset.py +++ b/torchTextClassifiers/dataset/dataset.py @@ -1,3 +1,4 @@ +import logging import os from typing import List, Union @@ -8,6 +9,7 @@ from torchTextClassifiers.tokenizers import BaseTokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" +logger = logging.getLogger(__name__) class TextClassificationDataset(Dataset): @@ -16,7 +18,8 @@ def __init__( texts: List[str], categorical_variables: Union[List[List[int]], np.array, None], tokenizer: BaseTokenizer, - labels: Union[List[int], None] = None, + labels: Union[List[int], List[List[int]], np.array, None] = None, + ragged_multilabel: bool = False, ): self.categorical_variables = categorical_variables @@ -32,6 +35,23 @@ def __init__( self.texts = texts self.tokenizer = tokenizer self.labels = labels + self.ragged_multilabel = ragged_multilabel + + if self.ragged_multilabel and self.labels is not None: + max_value = int(max(max(row) for row in labels if row)) + self.num_classes = max_value + 1 + + if max_value == 1: + try: + labels = np.array(labels) + logger.critical( + """ragged_multilabel set to True but max label value is 1 and all samples have the same number of labels. + If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong.""" + ) + except ValueError: + logger.warning( + "ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong." + ) def __len__(self): return len(self.texts) @@ -59,10 +79,28 @@ def __getitem__(self, idx): ) def collate_fn(self, batch): - text, *categorical_vars, y = zip(*batch) + text, *categorical_vars, labels = zip(*batch) if self.labels is not None: - labels_tensor = torch.tensor(y, dtype=torch.long) + if self.ragged_multilabel: + # Pad labels to the max length in the batch + labels_padded = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(label) for label in labels], + batch_first=True, + padding_value=-1, # use impossible class + ).int() + + labels_tensor = torch.zeros(labels_padded.size(0), 6).float() + mask = labels_padded != -1 + + batch_size = labels_padded.size(0) + rows = torch.arange(batch_size).unsqueeze(1).expand_as(labels_padded)[mask] + cols = labels_padded[mask] + + labels_tensor[rows, cols] = 1 + + else: + labels_tensor = torch.tensor(labels, dtype=torch.long) else: labels_tensor = None From 2074731dec1bf3d883660695ebefee9cc4c6c86d Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 15:42:35 +0000 Subject: [PATCH 2/7] feat: add ragged_multilabel support --- torchTextClassifiers/torchTextClassifiers.py | 47 ++++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 66b285d..0aa2070 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -100,6 +100,7 @@ def __init__( self, tokenizer: BaseTokenizer, model_config: ModelConfig, + ragged_multilabel: bool = False, ): """Initialize the torchTextClassifiers instance. @@ -124,6 +125,7 @@ def __init__( self.model_config = model_config self.tokenizer = tokenizer + self.ragged_multilabel = ragged_multilabel if hasattr(self.tokenizer, "trained"): if not self.tokenizer.trained: @@ -249,6 +251,11 @@ def train( if training_config.optimizer_params is not None: optimizer_params.update(training_config.optimizer_params) + if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel: + logger.warning( + "⚠️ You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks." + ) + self.lightning_module = TextClassificationModule( model=self.pytorch_model, loss=training_config.loss, @@ -271,12 +278,14 @@ def train( categorical_variables=X_train["categorical_variables"], # None if no cat vars tokenizer=self.tokenizer, labels=y_train, + ragged_multilabel=self.ragged_multilabel, ) val_dataset = TextClassificationDataset( texts=X_val["text"], categorical_variables=X_val["categorical_variables"], # None if no cat vars tokenizer=self.tokenizer, labels=y_val, + ragged_multilabel=self.ragged_multilabel, ) train_dataloader = train_dataset.create_dataloader( @@ -352,7 +361,7 @@ def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarra X = self._check_X(X) Y = self._check_Y(Y) - if X["text"].shape[0] != Y.shape[0]: + if X["text"].shape[0] != len(Y): raise ValueError("X_train and y_train must have the same number of observations.") return X, Y @@ -422,22 +431,32 @@ def _check_X(self, X: np.ndarray) -> np.ndarray: return {"text": text, "categorical_variables": categorical_variables} def _check_Y(self, Y): - assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." - assert len(Y.shape) == 1 or ( - len(Y.shape) == 2 and Y.shape[1] == 1 - ), "Y must be a numpy array of shape (N,) or (N,1)." + if self.ragged_multilabel: + assert isinstance( + Y, list + ), "Y must be a list of lists for ragged multilabel classification." + for row in Y: + assert isinstance(row, list), "Each element of Y must be a list of labels." - try: - Y = Y.astype(int) - except ValueError: - logger.error("Y must be castable in integer format.") + return Y - if Y.max() >= self.num_classes or Y.min() < 0: - raise ValueError( - f"Y contains class labels outside the range [0, {self.num_classes - 1}]." - ) + else: + assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." + assert len(Y.shape) == 1 or ( + len(Y.shape) == 2 and Y.shape[1] == 1 + ), "Y must be a numpy array of shape (N,) or (N,1)." + + try: + Y = Y.astype(int) + except ValueError: + logger.error("Y must be castable in integer format.") + + if Y.max() >= self.num_classes or Y.min() < 0: + raise ValueError( + f"Y contains class labels outside the range [0, {self.num_classes - 1}]." + ) - return Y + return Y def predict( self, From 0b5c9839e04e80cbd53660f670e71cb723bb8e6a Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 15:42:49 +0000 Subject: [PATCH 3/7] doc: add multilabel support --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 95fa297..6772d79 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,12 @@ A unified, extensible framework for text classification with categorical variabl ## 🚀 Features -- **Mixed input support**: Handle text data alongside categorical variables seamlessly. +- **Complex input support**: Handle text data alongside categorical variables seamlessly. - **Unified yet highly customizable**: - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer. - Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` ! - The `TextClassificationModel` class combines these components and can be extended for custom behavior. +- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks. - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code: - The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you From 75cc435a2a75116ed91130227d5342b1be0d02ed Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 18:02:31 +0000 Subject: [PATCH 4/7] doc!: add a multilabel example --- notebooks/multilabel_classification.ipynb | 350 ++++++++++++++++++++++ 1 file changed, 350 insertions(+) create mode 100644 notebooks/multilabel_classification.ipynb diff --git a/notebooks/multilabel_classification.ipynb b/notebooks/multilabel_classification.ipynb new file mode 100644 index 0000000..f9af9a4 --- /dev/null +++ b/notebooks/multilabel_classification.ipynb @@ -0,0 +1,350 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Multilabel classification" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "In **multilabel classification**, each instance can be assigned multiple labels simultaneously. This is different from multiclass classification, where each instance is assigned to one and only one class from a set of classes.\n", + "\n", + "This notebook shows how to use torchTextClassifiers to perform multilabel classification." + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Ragged-lists approach" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers\n", + "from torchTextClassifiers.dataset import TextClassificationDataset\n", + "from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule\n", + "from torchTextClassifiers.model.components import (\n", + " AttentionConfig,\n", + " CategoricalVariableNet,\n", + " ClassificationHead,\n", + " TextEmbedder,\n", + " TextEmbedderConfig,\n", + ")\n", + "from torchTextClassifiers.tokenizers import HuggingFaceTokenizer\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "Let's use fake data.\n", + "\n", + "Look at `labels`: it is a list of lists, where each inner list contains the labels for the corresponding instance.\n", + "\n", + "We're indeed in a multilabel classification setting, where each instance can have multiple labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "sample_text_data = [\n", + " \"This is a positive example\",\n", + " \"This is a negative example\",\n", + " \"Another positive case\",\n", + " \"Another negative case\",\n", + " \"Good example here\",\n", + " \"Bad example here\",\n", + "]\n", + "\n", + "labels = [[0, 1, 5], [0, 4], [1, 5], [0, 1, 4], [1, 5], [0]]" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "Note that `labels` is not a nice object to manipulate: each inner list has different lengths. You can not convert it to a tensor or a numpy array directly.\n", + "\n", + "This is called a *jagged array* or *ragged array*.\n", + "\n", + "Yet, you do not need to change anything: torchTextClassifiers can handle this kind of data directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "labels = np.array(labels) # This does not work !" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "Let's import a pre-trained tokenizer from HuggingFace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = HuggingFaceTokenizer.load_from_pretrained(\n", + " \"google-bert/bert-base-uncased\", output_dim=126\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "And create our input numpy array." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "X = np.array(\n", + " sample_text_data\n", + ")\n", + "\n", + "print(X.shape)\n", + "\n", + "Y = labels # only for the sake of clarity, but it remains a ragged array here" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "We initialize a very simple model, no categorical features, no attention, just text input and multilabel output.\n", + "\n", + "In this setting, we advise to use `torch.nn.BCEWithLogitsLoss()` as loss function in the training config. \n", + "\n", + "Each label is treated as a separate (but not independent, because we output the joint prediction vector) binary classification problem (where we try to estimate the probability of inclusion), whereas in the default setting (multiclass classification) the model uses `torch.nn.CrossEntropyLoss()`, that implies a *competition* among classes.\n", + "\n", + "Note that we won't enforce this change of loss and if you do not specify it, the default loss (CrossEntropyLoss) will be used." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "embedding_dim = 96\n", + "n_layers = 2\n", + "n_head = 4\n", + "n_kv_head = n_head\n", + "sequence_len = tokenizer.output_dim\n", + "num_classes = max(max(label_list) for label_list in labels) + 1\n", + "\n", + "model_config = ModelConfig(\n", + " embedding_dim=embedding_dim,\n", + " num_classes=num_classes,\n", + ")\n", + "\n", + "training_config = TrainingConfig(\n", + " lr=1e-3,\n", + " batch_size=4,\n", + " num_epochs=1,\n", + " loss=torch.nn.BCEWithLogitsLoss(), # change the loss here\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "Here, do not forget to set `ragged_multilabel=True` !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "ttc = torchTextClassifiers(\n", + " tokenizer=tokenizer,\n", + " model_config=model_config,\n", + " ragged_multilabel=True, # This is key !\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "And you can train !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "ttc.train(\n", + " X_train=X,\n", + " y_train=Y,\n", + " training_config=training_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "What happens behind the hood, is that we efficiently convert your ragged lists of labels into a binary matrix, where each row corresponds to an instance and each column to a label. A value of 1 indicates the presence of a label for an instance, while 0 indicates its absence: **it is a one-hot version** of your ragged lists.\n", + "\n", + "You can have a look [here](../torchTextClassifiers/dataset/dataset.py#L85)." + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "## One-hot / multidimensional output approach" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "You can also choose to directly provide a one-hot / multidimensional array as labels.\n", + "\n", + "For each sample, you have a vector of size equal to the number of labels, with 1s and 0s indicating the presence or absence of each label - or float values between 0 and 1, indicating the ground truth probability of each label.\n", + "\n", + "You do not have ragged lists anymore: **set `ragged_multilabel=False`** in the ``ttc`` initialization (it is very important, otherwise it will interpret it as a bag of labels as previously ! - we will throw a warning if we detect that your labels are one-hot encoded while you set `ragged_multilabel=True`, but we won't enforce anything).\n", + "\n", + "Also, convert your labels to a numpy array - it is possible now !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "# We put 1s here, but it could be any float value (probabilities...)\n", + "labels = [[1., 1., 0., 0., 0., 1.],\n", + " [1., 0., 0., 0., 1., 0.],\n", + " [0., 1., 0., 0., 0., 1.],\n", + " [1., 1., 0., 0., 1., 0.],\n", + " [0., 1., 0., 0., 0., 1.],\n", + " [1., 0., 0., 0., 1., 0.]]\n", + "Y = np.array(labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "ttc = torchTextClassifiers(\n", + " tokenizer=tokenizer,\n", + " model_config=model_config,\n", + ") # We removed the ragged_multilabel flag here, it is False by default" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "ttc.train(\n", + " X_train=X,\n", + " y_train=Y,\n", + " training_config=training_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "24", + "metadata": {}, + "source": [ + "As discussed, you can also put probabilities in `labels`. \n", + "\n", + "In this case, once again, you can use:\n", + "\n", + "- `torch.nn.BCEWithLogitsLoss()` as loss function in the training config, if you are in a multilabel setting.\n", + "- `torch.nn.CrossEntropyLoss()` as loss function in the training config, if you are in a *soft* multiclass setting (i.e. each instance has only one label, but you provide probabilities instead of class indices). Normally, your ground truth probabilities should sum to 1 for each instance in this case.\n", + "\n", + "We won't enforce anything that PyTorch does not enforce, so make sure to choose the right loss function for your task." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d1987ef7c6a81fc48ea69541da1997dbf1c06380 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 18:02:52 +0000 Subject: [PATCH 5/7] fix: missing variable --- notebooks/example.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index cd32c28..6712468 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -430,9 +430,10 @@ "outputs": [], "source": [ "# test the TextEmbedder: it takes as input a tensor of token ids and outputs a tensor of embeddings\n", + "\n", "text_embedder_output = text_embedder(input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"])\n", "\n", - "print(\"TextEmbedder input: \", text_embedder_input.input_ids)\n", + "print(\"TextEmbedder input: \", batch[\"input_ids\"])\n", "print(\"TextEmbedder output shape: \", text_embedder_output.shape)" ] }, From c9c57900ae6c293c1367f01bfc2a5efb32489868 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 18:03:22 +0000 Subject: [PATCH 6/7] fix: problem of type with BCEwithLogitsLoss --- torchTextClassifiers/dataset/dataset.py | 2 +- torchTextClassifiers/model/lightning.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchTextClassifiers/dataset/dataset.py b/torchTextClassifiers/dataset/dataset.py index 4892f7c..c4f6a83 100644 --- a/torchTextClassifiers/dataset/dataset.py +++ b/torchTextClassifiers/dataset/dataset.py @@ -100,7 +100,7 @@ def collate_fn(self, batch): labels_tensor[rows, cols] = 1 else: - labels_tensor = torch.tensor(labels, dtype=torch.long) + labels_tensor = torch.tensor(labels) else: labels_tensor = None diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index e432082..ac94eff 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -76,6 +76,10 @@ def training_step(self, batch, batch_idx: int) -> torch.Tensor: targets = batch["labels"] outputs = self.forward(batch) + + if isinstance(self.loss, torch.nn.BCEWithLogitsLoss): + targets = targets.float() + loss = self.loss(outputs, targets) self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True) accuracy = self.accuracy_fn(outputs, targets) From b923c8d6e6cf1a6d36e70dc85e6bbe8ec62b2b91 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 24 Nov 2025 18:03:51 +0000 Subject: [PATCH 7/7] chore: X_val, y_val are optional --- torchTextClassifiers/torchTextClassifiers.py | 56 ++++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 0aa2070..fb14fcb 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -184,9 +184,9 @@ def train( self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray, training_config: TrainingConfig, + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None, verbose: bool = False, ) -> None: """Train the classifier using PyTorch Lightning. @@ -224,7 +224,14 @@ def train( """ # Input validation X_train, y_train = self._check_XY(X_train, y_train) - X_val, y_val = self._check_XY(X_val, y_val) + + if X_val is not None: + assert y_val is not None, "y_val must be provided if X_val is provided." + if y_val is not None: + assert X_val is not None, "X_val must be provided if y_val is provided." + + if X_val is not None and y_val is not None: + X_val, y_val = self._check_XY(X_val, y_val) if ( X_train["categorical_variables"] is not None @@ -277,40 +284,43 @@ def train( texts=X_train["text"], categorical_variables=X_train["categorical_variables"], # None if no cat vars tokenizer=self.tokenizer, - labels=y_train, - ragged_multilabel=self.ragged_multilabel, - ) - val_dataset = TextClassificationDataset( - texts=X_val["text"], - categorical_variables=X_val["categorical_variables"], # None if no cat vars - tokenizer=self.tokenizer, - labels=y_val, + labels=y_train.tolist(), ragged_multilabel=self.ragged_multilabel, ) - train_dataloader = train_dataset.create_dataloader( batch_size=training_config.batch_size, num_workers=training_config.num_workers, shuffle=True, **training_config.dataloader_params if training_config.dataloader_params else {}, ) - val_dataloader = val_dataset.create_dataloader( - batch_size=training_config.batch_size, - num_workers=training_config.num_workers, - shuffle=False, - **training_config.dataloader_params if training_config.dataloader_params else {}, - ) + + if X_val is not None and y_val is not None: + val_dataset = TextClassificationDataset( + texts=X_val["text"], + categorical_variables=X_val["categorical_variables"], # None if no cat vars + tokenizer=self.tokenizer, + labels=y_val, + ragged_multilabel=self.ragged_multilabel, + ) + val_dataloader = val_dataset.create_dataloader( + batch_size=training_config.batch_size, + num_workers=training_config.num_workers, + shuffle=False, + **training_config.dataloader_params if training_config.dataloader_params else {}, + ) + else: + val_dataloader = None # Setup trainer callbacks = [ ModelCheckpoint( - monitor="val_loss", + monitor="val_loss" if val_dataloader is not None else "train_loss", save_top_k=1, save_last=False, mode="min", ), EarlyStopping( - monitor="val_loss", + monitor="val_loss" if val_dataloader is not None else "train_loss", patience=training_config.patience_early_stopping, mode="min", ), @@ -442,9 +452,9 @@ def _check_Y(self, Y): else: assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)." - assert len(Y.shape) == 1 or ( - len(Y.shape) == 2 and Y.shape[1] == 1 - ), "Y must be a numpy array of shape (N,) or (N,1)." + assert ( + len(Y.shape) == 1 or len(Y.shape) == 2 + ), "Y must be a numpy array of shape (N,) or (N, num_labels)." try: Y = Y.astype(int)