diff --git a/torchTextClassifiers/model/components/classification_head.py b/torchTextClassifiers/model/components/classification_head.py index 1c0d506..bd21f52 100644 --- a/torchTextClassifiers/model/components/classification_head.py +++ b/torchTextClassifiers/model/components/classification_head.py @@ -24,6 +24,8 @@ def __init__( """ super().__init__() if net is not None: + self.net = net + # --- Custom net should either be a Sequential or a Linear --- if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)): raise ValueError("net must be an nn.Sequential when provided.") @@ -43,7 +45,6 @@ def __init__( # --- Extract features --- self.input_dim = first.in_features self.num_classes = last.out_features - self.net = net else: # if not Sequential, it is a Linear self.input_dim = net.in_features self.num_classes = net.out_features @@ -53,23 +54,8 @@ def __init__( input_dim is not None and num_classes is not None ), "Either net or both input_dim and num_classes must be provided." self.net = nn.Linear(input_dim, num_classes) - self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net) + self.input_dim = input_dim + self.num_classes = num_classes def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) - - @staticmethod - def _get_linear_input_output_dims(module: nn.Module): - """ - Returns (input_dim, output_dim) for any module containing Linear layers. - Works for Linear, Sequential, or nested models. - """ - # Collect all Linear layers recursively - linears = [m for m in module.modules() if isinstance(m, nn.Linear)] - - if not linears: - raise ValueError("No Linear layers found in the given module.") - - input_dim = linears[0].in_features - output_dim = linears[-1].out_features - return input_dim, output_dim