Source code for meerkat.mixins.inspect_fn

from collections.abc import Callable, Mapping, Sequence
from types import SimpleNamespace

import numpy as np
import torch


[docs]class FunctionInspectorMixin: def __init__(self, *args, **kwargs): super(FunctionInspectorMixin, self).__init__(*args, **kwargs) def _inspect_function( self, function: Callable, with_indices: bool = False, is_batched_fn: bool = False, data=None, indices=None, materialize=True, **kwargs ) -> SimpleNamespace: # Initialize variables to track no_output = dict_output = bool_output = list_output = False # If dict_output = True and `function` is used for updating the `DataPanel` # useful to know if any existing column is modified updates_existing_column = True existing_columns_updated = [] # Record the shape and dtype of the output output_shape = output_dtype = None # Run the function to test it if data is None: if is_batched_fn: data = self[:2] if materialize else self.lz[:2] else: data = self[0] if materialize else self.lz[0] if indices is None: if is_batched_fn: indices = range(2) else: indices = 0 if with_indices and is_batched_fn: output = function(data, indices, **kwargs) elif with_indices and not is_batched_fn: output = function(data, indices, **kwargs) else: output = function(data, **kwargs) # lazy import to avoid circular dependency from meerkat.columns.abstract import AbstractColumn from meerkat.columns.numpy_column import NumpyArrayColumn from meerkat.columns.tensor_column import TensorColumn if isinstance(output, Mapping): # `function` returns a dict output dict_output = True # Check if `self` is a `DataPanel` if hasattr(self, "all_columns"): # Set of columns that are updated existing_columns_updated = set(self.all_columns).intersection( set(output.keys()) ) # Check if `function` updates an existing column if len(existing_columns_updated) == 0: updates_existing_column = False elif output is None: # `function` returns None no_output = True elif ( isinstance(output, (bool, np.bool_)) or ( isinstance(output, (np.ndarray, NumpyArrayColumn)) and output.dtype == bool ) or ( isinstance(output, (torch.Tensor, TensorColumn)) and output.dtype == torch.bool ) ): # `function` returns a bool bool_output = True elif isinstance(output, (Sequence, AbstractColumn, torch.Tensor, np.ndarray)): # `function` returns a list list_output = True if is_batched_fn and ( isinstance(output[0], (bool, np.bool_)) or (isinstance(output[0], np.ndarray) and (output[0].dtype == bool)) or ( isinstance(output[0], torch.Tensor) and (output[0].dtype == torch.bool) ) ): # `function` returns a bool per example bool_output = True if not isinstance(output, Mapping) and not isinstance(output, bool): # Record the shape of the output if hasattr(output, "shape"): output_shape = output.shape elif len(output) > 0 and hasattr(output[0], "shape"): output_shape = output[0].shape # Record the dtype of the output if hasattr(output, "dtype"): output_dtype = output.dtype elif len(output) > 0 and hasattr(output[0], "dtype"): output_dtype = output[0].dtype return SimpleNamespace( output=output, dict_output=dict_output, no_output=no_output, bool_output=bool_output, list_output=list_output, updates_existing_column=updates_existing_column, existing_columns_updated=existing_columns_updated, output_shape=output_shape, output_dtype=output_dtype, )