From 79b946525d0c285c0a4783d21069b02747d6848c Mon Sep 17 00:00:00 2001 From: 4r14nn4 Date: Thu, 27 Nov 2025 17:47:29 +0000 Subject: [PATCH 1/2] Add celeba (WIP) --- conceptarium/conf/dataset/_commons.yaml | 2 + conceptarium/conf/dataset/celeba.yaml | 18 ++ torch_concepts/data/backbone.py | 72 +++++- torch_concepts/data/datamodules/celeba.py | 83 ++++++ torch_concepts/data/datasets/__init__.py | 2 + torch_concepts/data/datasets/celeba.py | 293 +++++++++++++++++----- torch_concepts/data/splitters/__init__.py | 2 + torch_concepts/data/splitters/standard.py | 103 ++++++++ 8 files changed, 508 insertions(+), 67 deletions(-) create mode 100644 conceptarium/conf/dataset/celeba.yaml create mode 100644 torch_concepts/data/datamodules/celeba.py create mode 100644 torch_concepts/data/splitters/standard.py diff --git a/conceptarium/conf/dataset/_commons.yaml b/conceptarium/conf/dataset/_commons.yaml index a27a2e2..9562a9a 100644 --- a/conceptarium/conf/dataset/_commons.yaml +++ b/conceptarium/conf/dataset/_commons.yaml @@ -1,5 +1,7 @@ batch_size: 512 +seed: ${seed} + val_size: 0.1 test_size: 0.2 diff --git a/conceptarium/conf/dataset/celeba.yaml b/conceptarium/conf/dataset/celeba.yaml new file mode 100644 index 0000000..0944d83 --- /dev/null +++ b/conceptarium/conf/dataset/celeba.yaml @@ -0,0 +1,18 @@ +defaults: + - _commons + - _self_ + +_target_: torch_concepts.data.datamodules.celeba.CelebADataModule + +name: celeba + +backbone: resnet18 +precompute_embs: true +force_recompute: false + +# Task label - which CelebA attribute to predict +task_label: [Attractive] + +# all CelebA attributes are binary facial features +# label_descriptions can be added here if needed +label_descriptions: null \ No newline at end of file diff --git a/torch_concepts/data/backbone.py b/torch_concepts/data/backbone.py index 86ec3c5..de34df7 100644 --- a/torch_concepts/data/backbone.py +++ b/torch_concepts/data/backbone.py @@ -12,9 +12,57 @@ logger = logging.getLogger(__name__) +def choose_backbone(name: str): + """Choose a backbone model by name. + + Args: + name (str): Name of the backbone model (e.g., 'resnet18', 'vit_b_16'). + + Returns: + tuple: (backbone model, transforms) - The backbone model and its preprocessing transforms. + + Raises: + ValueError: If the backbone name is not recognized. + + Example: + >>> backbone, transforms = choose_backbone('resnet18') + >>> print(backbone) + ResNet(...) + """ + from torchvision.models import ( + resnet18, resnet50, vit_b_16, vit_l_16, + ResNet18_Weights, ResNet50_Weights, + ViT_B_16_Weights, ViT_L_16_Weights + ) + + if name == 'resnet18': + weights = ResNet18_Weights.DEFAULT + model = resnet18(weights=weights) + transforms = weights.transforms() + backbone = nn.Sequential(*list(model.children())[:-1]) # Remove final FC layer + elif name == 'resnet50': + weights = ResNet50_Weights.DEFAULT + model = resnet50(weights=weights) + transforms = weights.transforms() + backbone = nn.Sequential(*list(model.children())[:-1]) + elif name == 'vit_b_16': + weights = ViT_B_16_Weights.DEFAULT + model = vit_b_16(weights=weights) + transforms = weights.transforms() + backbone = nn.Sequential(*list(model.children())[:-1]) + elif name == 'vit_l_16': + weights = ViT_L_16_Weights.DEFAULT + model = vit_l_16(weights=weights) + transforms = weights.transforms() + backbone = nn.Sequential(*list(model.children())[:-1]) + else: + raise ValueError(f"Backbone '{name}' is not recognized.") + + return backbone, transforms + def compute_backbone_embs( dataset, - backbone: nn.Module, + backbone: str, batch_size: int = 512, workers: int = 0, device: str = None, @@ -28,7 +76,7 @@ def compute_backbone_embs( Args: dataset: Dataset with __getitem__ returning dict with 'x' key or 'inputs'.'x' nested key. - backbone (nn.Module): Feature extraction model (e.g., ResNet encoder). + backbone (str): Backbone model name for feature extraction (e.g., 'resnet18'). batch_size (int, optional): Batch size for processing. Defaults to 512. workers (int, optional): Number of DataLoader workers. Defaults to 0. device (str, optional): Device to use ('cpu', 'cuda', 'cuda:0', etc.). @@ -52,11 +100,12 @@ def compute_backbone_embs( device = torch.device(device) # Store original training state to restore later - was_training = backbone.training + #was_training = backbone.training # Move backbone to device and set to eval mode - backbone = backbone.to(device) - backbone.eval() + backbone_model, transforms = choose_backbone(backbone) + backbone_model = backbone_model.to(device) + backbone_model.eval() # Create dataloader dataloader = DataLoader( @@ -78,20 +127,21 @@ def compute_backbone_embs( x = batch['inputs']['x'].to(device) else: x = batch['x'].to(device) - embeddings = backbone(x) # Forward pass through backbone + + embeddings = backbone_model(transforms(x)) # Forward pass through backbone embeddings_list.append(embeddings.cpu()) # Move back to CPU and store all_embeddings = torch.cat(embeddings_list, dim=0) # Concatenate all embeddings # Restore original training state - if was_training: - backbone.train() + #if was_training: + # backbone.train() return all_embeddings def get_backbone_embs(path: str, dataset, - backbone, + backbone: str, batch_size, force_recompute=False, workers=0, @@ -105,7 +155,7 @@ def get_backbone_embs(path: str, Args: path (str): File path for saving/loading embeddings (.pt file). dataset: Dataset to extract embeddings from. - backbone: Backbone model for feature extraction. + backbone: Backbone model name for feature extraction. batch_size: Batch size for computation. force_recompute (bool, optional): Recompute even if cached. Defaults to False. workers (int, optional): Number of DataLoader workers. Defaults to 0. @@ -130,7 +180,7 @@ def get_backbone_embs(path: str, if not os.path.exists(path) or force_recompute: # compute embs = compute_backbone_embs(dataset, - backbone, + backbone=backbone, batch_size=batch_size, workers=workers, device=device, diff --git a/torch_concepts/data/datamodules/celeba.py b/torch_concepts/data/datamodules/celeba.py new file mode 100644 index 0000000..1d56855 --- /dev/null +++ b/torch_concepts/data/datamodules/celeba.py @@ -0,0 +1,83 @@ +from ..datasets import CelebADataset + +from ..base.datamodule import ConceptDataModule +from ...typing import BackboneType +from ..splitters import StandardSplitter, RandomSplitter + + +class CelebADataModule(ConceptDataModule): + """DataModule for CelebA dataset. + + Handles data loading, splitting, and batching for CelebA dataset + with support for concept-based learning. + + Args: + seed: Random seed for reproducibility. + name: Dataset identifier (default: 'celeba'). + split: Dataset split to use ('train', 'valid', or 'test'). + val_size: Validation set size (fraction or absolute count). + test_size: Test set size (fraction or absolute count). + ftune_size: Fine-tuning set size (fraction or absolute count). + ftune_val_size: Fine-tuning validation set size (fraction or absolute count). + batch_size: Batch size for dataloaders. + download: Whether to download the dataset if not present. + task_label: List of attributes to use as task labels. + concept_subset: Subset of concepts to use. If None, uses all concepts. + label_descriptions: Dictionary mapping concept names to descriptions. + backbone: Model backbone to use (if applicable). + workers: Number of workers for dataloaders. + DATA_ROOT: Root directory for data storage. + """ + + def __init__( + self, + seed: int, # seed for reproducibility + name: str, # dataset identifier + root: str, # root directory for dataset + val_size: int | float = 0.1, + test_size: int | float = 0.2, + ftune_size: int | float = 0.0, + ftune_val_size: int | float = 0.0, + batch_size: int = 512, + backbone: BackboneType = None, + precompute_embs: bool = True, + force_recompute: bool = False, + task_label: list | None = None, + concept_subset: list | None = None, + label_descriptions: dict | None = None, + splitter: str = "standard", + workers: int = 0, + DATA_ROOT = None, + **kwargs + ): + + dataset = CelebADataset( + name=name, + root=root, + transform=None, + task_label=task_label, + class_attributes=task_label, + concept_subset=concept_subset, + label_descriptions=label_descriptions + ) + + # check configura + if splitter== "standard": + splitter = StandardSplitter() + else: + splitter = RandomSplitter( + val_size=val_size, + test_size=test_size + ) + + super().__init__( + dataset=dataset, + val_size=val_size, + test_size=test_size, + batch_size=batch_size, + backbone=backbone, + precompute_embs=precompute_embs, + force_recompute=force_recompute, + workers=workers, + splitter=splitter + ) diff --git a/torch_concepts/data/datasets/__init__.py b/torch_concepts/data/datasets/__init__.py index d194122..2696e22 100644 --- a/torch_concepts/data/datasets/__init__.py +++ b/torch_concepts/data/datasets/__init__.py @@ -1,9 +1,11 @@ from .bnlearn import BnLearnDataset from .toy import ToyDataset, CompletenessDataset +from .celeba import CelebADataset __all__: list[str] = [ "BnLearnDataset", "ToyDataset", "CompletenessDataset", + "CelebADataset", ] diff --git a/torch_concepts/data/datasets/celeba.py b/torch_concepts/data/datasets/celeba.py index 0273e73..635e01c 100644 --- a/torch_concepts/data/datasets/celeba.py +++ b/torch_concepts/data/datasets/celeba.py @@ -1,72 +1,253 @@ - +import os import torch +import pandas as pd +import numpy as np +import logging +from typing import List, Optional, Union +from tqdm import tqdm +from datasets import load_dataset +from torchvision.transforms import Compose +from torch_concepts import Annotations, AxisAnnotation +from torch_concepts.data.base import ConceptDataset +import torchvision.transforms as T +from glob import glob +import numpy as np +from PIL import Image -from torchvision.datasets import CelebA -from typing import List - - -class CelebADataset(CelebA): - """ - The CelebA dataset is a large-scale face attributes dataset with more than - 200K celebrity images, each with 40 attribute annotations. This class - extends the CelebA dataset to extract concept and task attributes based on - class attributes. - - The dataset can be downloaded from the official - website: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. +logger = logging.getLogger(__name__) - Attributes: - root: The root directory where the dataset is stored. - split: The split of the dataset to use. Default is 'train'. +class CelebADataset(ConceptDataset): + """Dataset class for CelebA. + + CelebA is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. + This class wraps torchvision's CelebA dataset to work with the ConceptDataset framework. + The dataset can be downloaded from the official website: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. + + Args: + root: Root directory where the dataset is stored or will be downloaded. + split: The split of the dataset to use ('train', 'valid', or 'test'). Default is 'train'. transform: The transformations to apply to the images. Default is None. - download: Whether to download the dataset if it does not exist. Default - is False. - class_attributes: The class attributes to use for the task. Default is - None. + download: Whether to download the dataset if it does not exist. Default is False. + task_label: The attribute(s) to use for the task. Default is 'Attractive'. + concept_subset: Optional subset of concept labels to use. + label_descriptions: Optional dict mapping concept names to descriptions. """ + def __init__( - self, root: str, split: str = 'train', - transform = None, - download: bool = False, - class_attributes: List[str] = None, + self, + name: str, + root: str, + transform: Union[Compose, torch.nn.Module] = None, + task_label: Optional[List[str]] = None, + class_attributes: Optional[List[str]] = None, # Alias for task_label + concept_subset: Optional[list] = None, + label_descriptions: Optional[dict] = None, ): - super(CelebADataset, self).__init__( - root, - split=split, - target_type="attr", - transform=transform, - download=download, - ) + self.name = name + self.transform = transform - # Set the class attributes - if class_attributes is None: - # Default to 'Attractive' if no class_attributes provided - self.class_idx = [self.attr_names.index('Attractive')] + # If root is not provided, create a local folder automatically + if root is None: + root = os.path.join(os.getcwd(), 'data', self.name) + + self.root = root + + # Support both task_label and class_attributes (class_attributes takes precedence) + if class_attributes is not None: + self.task_label = class_attributes if isinstance(class_attributes, list) else [class_attributes] + elif task_label is not None: + self.task_label = task_label if isinstance(task_label, list) else [task_label] else: - # Use the provided class attributes - self.class_idx = [ - self.attr_names.index(attr) for attr in class_attributes - ] + self.task_label = ['Attractive'] + + self.label_descriptions = label_descriptions + + # These will be set during build/load + self.concept_attr_names = [] + self.task_attr_names = [] + + # Load data and annotations + input_data, concepts, annotations, graph = self.load() + + # Initialize parent class + super().__init__( + input_data=input_data, + concepts=concepts, + annotations=annotations, + graph=graph, + concept_names_subset=concept_subset, + ) + + @property + def raw_filenames(self) -> List[str]: + """List of raw filenames that must be present to skip downloading.""" + # find the directory of downloaded data (if any) + path_base_file = os.path.join(self.root, "flwrlabs___celeba", "*", "*", "*", "dataset_info.json") + matches = glob(path_base_file) + if len(matches)==0: + return ["__nonexistent_file__"] + d = os.path.dirname(matches[0]) + + # eliminate self.root (it is added by default later). + d = d.replace(self.root +"/", "") + base_file = matches[0].replace(self.root +"/", "") + + n_train_files = 19 + n_valid_files = 3 + n_test_files = 3 + + train_files = [] + valid_files = [] + test_files = [] + + for i in range(n_train_files): + if i<10: + train_files.append(os.path.join(d, f"celeba-train-0000{i}-of-000{n_train_files}.arrow")) + else: + train_files.append(os.path.join(d, f"celeba-train-000{i}-of-000{n_train_files}.arrow")) + for i in range(n_valid_files): + valid_files.append(os.path.join(d, f"celeba-valid-0000{i}-of-0000{n_valid_files}.arrow")) + for i in range(n_test_files): + test_files.append(os.path.join(d, f"celeba-test-0000{i}-of-0000{n_test_files}.arrow")) - self.attr_names = [string for string in self.attr_names if string] + return [base_file] + train_files + valid_files + test_files - # Determine concept and task attribute names based on class attributes - self.concept_attr_names = [ - attr for i, attr in enumerate(self.attr_names) - if i not in self.class_idx + + @property + def processed_filenames(self) -> List[str]: + """List of processed filenames that will be created during build step.""" + return [ + f"images.pt", + f"concepts.h5", + "annotations.pt", + f"split_mapping.h5", ] - self.task_attr_names = [self.attr_names[i] for i in self.class_idx] + + def download(self): + """Download raw data files from HuggingFace and save to root directory.""" + logger.info(f"Downloading CelebA dataset from HuggingFace...") + load_dataset( + "flwrlabs/celeba", + cache_dir=self.root + ) + logger.info(f"CelebA dataset downloaded and saved to {self.root}.") + + def build(self): + """Build processed dataset from all splits (train, valid, test) concatenated.""" + self.maybe_download() + + # --- Load data --- + logger.info(f"Building CelebA dataset from raw files in {self.root_dir}...") + ds = load_dataset(path=self.root) + + # --- Construct input_data, concepts, annotations --- + # construct concept_names + concept_names = list(ds['train'].features.keys()) + concept_names.remove('image') + concept_names.remove('celeb_id') + + # process each split + split_indices = [] + all_images = [] + all_concepts_df = pd.DataFrame() + for split_name in ['train','validation', 'test']: + split_dataset = ds[split_name] + corrupted_images = [] + # extract images + for idx in tqdm(range(len(split_dataset)), desc=f" {split_name} images", unit="img"): + try: + img = split_dataset[idx]['image'] + arr = np.array(img) + all_images.append(torch.from_numpy(arr)) + except OSError as e: + logger.warning(f"Skipping image at index {idx} in split {split_name} due to error: {e}") + corrupted_images.append(idx) + + if len(corrupted_images)!=0: + logger.warning(f"Skipping {len(corrupted_images)} corrupted images in split {split_name}.") + #remove corrupted indices from split_dataset + split_dataset = split_dataset.select([i for i in range(len(split_dataset)) if i not in corrupted_images]) + + # extract split_indices + split_indices.extend([split_name] * len(split_dataset)) + + # extract concepts for this split + split_concepts = split_dataset.to_pandas()[concept_names] + all_concepts_df = pd.concat([all_concepts_df, split_concepts], ignore_index=True) + + + assert all_concepts_df.columns.tolist() == concept_names, "Concept names do not match." + + + # combine all data + input_data = torch.stack(all_images) + logger.info(f"Input data shape: {input_data.shape}") + concepts = all_concepts_df.astype(int) + logger.info(f"Concepts shape: {concepts.shape}") + + # create annotations + cardinalities = tuple([2] * len(concept_names)) + annotations = Annotations({ + 1: AxisAnnotation( + labels=concept_names, + cardinalities=cardinalities, + metadata={name: {'type': 'discrete'} for name in concept_names} + ) + }) + + # --- Save processed data --- + logger.info(f"Saving concepts, annotations and split mapping to {self.root}") + torch.save(input_data, self.processed_paths[0]) + concepts.to_hdf(self.processed_paths[1], key="concepts", mode="w") + torch.save(annotations, self.processed_paths[2]) + pd.Series(split_indices).to_hdf(self.processed_paths[4], key="split_mapping", mode="w") + + def load_raw(self): + """Load raw processed files for the current split.""" + self.maybe_build() # Ensures build() is called if needed + + logger.info(f"Loading dataset from {self.root_dir}") + input_data = torch.load(self.processed_paths[0]) + concepts = pd.read_hdf(self.processed_paths[1], "concepts") + annotations = torch.load(self.processed_paths[2]) + graph = None + + return input_data, concepts, annotations, graph + + def load(self): + """Load and optionally preprocess dataset.""" + inputs, concepts, annotations, graph = self.load_raw() + + # Add any additional preprocessing here if needed + # For most cases, just return raw data + + return inputs, concepts, annotations, graph + + def __getitem__(self, item): + """ + Get a single sample from the dataset. + + Args: + item (int): Index of the sample to retrieve. - def __getitem__(self, index: int): - image, attributes = super(CelebADataset, self).__getitem__(index) + Returns: + dict: Dictionary containing 'inputs' and 'concepts' sub-dictionaries. + """ + # Get raw input data and concepts + x = self.input_data[item] + x = x.permute(2,0,1).float() / 255.0 + c = self.concepts[item] - # Extract the target (y) based on the class index - y = torch.stack([attributes[i] for i in self.class_idx]) + # TODO: handle missing values with masks - # Extract concept attributes, excluding the class attributes - concept_attributes = torch.stack([ - attributes[i] for i in range(len(attributes)) - if i not in self.class_idx - ]) + # Create sample dictionary + sample = { + 'inputs': {'x': x}, # input data: multiple inputs can be stored in a dict + 'concepts': {'c': c}, # concepts: multiple concepts can be stored in a dict + # TODO: add scalers when these are set + # also check if batch transforms work correctly inside the model training loop + # 'transforms': {'x': self.scalers.get('input', None), + # 'c': self.scalers.get('concepts', None)} + } - return image, concept_attributes, y + return sample \ No newline at end of file diff --git a/torch_concepts/data/splitters/__init__.py b/torch_concepts/data/splitters/__init__.py index 6d68c58..ddaca24 100644 --- a/torch_concepts/data/splitters/__init__.py +++ b/torch_concepts/data/splitters/__init__.py @@ -1,8 +1,10 @@ from .random import RandomSplitter from .coloring import ColoringSplitter +from .standard import StandardSplitter __all__: list[str] = [ "RandomSplitter", "ColoringSplitter", + "StandardSplitter", ] diff --git a/torch_concepts/data/splitters/standard.py b/torch_concepts/data/splitters/standard.py new file mode 100644 index 0000000..0095517 --- /dev/null +++ b/torch_concepts/data/splitters/standard.py @@ -0,0 +1,103 @@ +"""Standard data splitting for train/validation/test splits. + +This module provides StandardSplitter for dividing datasets into +standard train/val/test splits provided by the dataset authors. +""" + +from typing import Union +import numpy as np +import pandas as pd +import logging +logger = logging.getLogger(__name__) + +from ..utils import resolve_size +from ..base.dataset import ConceptDataset +from ..base.splitter import Splitter + +class StandardSplitter(Splitter): + """Standard splitting strategy for datasets. + + Divides a dataset into train, validation, and test splits based on + standard splits provided by the dataset authors. + Ensures reproducibility when numpy's random seed is set externally + before calling fit(). + + The splitting is done in the following order: + 1. Test (if test_size > 0) + 2. Validation (if val_size > 0) + 3. Training (remaining samples) + + Args: + val_size (Union[int, float], optional): Size of validation set. + If float, represents fraction of dataset. If int, represents + absolute number of samples. Defaults to 0.1. + test_size (Union[int, float], optional): Size of test set. + If float, represents fraction of dataset. If int, represents + absolute number of samples. Defaults to 0.2. + + Example: + >>> # 70% train, 10% val, 20% test + >>> splitter = StandardSplitter(val_size=0.1, test_size=0.2) + >>> splitter.fit(dataset) + >>> print(f"Train: {splitter.train_len}, Val: {splitter.val_len}, Test: {splitter.test_len}") + Train: 700, Val: 100, Test: 200 + """ + + def __init__( + self + ): + """Initialize the StandardSplitter. + + Args: + val_size: Size of validation set. If float, represents fraction + of dataset. If int, represents absolute number of samples. + Defaults to 0.1. + test_size: Size of test set. If float, represents fraction + of dataset. If int, represents absolute number of samples. + Defaults to 0.2. + """ + super().__init__() + + def fit(self, dataset: ConceptDataset) -> None: + """Split the dataset into train/val/test sets based on standard splits. + + Args: + dataset: The ConceptDataset to split. + + Raises: + ValueError: If the dataset does not provide standard splits. + """ + + # Load standard splits from dataset if available + if any("split_mapping" in path for path in dataset.processed_paths): + split_series = pd.read_hdf( + next(path for path in dataset.processed_paths if "split_mapping" in path), key="split_mapping" + ) + train_idxs = split_series[split_series == "train"].index.tolist() + val_idxs = split_series[split_series == "val"].index.tolist() + test_idxs = split_series[split_series == "test"].index.tolist() + + # Store indices + self.set_indices( + train=train_idxs, + val=val_idxs, + test=test_idxs + ) + + self._fitted = True + + logger.info(f"Attention StandardSplitter uses predefined splits provided by the dataset authors." + f"Train size: {self.train_len}, " + f"Val size: {self.val_len}, " + f"Test size: {self.test_len}") + else: + raise ValueError("Dataset does not provide standard splits.") + + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"train_size={self.train_len}, " + f"val_size={self.val_len}, " + f"test_size={self.test_len})" + ) From cf2a096eeadcdc4752edf66fd46c520bc8725c06 Mon Sep 17 00:00:00 2001 From: edogab33 Date: Fri, 28 Nov 2025 18:23:16 +0100 Subject: [PATCH 2/2] Clean backbone code --- torch_concepts/data/backbone.py | 74 ++++----------------------------- 1 file changed, 7 insertions(+), 67 deletions(-) diff --git a/torch_concepts/data/backbone.py b/torch_concepts/data/backbone.py index de34df7..a1fea4e 100644 --- a/torch_concepts/data/backbone.py +++ b/torch_concepts/data/backbone.py @@ -6,59 +6,12 @@ import os import torch import logging -from torch import nn from torch.utils.data import DataLoader +from torchvision.models import get_model, get_model_weights from tqdm import tqdm -logger = logging.getLogger(__name__) -def choose_backbone(name: str): - """Choose a backbone model by name. - - Args: - name (str): Name of the backbone model (e.g., 'resnet18', 'vit_b_16'). - - Returns: - tuple: (backbone model, transforms) - The backbone model and its preprocessing transforms. - - Raises: - ValueError: If the backbone name is not recognized. - - Example: - >>> backbone, transforms = choose_backbone('resnet18') - >>> print(backbone) - ResNet(...) - """ - from torchvision.models import ( - resnet18, resnet50, vit_b_16, vit_l_16, - ResNet18_Weights, ResNet50_Weights, - ViT_B_16_Weights, ViT_L_16_Weights - ) - - if name == 'resnet18': - weights = ResNet18_Weights.DEFAULT - model = resnet18(weights=weights) - transforms = weights.transforms() - backbone = nn.Sequential(*list(model.children())[:-1]) # Remove final FC layer - elif name == 'resnet50': - weights = ResNet50_Weights.DEFAULT - model = resnet50(weights=weights) - transforms = weights.transforms() - backbone = nn.Sequential(*list(model.children())[:-1]) - elif name == 'vit_b_16': - weights = ViT_B_16_Weights.DEFAULT - model = vit_b_16(weights=weights) - transforms = weights.transforms() - backbone = nn.Sequential(*list(model.children())[:-1]) - elif name == 'vit_l_16': - weights = ViT_L_16_Weights.DEFAULT - model = vit_l_16(weights=weights) - transforms = weights.transforms() - backbone = nn.Sequential(*list(model.children())[:-1]) - else: - raise ValueError(f"Backbone '{name}' is not recognized.") - - return backbone, transforms +logger = logging.getLogger(__name__) def compute_backbone_embs( dataset, @@ -99,13 +52,9 @@ def compute_backbone_embs( device = 'cuda' if torch.cuda.is_available() else 'cpu' device = torch.device(device) - # Store original training state to restore later - #was_training = backbone.training - - # Move backbone to device and set to eval mode - backbone_model, transforms = choose_backbone(backbone) - backbone_model = backbone_model.to(device) - backbone_model.eval() + backbone_model = get_model(backbone, weights="DEFAULT").to(device).eval() # "DEFAULT" points to best available weights + weights = get_model_weights(backbone, weights="DEFAULT") + preprocess = weights.transforms() # Create dataloader dataloader = DataLoader( @@ -122,21 +71,12 @@ def compute_backbone_embs( with torch.no_grad(): iterator = tqdm(dataloader, desc="Extracting embeddings") if verbose else dataloader for batch in iterator: - # Handle both {'x': tensor} and {'inputs': {'x': tensor}} structures - if 'inputs' in batch: - x = batch['inputs']['x'].to(device) - else: - x = batch['x'].to(device) - - embeddings = backbone_model(transforms(x)) # Forward pass through backbone + x = batch['inputs']['x'].to(device) + embeddings = backbone_model(preprocess(x)) # Forward pass through backbone embeddings_list.append(embeddings.cpu()) # Move back to CPU and store all_embeddings = torch.cat(embeddings_list, dim=0) # Concatenate all embeddings - # Restore original training state - #if was_training: - # backbone.train() - return all_embeddings def get_backbone_embs(path: str,