From 81689ee40c2bb4719fb865c4fe7b4505ccac9044 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 25 Nov 2025 15:17:39 +0000 Subject: [PATCH] fix(ClassifHead): enable Sequential as custom net --- .../model/components/classification_head.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/torchTextClassifiers/model/components/classification_head.py b/torchTextClassifiers/model/components/classification_head.py index 4297d27..1c0d506 100644 --- a/torchTextClassifiers/model/components/classification_head.py +++ b/torchTextClassifiers/model/components/classification_head.py @@ -11,11 +11,43 @@ def __init__( num_classes: Optional[int] = None, net: Optional[nn.Module] = None, ): + """ + Classification head for text classification tasks. + It is a nn.Module that can either be a simple Linear layer or a custom neural network module. + + Args: + input_dim (int, optional): Dimension of the input features. Required if net is not provided. + num_classes (int, optional): Number of output classes. Required if net is not provided. + net (nn.Module, optional): Custom neural network module to be used as the classification head. + If provided, input_dim and num_classes are inferred from this module. + Should be either an nn.Sequential with first and last layers being Linears or nn.Linear. + """ super().__init__() if net is not None: - self.net = net - self.input_dim = net.in_features - self.num_classes = net.out_features + # --- Custom net should either be a Sequential or a Linear --- + if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)): + raise ValueError("net must be an nn.Sequential when provided.") + + # --- If Sequential, Check first and last layers are Linear --- + + if isinstance(net, nn.Sequential): + first = net[0] + last = net[-1] + + if not isinstance(first, nn.Linear): + raise TypeError(f"First layer must be nn.Linear, got {type(first).__name__}.") + + if not isinstance(last, nn.Linear): + raise TypeError(f"Last layer must be nn.Linear, got {type(last).__name__}.") + + # --- Extract features --- + self.input_dim = first.in_features + self.num_classes = last.out_features + self.net = net + else: # if not Sequential, it is a Linear + self.input_dim = net.in_features + self.num_classes = net.out_features + else: assert ( input_dim is not None and num_classes is not None