Source code for shifthappens.benchmark

"""Base functions to register new tasks to the benchmark and evaluate models.

To register add a new task decorate a task class inherited from
:py:class:`shifthappens.tasks.base.Task` with :py:meth:`shifthappens.benchmark.register_task`
function.

To evaluate model on all the registered tasks run
:py:meth:`shifthappens.benchmark.evaluate_model`.
"""

import dataclasses
import os
from typing import Dict
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type

import shifthappens.utils as sh_utils
from shifthappens.models import Model
from shifthappens.task_data import task_metadata
from shifthappens.task_data.task_registration import TaskRegistration
from shifthappens.tasks.base import Task
from shifthappens.tasks.task_result import TaskResult

__all__ = ["evaluate_model", "register_task", "get_registered_tasks"]
__registered_tasks: Set[TaskRegistration] = set()


[docs]def get_registered_tasks() -> Tuple[Type[Task], ...]: """All tasks currently registered as part of the benchmark. Returns: A tuple of all currently registered tasks as part of the benchmark. This tuple used for task iteration in :py:meth:`shifthappens.benchmark.evaluate_model`. """ return tuple([x.cls for x in __registered_tasks])
def get_task_registrations() -> Tuple[TaskRegistration, ...]: """Registrations for all tasks currently registered as part of the benchmark.""" return tuple(__registered_tasks)
[docs]def register_task(*, name: str, relative_data_folder: str, standalone: bool = True): """Registers a task class inherited from :py:class:`shifthappens.tasks.base.Task` as task as part of the benchmark. Args: name: Name of the task (can contain spaces or special characters). relative_data_folder: Name of the folder in which the data for this dataset will be saved for this task relative to the root folder of the benchmark. standalone: Boolean which represents if this task meaningful as a standalone task or will this only be relevant as a part of a collection of tasks. Examples: >>> @shifthappens.benchmark.register_task( name="CustomTask", relative_data_folder="path_to_store_task_data", standalone=True ) >>> @dataclasses.dataclass >>> class CustomTaskClass(Task): ... """ assert sh_utils.is_pathname_valid( relative_data_folder ), "relative_data_folder must only contain valid characters for a path" def _inner_register_task(cls: Type[Task], /): assert issubclass(cls, Task) # make sure the class was not registered before if cls in [t.cls for t in __registered_tasks]: return # check whether the task is marked as a dataclass # (i.e. defines its own _FIELDS attribute and does not use that of the base task) assert getattr(cls, getattr(dataclasses, "_FIELDS")) is not getattr( Task, getattr(dataclasses, "_FIELDS") ), "Tasks need to be dataclasses (i.e. add a @dataclasses.dataclass() decorator)" # check that the class did not define any fields the benchmark uses internally forbidden_fields = [task_metadata._TASK_METADATA_FIELD] for forbidden_field in forbidden_fields: assert not hasattr( cls, forbidden_field ), f"Tasks must not have an attribute called `{forbidden_field}`" # add metadata to class definition metadata = task_metadata.TaskMetadata( name=name, relative_data_folder=relative_data_folder, standalone=standalone, ) setattr(cls, task_metadata._TASK_METADATA_FIELD, metadata) # finally register class registration = TaskRegistration(cls, metadata=metadata) __registered_tasks.add(registration) return cls return _inner_register_task
def unregister_task(cls: Type[Task]): """Unregisters a task by removing it from the task registry.""" for cls_reg in __registered_tasks: if cls_reg.cls == cls: __registered_tasks.remove(cls_reg) return raise ValueError(f"Task `{cls}` is not registered.")
[docs]def evaluate_model( model: Model, data_root: str ) -> Dict[TaskRegistration, Optional[TaskResult]]: """ Runs all registered tasks of the benchmark which are supported by the supplied model. Args: model: Model to evaluate. data_root: Folder where individual tasks can store their data. Returns: Associates :py:class:`shifthappens.task_data.task_metadata.TaskMetadata` with the respective :py:class:`shifthappens.tasks.task_result.TaskResult`. Examples: >>> import shifthappens.benchmark >>> from shifthappens.models.torchvision import ResNet18 >>> # import existing model or create a custom one inherited from >>> # shifthappens.models.base.Model and ModelMixin's >>> model = ResNet18() >>> shifthappens.benchmark.evaluate_model(model, "path_to_store_tasks_data") """ results = dict() for task_registration in get_task_registrations(): if not task_registration.metadata.standalone: continue for task in task_registration.cls.iterate_flavours( data_root=os.path.join( data_root, task_registration.metadata.relative_data_folder ) ): task.setup() flavored_task_metadata = getattr(task, task_metadata._TASK_METADATA_FIELD) results[flavored_task_metadata] = task.evaluate(model) return results