Source code for meerkat.ml.prediction_column

from __future__ import annotations

from enum import Enum
from typing import Sequence, Union

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from meerkat.columns.tensor_column import TensorColumn

Columnable = Union[Sequence, np.ndarray, pd.Series, torch.Tensor]


class _ClassifierOutputType(Enum):
    LOGIT = 1, ("logit", "logits")
    PROBABILITY = 2, ("probability", "probabilities", "prob", "probs")
    PREDICTION = 3, ("prediction", "predictions", "pred", "preds")

    def __new__(cls, key_code, names):
        """
        Args:
            key_code (int): Enum value.
            extensions (tuple[str]): Extensions supported by format.
        """
        obj = object.__new__(cls)
        obj._value_ = key_code
        obj.names = names
        return obj

    @classmethod
    def get_ctype(cls, ctype):
        if isinstance(ctype, cls):
            return ctype
        if not isinstance(ctype, (int, str)):
            raise TypeError(
                f"Cannot get {cls} from ctype of type {type(ctype)}. Must be int or str"
            )

        for _ctype in cls:
            if (isinstance(ctype, int) and ctype == _ctype.value) or (
                ctype in _ctype.names
            ):
                return _ctype

        raise ValueError(f"Invalid ctype - no {cls} corresponding to `ctype={ctype}`")

    def readable_name(self):
        return self.names[0]


[docs]class ClassificationOutputColumn(TensorColumn): def __init__( self, logits: Columnable = None, probs: Columnable = None, preds: Columnable = None, num_classes: int = None, multi_label: bool = False, one_hot: bool = None, threshold=0.5, *args, **kwargs, ): """Classification output handler initialized by one (and only one) of ``logits``, ``probs``, or ``preds``. Args: logits (array-like | tensor-like): Event probabilities (unnormalized). probs (array-like | tensor-like): Event probabilities. preds (array-like | tensor-like): Predictions. Can either be one-hot encoded (set ``one_hot=True``) or array-like of categorical labels. num_classes (int): The number of categories. multi_label (bool): If ``True``, examples can have multiple categories. one_hot (bool): If ``True``, ``preds`` (if specified) is considered to be a one-hot encoded matrix. If ``True`` and ``multi_label=False``, ``preds`` will be auto-converted into categorical labels. threshold (float): The threshold for binarizing multi-label probabilities to predictions. This is not relevant if ``multi_label=False``. """ valid_data = [ (x, ctype) for x, ctype in [ (logits, _ClassifierOutputType.LOGIT), (probs, _ClassifierOutputType.PROBABILITY), (preds, _ClassifierOutputType.PREDICTION), ] if x is not None ] if len(valid_data) == 0: raise ValueError("Must specify one of `logits`, `probs`, or `preds`") if len(valid_data) > 1: raise ValueError( "Only one of `logits`, `probs`, or `preds` should be specified" ) data, ctype = valid_data[0] if isinstance(data, TensorColumn): data = data.data # unwrap tensor out of TensorColumn if ctype == _ClassifierOutputType.PREDICTION and not multi_label and one_hot: # Convert one-hot single-label into categorical labels data = torch.argmax(data) one_hot = False self._ctype = _ClassifierOutputType.get_ctype(ctype) self.num_classes = num_classes self.multi_label = multi_label self.one_hot = one_hot self.threshold = threshold super().__init__(data=data, *args, **kwargs) def _set_data(self, data): if self.num_classes is None: if self._ctype in ( _ClassifierOutputType.LOGIT, _ClassifierOutputType.PROBABILITY, ): self.num_classes = data.shape[1] if self._ctype == _ClassifierOutputType.PREDICTION: if self.multi_label: if self.one_hot is False or data.ndim < 2 or not _is_binary(data): raise ValueError( "Multi-label predictions must be one-hot encoded " "and have at least 2 dimensions - (N, C, ...)" ) self.one_hot = True self.num_classes = ( data.shape[1] if self.multi_label else torch.max(data) + 1 ) super(ClassificationOutputColumn, self)._set_data(data)
[docs] def logits(self) -> ClassificationOutputColumn: if self._ctype in ( _ClassifierOutputType.PROBABILITY, _ClassifierOutputType.PREDICTION, ): raise ValueError( f"Cannot convert from {self._ctype.readable_name()} to logits" ) return self.view()
[docs] def probabilities(self) -> ClassificationOutputColumn: if self._ctype == _ClassifierOutputType.LOGIT: probs = ( torch.sigmoid(self.data) if self.multi_label else torch.softmax(self.data, dim=1) ) return ClassificationOutputColumn( probs=probs, num_classes=self.num_classes, multi_label=self.multi_label, threshold=self.threshold, ) if self._ctype == _ClassifierOutputType.PROBABILITY: return self.view() raise ValueError(f"Cannot convert from {self._ctype.readable_name()} to logits")
probs = probabilities
[docs] def predictions(self) -> ClassificationOutputColumn: """Compute predictions.""" if self._ctype == _ClassifierOutputType.PREDICTION: return self.view() probs = self.probabilities().data preds = ( (probs >= self.threshold).type(torch.uint8) if self.multi_label else torch.argmax(probs, dim=1) ) return ClassificationOutputColumn( preds=preds, num_classes=self.num_classes, multi_label=self.multi_label, threshold=self.threshold, # multi-label predictions are always one hot encoded. one_hot=self.multi_label, )
preds = predictions
[docs] def bincount(self) -> TensorColumn: """Compute the count (cardinality) for each category. Categories which are not available will have a count of 0. If ``self.multi_label=True``, the bincount will include the total number of times the category is seen. If an example is marked as 2 categories, the bincount will increase the count for both categories. Note, this means the sum of the number of classes can be more than the number of examples ``N``. Returns: torch.Tensor: A 1D tensor of length ``self.num_classes``. """ preds = self.predictions().data if (self.multi_label and preds.ndim > 2) or ( not self.multi_label and preds.ndim > 1 ): raise ValueError( "mode for multi-dimensional tensors is not currently supported" ) if preds.ndim == 1: return TensorColumn(torch.bincount(preds, minlength=self.num_classes)) else: out = preds.sum(dim=0) if len(out) < self.num_classes: pad = (0,) * (2 * (out.ndim - 1)) + (0, self.num_classes - len(out)) out = F.pad(out, pad=pad) return TensorColumn(out)
[docs] def mode(self): count = self.bincount() return torch.argmax(count.data)
[docs] def entropy(self) -> TensorColumn: """Compute the entropy for each example. If ``self.multi_label`` is True, each category is treated as a binary classification problem. There will be an entropy calculation for each category as well. For example, if the probabilities are of shape ``(N, C)``, there will be ``NxC`` entropy values. In the multi-dimensional case, this returns the entropy for each element. For example, if the probabilities are of shape ``(N, C, A, B)``, there will be ``NxAxB`` entropy values. Returns: TensorColumn: Tensor of entropies """ probs = self.probabilities().data if self.multi_label: probs = torch.stack([probs, 1 - probs], dim=-1) elif probs.ndim > 2: # make channels last probs = probs.transpose((0,) + tuple(range(2, probs.ndim)) + (1,)) return TensorColumn(Categorical(probs=probs).entropy())
@classmethod def _state_keys(cls) -> set: """List of attributes that describe the state of the object.""" return super()._state_keys() | { "_ctype", "num_classes", "multi_label", "one_hot", "threshold", }
def _is_binary(tensor: torch.Tensor): return torch.all( (tensor == tensor.type(torch.uint8)) & (tensor >= 0) & (tensor <= 0) )