Source code for shifthappens.data.base
"""Base classes and helper functions for data handling (dataset, dataloader)."""
import abc
from typing import Iterator
from typing import List
from typing import Optional
from typing import Union
import numpy as np
[docs]class Dataset(abc.ABC):
"""
An abstract class representing an `iterable dataset <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_.
Your iterable datasets should be inherited from this class.
"""
@abc.abstractmethod
def __iter__(self):
raise NotImplementedError
[docs]class IndexedDataset(Dataset):
"""
A class representing a `map-style dataset <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_.
Your map-style datasets should be inherited from this class.
"""
@abc.abstractmethod
def __getitem__(self, item):
raise NotImplementedError
@abc.abstractmethod
def __len__(self):
raise NotImplementedError
def __iter__(self):
self.current_index = -1
return self
def __next__(self):
self.current_index += 1
if self.current_index >= len(self):
raise StopIteration
return self[self.current_index]
[docs]class DataLoader:
"""
Interface b/w model and task, implements restrictions
(e.g. max batch size) for models.
Args:
dataset: Dataset from which to load the data.
max_batch_size: How many samples allowed per batch to load.
"""
def __init__(self, dataset: Dataset, max_batch_size: Optional[int]):
self._dataset = dataset
self.__max_batch_size = max_batch_size
@property
def max_batch_size(self):
"""Maximum allowed batch size that the dataloader will satisfy for request
through the :py:func:`iterate` function."""
return self.__max_batch_size
[docs] def iterate(self, batch_size) -> Iterator[List[np.ndarray]]:
"""Iterate through the dataloader and return batches of data.
Args:
batch_size: Maximum batch size the function that requests the data can handle.
Yields:
The dataset split up in batches.
"""
if self.max_batch_size is not None:
batch_size = min(batch_size, self.max_batch_size)
batch = []
ds_iter = iter(self._dataset)
while (item := next(ds_iter, None)) is not None:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) != 0:
yield batch
[docs]def shuffle_data(
*, data: Union[List[np.ndarray], np.ndarray], seed: int
) -> Union[List[np.ndarray], np.ndarray]:
"""Randomly shuffles without replacement an :py:class:`numpy.ndarray`/list of
:py:class:`numpy.ndarray` objects with a fixed random seed.
Args:
data: Data to shuffle.
seed: Random seed.
"""
undo_list = False
if not isinstance(data, List):
undo_list = True
data = [
data,
]
assert np.all(
[len(data[0]) == len(it) for it in data]
), "All data arrays must have the same length"
rng = np.random.default_rng(seed=seed)
rnd_indxs = rng.choice(len(data[0]), size=len(data[0]), replace=False)
data = [it[rnd_indxs] for it in data]
if undo_list:
data = data[0]
return data