Source code for shifthappens.benchmark

"""Base functions to register new tasks to the benchmark and evaluate models.

To register add a new task decorate a task class inherited from
:py:class:`shifthappens.tasks.base.Task` with :py:meth:`shifthappens.benchmark.register_task`

To evaluate model on all the registered tasks run

import dataclasses
import os
from typing import Dict
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type

import shifthappens.utils as sh_utils
from shifthappens.models import Model
from shifthappens.task_data import task_metadata
from shifthappens.task_data.task_registration import TaskRegistration
from shifthappens.tasks.base import Task
from shifthappens.tasks.task_result import TaskResult

__all__ = ["evaluate_model", "register_task", "get_registered_tasks"]
__registered_tasks: Set[TaskRegistration] = set()

[docs]def get_registered_tasks() -> Tuple[Type[Task], ...]: """All tasks currently registered as part of the benchmark. Returns: A tuple of all currently registered tasks as part of the benchmark. This tuple used for task iteration in :py:meth:`shifthappens.benchmark.evaluate_model`. """ return tuple([x.cls for x in __registered_tasks])
def get_task_registrations() -> Tuple[TaskRegistration, ...]: """Registrations for all tasks currently registered as part of the benchmark.""" return tuple(__registered_tasks)
[docs]def register_task(*, name: str, relative_data_folder: str, standalone: bool = True): """Registers a task class inherited from :py:class:`shifthappens.tasks.base.Task` as task as part of the benchmark. Args: name: Name of the task (can contain spaces or special characters). relative_data_folder: Name of the folder in which the data for this dataset will be saved for this task relative to the root folder of the benchmark. standalone: Boolean which represents if this task meaningful as a standalone task or will this only be relevant as a part of a collection of tasks. Examples: >>> @shifthappens.benchmark.register_task( name="CustomTask", relative_data_folder="path_to_store_task_data", standalone=True ) >>> @dataclasses.dataclass >>> class CustomTaskClass(Task): ... """ assert sh_utils.is_pathname_valid( relative_data_folder ), "relative_data_folder must only contain valid characters for a path" def _inner_register_task(cls: Type[Task], /): assert issubclass(cls, Task) # make sure the class was not registered before if cls in [t.cls for t in __registered_tasks]: return # check whether the task is marked as a dataclass # (i.e. defines its own _FIELDS attribute and does not use that of the base task) assert getattr(cls, getattr(dataclasses, "_FIELDS")) is not getattr( Task, getattr(dataclasses, "_FIELDS") ), "Tasks need to be dataclasses (i.e. add a @dataclasses.dataclass() decorator)" # check that the class did not define any fields the benchmark uses internally forbidden_fields = [task_metadata._TASK_METADATA_FIELD] for forbidden_field in forbidden_fields: assert not hasattr( cls, forbidden_field ), f"Tasks must not have an attribute called `{forbidden_field}`" # add metadata to class definition metadata = task_metadata.TaskMetadata( name=name, relative_data_folder=relative_data_folder, standalone=standalone, ) setattr(cls, task_metadata._TASK_METADATA_FIELD, metadata) # finally register class registration = TaskRegistration(cls, metadata=metadata) __registered_tasks.add(registration) return cls return _inner_register_task
def unregister_task(cls: Type[Task]): """Unregisters a task by removing it from the task registry.""" for cls_reg in __registered_tasks: if cls_reg.cls == cls: __registered_tasks.remove(cls_reg) return raise ValueError(f"Task `{cls}` is not registered.")
[docs]def evaluate_model( model: Model, data_root: str ) -> Dict[TaskRegistration, Optional[TaskResult]]: """ Runs all registered tasks of the benchmark which are supported by the supplied model. Args: model: Model to evaluate. data_root: Folder where individual tasks can store their data. Returns: Associates :py:class:`shifthappens.task_data.task_metadata.TaskMetadata` with the respective :py:class:`shifthappens.tasks.task_result.TaskResult`. Examples: >>> import shifthappens.benchmark >>> from shifthappens.models.torchvision import ResNet18 >>> # import existing model or create a custom one inherited from >>> # shifthappens.models.base.Model and ModelMixin's >>> model = ResNet18() >>> shifthappens.benchmark.evaluate_model(model, "path_to_store_tasks_data") """ results = dict() for task_registration in get_task_registrations(): if not task_registration.metadata.standalone: continue for task in task_registration.cls.iterate_flavours( data_root=os.path.join( data_root, task_registration.metadata.relative_data_folder ) ): task.setup() flavored_task_metadata = getattr(task, task_metadata._TASK_METADATA_FIELD) results[flavored_task_metadata] = task.evaluate(model) return results