Source code for meerkat.ml.model

from typing import Dict, List

import torch

from meerkat.datapanel import DataPanel
from meerkat.ml.embedding_column import EmbeddingColumn
from meerkat.ml.metrics import compute_metric
from meerkat.ml.prediction_column import ClassificationOutputColumn


# TODO(Priya): Move some general functions here
[docs]class Model(torch.nn.Module): def __init__( self, # identifier: str, model: torch.nn.Module, # evaluation_fn=None, is_classifier: bool = None, task: str = None, device: str = None, ): super(Model, self).__init__() self.model = model # if evaluation_fn is not None: # self.evaluate = evaluation_fn if task is None: if is_classifier is None or not is_classifier: raise ValueError("Task is required for non-classification models.") else: is_classifier = False self.is_classifier = is_classifier self.task = task if not device: self.device = "cpu" if torch.cuda.is_available(): self.device = "cuda:0"
[docs] def forward(self, input_batch: Dict) -> Dict: raise NotImplementedError
def _predict(self, batch, input_cols): # Use input_cols instead of input_columns to avoid naming conflict with map # Process the batch to prepare input input_batch = self.process_batch(batch, input_cols) # Run forward pass prediction_dict = self.forward(input_batch) return prediction_dict def _activation(self, batch, input_cols, activation_op): # Use input_cols instead of input_columns to avoid naming conflict with map # Process the batch input_batch = self.process_batch(batch, input_cols) # Run forward pass _ = self.forward(input_batch) activation_dict = { f"activation_{activation_op.target_module}": EmbeddingColumn( activation_op.extractor.activation.cpu().detach() ) } return activation_dict
[docs] def activation( self, dataset: DataPanel, target_module: str, # TODO(Priya): Support multiple activation layers input_columns: List[str], batch_size=32, ) -> EmbeddingColumn: """An Operation that stores model activations in a new Embedding column. Args: dataset (DataPanel): the meerkat DataPanel containing the model inputs. target_module (str): the name of the submodule of `model` (i.e. an intermediate layer) that outputs the activations we'd like to extract. For nested submodules, specify a path separated by "." (e.g. `ActivationCachedOp(model, "block4.conv")`). input_columns (str): Column containing model inputs """ # To avoid circular dependency for now. from meerkat.ml.activation import ActivationOp # Get an activation operator activation_op = ActivationOp(self.model, target_module, self.device) activations = dataset.map( function=self._activation, is_batched_fn=True, batch_size=batch_size, output_type=EmbeddingColumn, input_cols=input_columns, activation_op=activation_op, ) activation_col = activations[f"activation_{activation_op.target_module}"] # dataset.add_column(f"activation ({target_module})", activation_col) return activation_col
[docs] def classification( self, dataset: DataPanel, input_columns: List[str], batch_size: int = 32, num_classes: int = None, multi_label: bool = False, one_hot: bool = None, threshold=0.5, ) -> DataPanel: # Handles outputs for classification tasks predictions = dataset.map( function=self._predict, is_batched_fn=True, batch_size=batch_size, output_type=ClassificationOutputColumn, input_cols=input_columns, ) # TODO(Priya): How to pass other args of ClassificationOutputColumn above? output_col = ClassificationOutputColumn( logits=predictions["logits"].data, num_classes=num_classes, multi_label=multi_label, one_hot=one_hot, threshold=threshold, ) output_dp = DataPanel( { "logits": output_col, "probs": output_col.probabilities(), "preds": output_col.predictions(), } ) dataset.add_column("logits", output_col) dataset.add_column("probs", output_col.probabilities()) dataset.add_column("preds", output_col.predictions()) return output_dp
[docs] def evaluate( self, dataset: DataPanel, target_column: List[str], # str? pred_column: List[str], # str? metrics: List[str], num_classes: int = None, ): preds = dataset[pred_column[0]] labels = dataset[target_column[0]] if num_classes is None: if isinstance(preds, ClassificationOutputColumn): num_classes = preds.num_classes elif self.is_classifier: raise ValueError( "Must specify num_classes if column type \ is not ClassificationOutputColumn" ) evaluation_dict = { metric: compute_metric(metric, preds.data, labels.data, num_classes) for metric in metrics } return evaluation_dict
[docs] @staticmethod def remap_labels(output_dict: Dict, label_map: List[int]) -> Dict: """Map the output labels of the model. Example: 3-way classificaiton, with label_map = [1, 2, 0] => (model label 0 -> dataset label 1, model label 1 -> dataset label 2, ...). """ # Check the number of classes num_classes = len(label_map) # Remap the columns of all outputs that have # columns = num_classes for key in output_dict: if output_dict[key].shape[-1] == num_classes: output_dict[key] = output_dict[key][..., label_map] # Remap the pred key inverse_label_map = [ t[1] for t in sorted([(label, i) for i, label in enumerate(label_map)]) ] output_dict["pred"] = torch.tensor(inverse_label_map)[output_dict["pred"]] return output_dict
@classmethod def _state_keys(cls) -> set: """List of attributes that describe the state of the object.""" return {"model", "is_classifier", "task", "device"}