Source code for meerkat.tools.utils

from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence

import yaml
from yaml.constructor import ConstructorError


[docs]class MeerkatLoader(yaml.FullLoader): """PyYaml does not load unimported modules for safety reasons. We want to allow importing only meerkat modules """
[docs] def find_python_module(self, name: str, mark, unsafe=False): try: return super().find_python_module(name=name, mark=mark, unsafe=unsafe) except ConstructorError as e: if name.startswith("meerkat."): __import__(name) else: raise e return super().find_python_module(name=name, mark=mark, unsafe=unsafe)
[docs] def find_python_name(self, name: str, mark, unsafe=False): if "meerkat.nn" in name: # backwards compatibility with old name name = name.replace("meerkat.nn", "meerkat.ml") if "." in name: module_name, _ = name.rsplit(".", 1) else: module_name = "builtins" try: return super().find_python_name(name=name, mark=mark, unsafe=unsafe) except ConstructorError as e: if name.startswith("meerkat."): __import__(module_name) else: raise e return super().find_python_name(name=name, mark=mark, unsafe=unsafe)
[docs]def convert_to_batch_column_fn( function: Callable, with_indices: bool, materialize: bool = True, **kwargs ): """Batch a function that applies to an example.""" def _function(batch: Sequence, indices: Optional[List[int]], *args, **kwargs): # Pull out the batch size batch_size = len(batch) # Iterate and apply the function outputs = None for i in range(batch_size): # Apply the unbatched function if with_indices: output = function( batch[i] if materialize else batch.lz[i], indices[i], *args, **kwargs, ) else: output = function( batch[i] if materialize else batch.lz[i], *args, **kwargs, ) if i == 0: # Create an empty dict or list for the outputs outputs = defaultdict(list) if isinstance(output, dict) else [] # Append the output if isinstance(output, dict): for k in output.keys(): outputs[k].append(output[k]) else: outputs.append(output) if isinstance(outputs, dict): return dict(outputs) return outputs if with_indices: # Just return the function as is return _function else: # Wrap in a lambda to apply the indices argument return lambda batch, *args, **kwargs: _function(batch, None, *args, **kwargs)
[docs]def convert_to_batch_fn( function: Callable, with_indices: bool, materialize: bool = True, **kwargs ): """Batch a function that applies to an example.""" def _function( batch: Dict[str, List], indices: Optional[List[int]], *args, **kwargs ): # Pull out the batch size batch_size = len(batch[list(batch.keys())[0]]) # Iterate and apply the function outputs = None for i in range(batch_size): # Apply the unbatched function if with_indices: output = function( {k: v[i] if materialize else v.lz[i] for k, v in batch.items()}, indices[i], *args, **kwargs, ) else: output = function( {k: v[i] if materialize else v.lz[i] for k, v in batch.items()}, *args, **kwargs, ) if i == 0: # Create an empty dict or list for the outputs outputs = defaultdict(list) if isinstance(output, dict) else [] # Append the output if isinstance(output, dict): for k in output.keys(): outputs[k].append(output[k]) else: outputs.append(output) if isinstance(outputs, dict): return dict(outputs) return outputs if with_indices: # Just return the function as is return _function else: # Wrap in a lambda to apply the indices argument return lambda batch, *args, **kwargs: _function(batch, None, *args, **kwargs)