Source code for meerkat.columns.lambda_column

from __future__ import annotations

import logging
import os
import warnings
from typing import Callable, Collection, Mapping, Sequence, Union

import numpy as np
import yaml

import meerkat as mk
from meerkat.cells.abstract import AbstractCell
from meerkat.columns.abstract import AbstractColumn
from meerkat.datapanel import DataPanel
from meerkat.display import lambda_cell_formatter
from meerkat.errors import ConcatWarning
from meerkat.tools.lazy_loader import LazyLoader

Image = LazyLoader("PIL.Image")


logger = logging.getLogger(__name__)


[docs]class LambdaCell(AbstractCell): def __init__( self, fn: callable = None, data: any = None, ): self.fn = fn self._data = data @property def data(self) -> object: """Get the data associated with this cell.""" return self._data
[docs] def get(self, *args, **kwargs): if isinstance(self.data, AbstractCell): return self.fn(self.data.get()) elif isinstance(self.data, Mapping): return self.fn( { k: v.get() if isinstance(v, AbstractCell) else v for k, v in self.data.items() } ) else: return self.fn(self.data)
def __eq__(self, other): return ( (other.__class__ == self.__class__) and (self.data == other.data) and (self.fn == other.fn) ) def __repr__(self): name = getattr(self.fn, "__qualname__", repr(self.fn)) return f"LambdaCell(fn={name})"
[docs]class LambdaColumn(AbstractColumn): def __init__( self, data: Union[DataPanel, AbstractColumn], fn: callable = None, output_type: type = None, *args, **kwargs, ): super(LambdaColumn, self).__init__(data.view(), *args, **kwargs) if fn is not None: self.fn = fn self._output_type = output_type def _set(self, index, value): raise ValueError("Cannot setitem on a `LambdaColumn`.")
[docs] def fn(self, data: object): """Subclasses like `ImageColumn` should be able to implement their own version.""" raise NotImplementedError
def _create_cell(self, data: object) -> LambdaCell: return LambdaCell(fn=self.fn, data=data) def _get_cell(self, index: int, materialize: bool = True): if materialize: return self.fn(self._data._get(index, materialize=True)) else: return self._create_cell(data=self._data._get(index, materialize=False)) def _get_batch(self, indices: np.ndarray, materialize: bool = True): if materialize: # if materializing, return a batch (by default, a list of objects returned # by `.get`, otherwise the batch format specified by `self.collate`) data = self.collate( [self._get_cell(int(i), materialize=True) for i in indices] ) if self._output_type is not None: data = self._output_type(data) return data else: return self._data.lz[indices] def _get(self, index, materialize: bool = True, _data: np.ndarray = None): index = self._translate_index(index) if isinstance(index, int): if _data is None: _data = self._get_cell(index, materialize=materialize) return _data elif isinstance(index, np.ndarray): # support for blocks if _data is None: _data = self._get_batch(index, materialize=materialize) if materialize: # materialize could change the data in unknown ways, cannot clone return self.__class__.from_data(data=_data) else: return self._clone(data=_data) @classmethod def _state_keys(cls) -> Collection: return super()._state_keys() | {"fn", "_output_type"}
[docs] @staticmethod def concat(columns: Sequence[LambdaColumn]): for c in columns: if c.fn != columns[0].fn: warnings.warn( ConcatWarning("Concatenating LambdaColumns with different `fn`.") ) break return columns[0]._clone(mk.concat([c._data for c in columns]))
def _write_data(self, path): # TODO (Sabri): avoid redundant writes in dataframes return self.data.write(os.path.join(path, "data"))
[docs] def is_equal(self, other: AbstractColumn) -> bool: if other.__class__ != self.__class__: return False if self.fn != other.fn: return False return self.data.is_equal(other.data)
@staticmethod def _read_data(path: str): meta = yaml.load( open(os.path.join(path, "data", "meta.yaml")), Loader=yaml.FullLoader, ) if issubclass(meta["dtype"], AbstractColumn): return AbstractColumn.read(os.path.join(path, "data")) else: return DataPanel.read(os.path.join(path, "data")) @staticmethod def _get_default_formatter() -> Callable: return lambda_cell_formatter def _repr_cell(self, idx): return self.lz[idx]