Source code for shifthappens.data.imagenet

"""
This module provides functionality to access the predictions of a model for the ImageNet validation set.
Further, the predictions can be cached and loaded from the cache to reduce computational costs.
For path configuration please have a look at :py:module:shifthappens.config.
"""
import os
import shutil

import numpy as np
import torchvision.datasets as tv_datasets
import torchvision.transforms as tv_transforms

import shifthappens.config
import shifthappens.data.torch as sh_data_torch
from shifthappens.data.base import DataLoader


def _check_imagenet_folder():
    """
    Checks if path to the ImageNet validation set folder is defined, the folder
    exists and contains a thousand folders (one per class).
    """
    assert (
        shifthappens.config.imagenet_validation_path is not None
    ), "shifthappens.config.ImagenetValidationData path is not specified."
    assert os.path.exists(shifthappens.config.imagenet_validation_path), (
        "You have specified an incorrect path to the ImageNet validation set. "
        "Files not found at location specified in shifthappens.config.imagenet_validation_path."
    )
    assert (
        len(
            [
                d_
                for d_ in os.listdir(shifthappens.config.imagenet_validation_path)
                if os.path.isdir(
                    os.path.join(shifthappens.config.imagenet_validation_path, d_)
                )
            ]
        )
        == 1000
    ), f"{shifthappens.config.imagenet_validation_path} folder contains less folders than ImageNet classes."


[docs]def get_imagenet_validation_loader(max_batch_size=128) -> DataLoader: """ Creates a :py:class:`shifthappens.data.base.DataLoader` for the validation set of ImageNet. Note that the path to ImageNet validation set :py:attr:`shifthappens.config.imagenet_validation_path <shifthappens.config.Config.imagenet_validation_path>` must be specified. Args: max_batch_size: How many samples allowed per batch to load. Returns: ImageNet validation set data loader. """ _check_imagenet_folder() transform = tv_transforms.Compose( [ tv_transforms.ToTensor(), tv_transforms.Lambda(lambda x: x.permute(1, 2, 0)), ] ) imagenet_val_dataset = sh_data_torch.IndexedTorchDataset( sh_data_torch.ImagesOnlyTorchDataset( tv_datasets.ImageFolder( root=shifthappens.config.imagenet_validation_path, transform=transform ) ) ) imagenet_val_dataloader = DataLoader( imagenet_val_dataset, max_batch_size=max_batch_size ) return imagenet_val_dataloader
[docs]def get_cached_predictions(cls) -> dict: """ Checks whether there exist cached results for the model's class and if so, returns them. Note that the path to ImageNet validation set :py:attr:`shifthappens.config.imagenet_validation_path <shifthappens.config.Config.imagenet_validation_path>` must be specified. Args: cls: Model's class. Used for specifying folder name. Returns: Dictionary of loaded model predictions on ImageNet validation set. """ assert shifthappens.config.cache_directory_path is not None, ( "Cannot get cached model results. " "shifthappens.config.cache_directory_path is not specified." ) load_path = os.path.join( shifthappens.config.cache_directory_path, cls.__class__.__name__, "" ) assert os.path.exists( load_path ), f"Cannot get cached model results. {load_path} folder not found." assert ( len(os.listdir(load_path)) != 0 ), f"Cannot get cached model results. {load_path} folder is empty." result_dict = dict() for file in os.listdir(load_path): result = np.load(load_path + file) result_dict[file.rstrip(".npy")] = result return result_dict
[docs]def cache_predictions(cls, imagenet_validation_result): """ Caches model predictions in cls-named folder and load model predictions from it. Note that the path to ImageNet validation set :py:attr:`shifthappens.config.imagenet_validation_path <shifthappens.config.Config.imagenet_validation_path>` must be specified as well as :py:attr:`shifthappens.config.cache_directory_path <shifthappens.config.Config.cache_directory_path>`. Args: cls: Model's class. Used for specifying folder name. imagenet_validation_result (ModelResult): Model's prediction on ImageNet validation set. """ assert ( shifthappens.config.cache_directory_path is not None ), "Cannot cache model results. shifthappens.config.cache_directory_path is not specified." save_path = os.path.join( shifthappens.config.cache_directory_path, cls.__class__.__name__, "" ) if os.path.exists(save_path): shutil.rmtree(save_path) os.makedirs(save_path) for result_type in imagenet_validation_result.__slots__: result = getattr(imagenet_validation_result, str(result_type)) if result is not None: np.save(save_path + str(result_type), result)
[docs]def is_cached(cls) -> bool: """ Checks if model's results are cached in cls-named folder. Note that the path to the ImageNet validation set :py:attr:`shifthappens.config.imagenet_validation_path <shifthappens.config.Config.imagenet_validation_path>` must be specified as well as :py:attr:`shifthappens.config.cache_directory_path <shifthappens.config.Config.cache_directory_path>`. Args: cls: Model's class. Used for specifying folder name. Returns: ``True`` if model's results are cached, ``False`` otherwise. """ assert shifthappens.config.cache_directory_path is not None, ( "Cannot find cached model results. " "shifthappens.config.cache_directory_path path is not specified. " ) load_path = os.path.join( shifthappens.config.cache_directory_path, cls.__class__.__name__, "" ) try: cached_files = os.listdir(load_path) except FileNotFoundError: print(f"There is no cached model results on ImageNet at {load_path}.") return False return True
[docs]def load_imagenet_targets() -> np.ndarray: """ Returns the ground-truth labels of the ImageNet valdation set. """ _check_imagenet_folder() return tv_datasets.ImageFolder( root=shifthappens.config.imagenet_validation_path ).targets