from __future__ import annotations
import logging
import os
import warnings
from typing import Callable, Collection, Sequence, Union
import dill
import numpy as np
import yaml
from meerkat.block.abstract import BlockView
from meerkat.block.deferred_block import DeferredBlock, DeferredCellOp, DeferredOp
from meerkat.cells.abstract import AbstractCell
from meerkat.columns.abstract import Column
from meerkat.errors import ConcatWarning, ImmutableError
from meerkat.tools.lazy_loader import LazyLoader
Image = LazyLoader("PIL.Image")
logger = logging.getLogger(__name__)
[docs]class DeferredCell(AbstractCell):
def __init__(self, data: DeferredCellOp):
self._data = data
@property
def data(self) -> object:
"""Get the data associated with this cell."""
return self._data
[docs] def get(self, *args, **kwargs):
return self.data._get()
def __eq__(self, other):
return (other.__class__ == self.__class__) and (self.data == other.data)
def __repr__(self):
name = getattr(self.data.fn, "__qualname__", repr(self.data.fn))
return f"{self.__class__.__qualname__}(fn={name})"
def __call__(self):
return self.data._get()
[docs]class DeferredColumn(Column):
block_class: type = DeferredBlock
def __init__(
self,
data: Union[DeferredOp, BlockView],
output_type: type = None,
*args,
**kwargs,
):
self._output_type = output_type
super(DeferredColumn, self).__init__(data, *args, **kwargs)
def __call__(self):
# TODO(Sabri): Make this a more efficient call
return self._get(index=np.arange(len(self)), materialize=True)
def _set(self, index, value):
raise ImmutableError("LambdaColumn is immutable.")
@property
def fn(self) -> Callable:
"""Subclasses like `ImageColumn` should be able to implement their own
version."""
return self.data.fn
def _create_cell(self, data: object) -> DeferredCell:
return DeferredCell(data=data)
def _get(self, index, materialize: bool = False, _data: np.ndarray = None):
index = self._translate_index(index)
data = self.data._get(index=index, materialize=materialize)
if isinstance(index, int):
if materialize:
return data
else:
return self._create_cell(data=data)
elif isinstance(index, np.ndarray):
# support for blocks
if materialize:
# materialize could change the data in unknown ways, cannot clone
return self.__class__.from_data(data=self.collate(data))
else:
return self._clone(data=data)
@classmethod
def _state_keys(cls) -> Collection:
return super()._state_keys() | {"_output_type"}
@staticmethod
def concat(columns: Sequence[DeferredColumn]):
for c in columns:
if c.fn != columns[0].fn:
warnings.warn(
ConcatWarning("Concatenating LambdaColumns with different `fn`.")
)
break
return columns[0]._clone(data=DeferredOp.concat([c.data for c in columns]))
def _write_data(self, path):
return self.data.write(os.path.join(path, "data"))
[docs] def is_equal(self, other: Column) -> bool:
if other.__class__ != self.__class__:
return False
return self.data.is_equal(other.data)
@staticmethod
def _read_data(path: str):
try:
return DeferredOp.read(path=os.path.join(path, "data"))
except KeyError:
# TODO(Sabri): Remove this in a future version, once we no longer need to
# support old DataFrames.
warnings.warn(
"Reading a LambdaColumn stored in a format that will soon be"
" deprecated. Please re-write the column to the new format."
)
meta = yaml.load(
open(os.path.join(path, "data", "meta.yaml")),
Loader=yaml.FullLoader,
)
if issubclass(meta["dtype"], Column):
col = Column.read(os.path.join(path, "data"))
else:
raise ValueError(
"Support for LambdaColumns based on a DataFrame is deprecated."
)
state = dill.load(open(os.path.join(path, "state.dill"), "rb"))
return DeferredOp(
args=[col],
kwargs={},
fn=state["fn"],
is_batched_fn=False,
batch_size=1,
)
def _get_default_formatter(self) -> Callable:
from meerkat.interactive.formatter import BasicFormatter, PILImageFormatter
if self._output_type == Image.Image:
return PILImageFormatter()
return BasicFormatter()
def _repr_cell(self, idx):
return self[idx]