Source code for meerkat.columns.lambda_column

from __future__ import annotations

import logging
import os
import warnings
from typing import Callable, Collection, Sequence, Union

import dill
import numpy as np
import yaml

from meerkat.block.abstract import BlockView
from meerkat.block.lambda_block import LambdaBlock, LambdaCellOp, LambdaOp
from meerkat.cells.abstract import AbstractCell
from meerkat.columns.abstract import AbstractColumn
from meerkat.display import lambda_cell_formatter
from meerkat.errors import ConcatWarning, ImmutableError
from meerkat.tools.lazy_loader import LazyLoader

Image = LazyLoader("PIL.Image")


logger = logging.getLogger(__name__)


[docs]class LambdaCell(AbstractCell): def __init__(self, data: LambdaCellOp): self._data = data @property def data(self) -> object: """Get the data associated with this cell.""" return self._data
[docs] def get(self, *args, **kwargs): return self.data._get()
def __eq__(self, other): return (other.__class__ == self.__class__) and (self.data == other.data) def __repr__(self): name = getattr(self.data.fn, "__qualname__", repr(self.data.fn)) return f"{self.__class__.__qualname__}(fn={name})"
[docs]class LambdaColumn(AbstractColumn): block_class: type = LambdaBlock def __init__( self, data: Union[LambdaOp, BlockView], output_type: type = None, *args, **kwargs, ): super(LambdaColumn, self).__init__(data, *args, **kwargs) self._output_type = output_type def _set(self, index, value): raise ImmutableError("LambdaColumn is immutable.") @property def fn(self) -> Callable: """Subclasses like `ImageColumn` should be able to implement their own version.""" return self.data.fn def _create_cell(self, data: object) -> LambdaCell: return LambdaCell(data=data) def _get(self, index, materialize: bool = True, _data: np.ndarray = None): index = self._translate_index(index) data = self.data._get(index=index, materialize=materialize) if isinstance(index, int): if materialize: return data else: return self._create_cell(data=data) elif isinstance(index, np.ndarray): # support for blocks if materialize: # materialize could change the data in unknown ways, cannot clone return self.__class__.from_data(data=self.collate(data)) else: return self._clone(data=data) @classmethod def _state_keys(cls) -> Collection: return super()._state_keys() | {"_output_type"}
[docs] @staticmethod def concat(columns: Sequence[LambdaColumn]): for c in columns: if c.fn != columns[0].fn: warnings.warn( ConcatWarning("Concatenating LambdaColumns with different `fn`.") ) break return columns[0]._clone(data=LambdaOp.concat([c.data for c in columns]))
def _write_data(self, path): # TODO (Sabri): avoid redundant writes in dataframes return self.data.write(os.path.join(path, "data"))
[docs] def is_equal(self, other: AbstractColumn) -> bool: if other.__class__ != self.__class__: return False return self.data.is_equal(other.data)
@staticmethod def _read_data(path: str): try: return LambdaOp.read(path=os.path.join(path, "data")) except KeyError: # TODO(Sabri): Remove this in a future version, once we no longer need to # support old DataPanels. warnings.warn( "Reading a LambdaColumn stored in a format that will soon be" " deprecated. Please re-write the column to the new format." ) print("here") meta = yaml.load( open(os.path.join(path, "data", "meta.yaml")), Loader=yaml.FullLoader, ) if issubclass(meta["dtype"], AbstractColumn): col = AbstractColumn.read(os.path.join(path, "data")) else: raise ValueError( "Support for LambdaColumns based on a DataPanel is deprecated." ) state = dill.load(open(os.path.join(path, "state.dill"), "rb")) return LambdaOp( args=[col], kwargs={}, fn=state["fn"], is_batched_fn=False, batch_size=1, ) @staticmethod def _get_default_formatter() -> Callable: return lambda_cell_formatter def _repr_cell(self, idx): return self.lz[idx]