Source code for meerkat.tools.utils

from collections import defaultdict
from functools import reduce
from typing import Callable, Dict, List, Optional, Sequence

import numpy as np
import pandas as pd
import torch
import yaml
from yaml.constructor import ConstructorError


[docs]def nested_getattr(obj, attr, *args): """Get a nested property from an object. # noqa: E501 Source: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties """ return reduce(lambda o, a: getattr(o, a, *args), [obj] + attr.split("."))
[docs]class MeerkatLoader(yaml.FullLoader): """PyYaml does not load unimported modules for safety reasons. We want to allow importing only meerkat modules """
[docs] def find_python_module(self, name: str, mark, unsafe=False): try: return super().find_python_module(name=name, mark=mark, unsafe=unsafe) except ConstructorError as e: if name.startswith("meerkat."): __import__(name) else: raise e return super().find_python_module(name=name, mark=mark, unsafe=unsafe)
[docs] def find_python_name(self, name: str, mark, unsafe=False): if "meerkat.nn" in name: # backwards compatibility with old name name = name.replace("meerkat.nn", "meerkat.ml") if "." in name: module_name, _ = name.rsplit(".", 1) else: module_name = "builtins" try: return super().find_python_name(name=name, mark=mark, unsafe=unsafe) except ConstructorError as e: if name.startswith("meerkat."): __import__(module_name) else: raise e return super().find_python_name(name=name, mark=mark, unsafe=unsafe)
[docs]def convert_to_batch_column_fn( function: Callable, with_indices: bool, materialize: bool = True, **kwargs ): """Batch a function that applies to an example.""" def _function(batch: Sequence, indices: Optional[List[int]], *args, **kwargs): # Pull out the batch size batch_size = len(batch) # Iterate and apply the function outputs = None for i in range(batch_size): # Apply the unbatched function if with_indices: output = function( batch[i] if materialize else batch.lz[i], indices[i], *args, **kwargs, ) else: output = function( batch[i] if materialize else batch.lz[i], *args, **kwargs, ) if i == 0: # Create an empty dict or list for the outputs outputs = defaultdict(list) if isinstance(output, dict) else [] # Append the output if isinstance(output, dict): for k in output.keys(): outputs[k].append(output[k]) else: outputs.append(output) if isinstance(outputs, dict): return dict(outputs) return outputs if with_indices: # Just return the function as is return _function else: # Wrap in a lambda to apply the indices argument return lambda batch, *args, **kwargs: _function(batch, None, *args, **kwargs)
[docs]def convert_to_batch_fn( function: Callable, with_indices: bool, materialize: bool = True, **kwargs ): """Batch a function that applies to an example.""" def _function( batch: Dict[str, List], indices: Optional[List[int]], *args, **kwargs ): # Pull out the batch size batch_size = len(batch[list(batch.keys())[0]]) # Iterate and apply the function outputs = None for i in range(batch_size): # Apply the unbatched function if with_indices: output = function( {k: v[i] if materialize else v.lz[i] for k, v in batch.items()}, indices[i], *args, **kwargs, ) else: output = function( {k: v[i] if materialize else v.lz[i] for k, v in batch.items()}, *args, **kwargs, ) if i == 0: # Create an empty dict or list for the outputs outputs = defaultdict(list) if isinstance(output, dict) else [] # Append the output if isinstance(output, dict): for k in output.keys(): outputs[k].append(output[k]) else: outputs.append(output) if isinstance(outputs, dict): return dict(outputs) return outputs if with_indices: # Just return the function as is return _function else: # Wrap in a lambda to apply the indices argument return lambda batch, *args, **kwargs: _function(batch, None, *args, **kwargs)
[docs]def translate_index(index, length: int): def _is_batch_index(index): # np.ndarray indexed with a tuple of length 1 does not return an np.ndarray # so we match this behavior return not ( isinstance(index, int) or (isinstance(index, tuple) and len(index) == 1) ) # `index` should return a single element if not _is_batch_index(index): return index from ..columns.abstract import AbstractColumn if isinstance(index, pd.Series): index = index.values if torch.is_tensor(index): index = index.numpy() if isinstance(index, tuple) or isinstance(index, list): index = np.array(index) # `index` should return a batch if isinstance(index, slice): # int or slice index => standard list slicing indices = np.arange(*index.indices(length)) elif isinstance(index, np.ndarray): if len(index.shape) != 1: raise TypeError( "`np.ndarray` index must have 1 axis, not {}".format(len(index.shape)) ) if index.dtype == bool: indices = np.where(index)[0] else: return index elif isinstance(index, AbstractColumn): # TODO (sabri): get rid of the np.arange here, very slow for large columns indices = np.arange(length)[index] else: raise TypeError("Object of type {} is not a valid index".format(type(index))) return indices