Source code for shifthappens.models.torchvision

"""Model baselines from torchvision."""

from typing import Iterator
from typing import List

import numpy as np
import torch
import torchvision
from surgeon_pytorch import Inspect
from torch import nn
from torchvision.transforms import functional as tv_functional

import shifthappens.models.base as sh_models
import shifthappens.models.mixins as sh_mixins
from shifthappens.data.base import DataLoader


[docs]class __TorchvisionPreProcessingMixin: """Performs the default preprocessing for torchvision ImageNet classifiers.""" def _pre_process(self, batch: List[np.ndarray], device: str) -> torch.Tensor: inputs = [] for item in batch: assert isinstance(item, np.ndarray) item_t = torch.tensor(item.transpose((2, 0, 1))) item_t = tv_functional.resize(item_t, 256, antialias=True) item_t = tv_functional.center_crop(item_t, 224) item_t = tv_functional.normalize( item_t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) inputs.append(item_t) inputs_t = torch.stack(inputs, 0) inputs_t = inputs_t.to(device) return inputs_t
[docs]class __TorchvisionModel( sh_models.Model, __TorchvisionPreProcessingMixin, sh_mixins.LabelModelMixin, sh_mixins.ConfidenceModelMixin, sh_mixins.FeaturesModelMixin, sh_mixins.OODScoreModelMixin, ): """Wraps a torchvision model. Args: model: Pretrained torchvision model. max_batch_size: How many samples allowed per batch to load. feature_layer: Name of the layer that shall be used to extract features/representations.. device: Selected device to run the model on. """ def __init__( self, model: nn.Module, feature_layer: str, max_batch_size: int, device: str = "cpu", ): super().__init__() assert not issubclass( type(model), torch.nn.DataParallel ), "Parallel models are not yet supported" self.model = model self.max_batch_size = max_batch_size self.device = device self.model.to(self.device) self.hooked_model = Inspect(self.model, layer=feature_layer) self.model.eval() @torch.no_grad() def _predict( self, input_dataloader: DataLoader, targets: sh_models.PredictionTargets ) -> Iterator[sh_models.ModelResult]: for batch in input_dataloader.iterate(self.max_batch_size): # pre-process batch inputs = self._pre_process(batch, self.device) logits, features = self.hooked_model(inputs) features = features.view(len(features), -1) logits, features = logits.cpu(), features.cpu() probabilities = torch.softmax(logits, -1) max_confidences, predictions = probabilities.max(1) yield sh_models.ModelResult( class_labels=predictions.numpy(), confidences=probabilities.numpy(), ood_scores=max_confidences.numpy(), features=features.numpy(), )
[docs]class ResNet18(__TorchvisionModel): """ResNet18 network trained on the ImageNet 2012 train set from torchvision. See :py:func:`torchvision.models.resnet18` for details. Args: max_batch_size: How many samples allowed per batch to load. device: Selected device to run the model on. """ def __init__( self, max_batch_size: int = 16, device: str = "cpu", ): model = torchvision.models.resnet18(pretrained=True) feature_layer = "avgpool" super().__init__(model, feature_layer, max_batch_size, device)
[docs]class ResNet50(__TorchvisionModel): """Load a ResNet50 network trained on the ImageNet 2012 train set from torchvision. See :py:func:`torchvision.models.resnet50` for details. Args: max_batch_size: How many samples allowed per batch to load. device: Selected device to run the model on. """ def __init__( self, max_batch_size: int = 16, device: str = "cpu", ): model = torchvision.models.resnet50(pretrained=True) feature_layer = "avgpool" super().__init__(model, feature_layer, max_batch_size, device)
[docs]class VGG16(__TorchvisionModel): """Load a VGG16 network trained on the ImageNet 2012 train set from torchvision. See :py:func:`torchvision.models.vgg16` for details. Args: max_batch_size: How many samples allowed per batch to load. device: Selected device to run the model on. """ def __init__( self, max_batch_size: int = 16, device: str = "cpu", ): model = torchvision.models.vgg16(pretrained=True) feature_layer = "avgpool" super().__init__(model, feature_layer, max_batch_size, device)