Source code for shifthappens.tasks.task_result
"""Class for representing the results of a single task."""
from typing import Dict
from typing import Tuple
from typing import Union
from .metrics import Metric
[docs]class TaskResult:
"""Contains the results of a result, which can be arbitrary metrics.
At least one of these metrics must be references as a summary metric.
Args:
summary_metrics: Associates :py:class:`shifthappens.tasks.metrics.Metric` values
to the name of metrics calculated by the task.
metrics: Metrics' names and their values.
Examples:
>>> @dataclasses.dataclass
>>> class CustomTask(Task):
>>> ...
>>> def _evaluate(self, model: shifthappens.models.base.Model) -> DataLoader:
>>> ...
>>> return TaskResult(
>>> your_robustness_metric=your_robustness_metric,
>>> your_calibration_metric=your_calibration_metric,
>>> your_custom_metric=1.0 - your_custom_metric,
>>> summary_metrics={
>>> Metric.Robustness: "your_robustness_metric",
>>> Metric.Calibration: "your_calibration_metric"},
>>> )
>>> ...
"""
def __init__(
self,
*,
summary_metrics: Dict[Metric, Union[str, Tuple[str, ...]]],
**metrics: Union[float, int],
):
# validate that metrics referenced in summary metrics exist
for sm in summary_metrics:
assert isinstance(sm, Metric), "Invalid summary metric key."
smv = summary_metrics[sm]
if isinstance(smv, str):
tms: Tuple[str, ...] = (smv,)
elif isinstance(smv, tuple):
tms = smv
else:
raise ValueError(
f"Value for metric key `{sm}` is neither str nor tuple of str."
)
for tm in tms:
assert tm in metrics
self._metrics = metrics
self.summary_metrics = summary_metrics
def __getitem__(self, item) -> float:
return self._metrics[item]
def __getattr__(self, item) -> float:
if item in self._metrics:
return self[item]
else:
return super().__getattribute__(item)