from __future__ import annotations
import abc
import logging
import pathlib
import reprlib
from copy import copy
from typing import Any, Callable, List, Optional, Sequence, Union
import numpy as np
import pandas as pd
import torch
import meerkat.config
from meerkat.mixins.blockable import BlockableMixin
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.mixins.collate import CollateMixin
from meerkat.mixins.inspect_fn import FunctionInspectorMixin
from meerkat.mixins.io import ColumnIOMixin
from meerkat.mixins.lambdable import LambdaMixin
from meerkat.mixins.mapping import MappableMixin
from meerkat.mixins.materialize import MaterializationMixin
from meerkat.provenance import ProvenanceMixin, capture_provenance
from meerkat.tools.utils import convert_to_batch_column_fn
logger = logging.getLogger(__name__)
[docs]class AbstractColumn(
BlockableMixin,
CloneableMixin,
CollateMixin,
ColumnIOMixin,
FunctionInspectorMixin,
LambdaMixin,
MappableMixin,
MaterializationMixin,
ProvenanceMixin,
abc.ABC,
):
"""An abstract class for Meerkat columns."""
_data: Sequence = None
# Path to a log directory
logdir: pathlib.Path = pathlib.Path.home() / "meerkat/"
# Create a directory
logdir.mkdir(parents=True, exist_ok=True)
def __init__(
self,
data: Sequence = None,
collate_fn: Callable = None,
formatter: Callable = None,
*args,
**kwargs,
):
"""
Args:
data (Sequence, optional): [description]. Defaults to None.
collate_fn (Callable, optional): [description]. Defaults to None.
formatter (Callable, optional): . Defaults to None.
"""
# Assign to data
self._set_data(data)
super(AbstractColumn, self).__init__(
collate_fn=collate_fn,
*args,
**kwargs,
)
self._formatter = (
formatter if formatter is not None else self._get_default_formatter()
)
# Log creation
logger.info(f"Created `{self.__class__.__name__}` with {len(self)} rows.")
def __repr__(self):
return f"{self.__class__.__name__}({reprlib.repr(self.data)})"
def __str__(self):
return f"{self.__class__.__name__}({reprlib.repr(self.data)})"
[docs] def streamlit(self):
return self._repr_pandas_()
def _set_data(self, data):
if self.is_blockable():
data = self._unpack_block_view(data)
self._data = data
@property
def data(self):
"""Get the underlying data."""
return self._data
@data.setter
def data(self, value):
self._set_data(value)
@property
def metadata(self):
return {}
@classmethod
def _state_keys(cls) -> set:
"""List of attributes that describe the state of the object."""
return {"_collate_fn", "_formatter"}
def _get_cell(self, index: int, materialize: bool = True) -> Any:
"""Get a single cell from the column.
Args:
index (int): This is an index into the ALL rows, not just visible rows. In
other words, we assume that the index passed in has already been
remapped via `_remap_index`, if `self.visible_rows` is not `None`.
materialize (bool, optional): Materialize and return the object. This
argument is used by subclasses of `AbstractColumn` that hold data in an
unmaterialized format. Defaults to False.
"""
return self._data[index]
def _get_batch(
self, indices: np.ndarray, materialize: bool = True
) -> AbstractColumn:
"""Get a batch of cells from the column.
Args:
index (int): This is an index into the ALL rows, not just visible rows. In
other words, we assume that the index passed in has already been
remapped via `_remap_index`, if `self.visible_rows` is not `None`.
materialize (bool, optional): Materialize and return the object. This
argument is used by subclasses of `AbstractColumn` that hold data in an
unmaterialized format. Defaults to False.
"""
if materialize:
return self.collate(
[self._get_cell(int(i), materialize=materialize) for i in indices]
)
else:
return self.collate(
[self._get_cell(int(i), materialize=materialize) for i in indices]
)
def _get(self, index, materialize: bool = True, _data: np.ndarray = None):
index = self._translate_index(index)
if isinstance(index, int):
if _data is None:
_data = self._get_cell(index, materialize=materialize)
return _data
elif isinstance(index, np.ndarray):
# support for blocks
if _data is None:
_data = self._get_batch(index, materialize=materialize)
return self._clone(data=_data)
# @capture_provenance()
def __getitem__(self, index):
return self._get(index, materialize=True)
def _set_cell(self, index, value):
self._data[index] = value
def _set_batch(self, indices: np.ndarray, values):
for index, value in zip(indices, values):
self._set_cell(int(index), value)
def _set(self, index, value):
index = self._translate_index(index)
if isinstance(index, int):
self._set_cell(index, value)
elif isinstance(index, Sequence) or isinstance(index, np.ndarray):
self._set_batch(index, value)
else:
raise ValueError
def __setitem__(self, index, value):
self._set(index, value)
def _is_batch_index(self, index):
# np.ndarray indexed with a tuple of length 1 does not return an np.ndarray
# so we match this behavior
return not (
isinstance(index, int) or (isinstance(index, tuple) and len(index) == 1)
)
def _translate_index(self, index):
# `index` should return a single element
if not self._is_batch_index(index):
return index
if isinstance(index, pd.Series):
index = index.values
if torch.is_tensor(index):
index = index.numpy()
if isinstance(index, tuple) or isinstance(index, list):
index = np.array(index)
# `index` should return a batch
if isinstance(index, slice):
# int or slice index => standard list slicing
# TODO (sabri): get rid of the np.arange here, very slow for large columns
indices = np.arange(self.full_length())[index]
elif isinstance(index, np.ndarray):
if len(index.shape) != 1:
raise TypeError(
"`np.ndarray` index must have 1 axis, not {}".format(
len(index.shape)
)
)
if index.dtype == bool:
indices = np.where(index)[0]
else:
return index
elif isinstance(index, AbstractColumn):
# TODO (sabri): get rid of the np.arange here, very slow for large columns
indices = np.arange(self.full_length())[index]
else:
raise TypeError(
"Object of type {} is not a valid index".format(type(index))
)
return indices
@staticmethod
def _convert_to_batch_fn(
function: Callable, with_indices: bool, materialize: bool = True, **kwargs
) -> callable:
return convert_to_batch_column_fn(
function=function,
with_indices=with_indices,
materialize=materialize,
**kwargs,
)
def __len__(self):
return self.full_length()
[docs] def full_length(self):
if self._data is None:
return 0
return len(self._data)
def _repr_cell_(self, index) -> object:
raise NotImplementedError
def _get_default_formatter(self) -> Callable:
# can't implement this as a class level property because then it will treat
# the formatter as a method
return None
@property
def formatter(self) -> Callable:
return self._formatter
def _repr_pandas_(self, max_rows: int = None) -> pd.Series:
if max_rows is None:
max_rows = meerkat.config.DisplayOptions.max_rows
if len(self) > max_rows:
col = pd.Series(
[self._repr_cell(idx) for idx in range(max_rows // 2)]
+ [self._repr_cell(0)]
+ [
self._repr_cell(idx)
for idx in range(len(self) - max_rows // 2, len(self))
]
)
else:
col = pd.Series([self._repr_cell(idx) for idx in range(len(self))])
return col, self.formatter
def _repr_html_(self, max_rows: int = None):
# pd.Series objects do not implement _repr_html_
if max_rows is None:
max_rows = meerkat.config.DisplayOptions.max_rows
if len(self) > max_rows:
pd_index = np.concatenate(
(
np.arange(max_rows // 2),
np.zeros(1),
np.arange(len(self) - max_rows // 2, len(self)),
),
)
else:
pd_index = np.arange(len(self))
col_name = f"({self.__class__.__name__})"
col, formatter = self._repr_pandas_(max_rows=max_rows)
df = col.to_frame(name=col_name)
df = df.set_index(pd_index.astype(int))
return df.to_html(
max_rows=max_rows,
formatters={col_name: formatter},
escape=False,
)
[docs] @capture_provenance()
def filter(
self,
function: Callable,
with_indices=False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
drop_last_batch: bool = False,
num_workers: Optional[int] = 0,
materialize: bool = True,
pbar: bool = False,
**kwargs,
) -> Optional[AbstractColumn]:
"""Filter the elements of the column using a function."""
# Return if `self` has no examples
if not len(self):
logger.info("Dataset empty, returning it .")
return self
# Get some information about the function
function_properties = self._inspect_function(
function,
with_indices,
is_batched_fn=is_batched_fn,
materialize=materialize,
**kwargs,
)
assert function_properties.bool_output, "function must return boolean."
# Map to get the boolean outputs and indices
logger.info("Running `filter`, a new dataset will be returned.")
outputs = self.map(
function=function,
with_indices=with_indices,
# input_columns=input_columns,
is_batched_fn=is_batched_fn,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_workers=num_workers,
materialize=materialize,
pbar=pbar,
**kwargs,
)
indices = np.where(outputs)[0]
return self.lz[indices]
[docs] def append(self, column: AbstractColumn) -> None:
# TODO(Sabri): implement a naive `ComposedColumn` for generic append and
# implement specific ones for ListColumn, NumpyColumn etc.
raise NotImplementedError
[docs] @staticmethod
def concat(columns: Sequence[AbstractColumn]) -> None:
# TODO(Sabri): implement a naive `ComposedColumn` for generic append and
# implement specific ones for ListColumn, NumpyColumn etc.
raise NotImplementedError
[docs] def is_equal(self, other: AbstractColumn) -> bool:
"""Tests whether two columns.
Args:
other (AbstractColumn): [description]
"""
raise NotImplementedError()
[docs] def batch(
self,
batch_size: int = 1,
drop_last_batch: bool = False,
collate: bool = True,
num_workers: int = 0,
materialize: bool = True,
*args,
**kwargs,
):
"""Batch the column.
Args:
batch_size: integer batch size
drop_last_batch: drop the last batch if its smaller than batch_size
collate: whether to collate the returned batches
Returns:
batches of data
"""
if (
self._get_batch.__func__ == AbstractColumn._get_batch
and self._get.__func__ == AbstractColumn._get
):
return torch.utils.data.DataLoader(
self if materialize else self.lz,
batch_size=batch_size,
collate_fn=self.collate if collate else lambda x: x,
drop_last=drop_last_batch,
num_workers=num_workers,
*args,
**kwargs,
)
else:
batch_indices = []
indices = np.arange(len(self))
for i in range(0, len(self), batch_size):
if drop_last_batch and i + batch_size > len(self):
continue
batch_indices.append(indices[i : i + batch_size])
return torch.utils.data.DataLoader(
self if materialize else self.lz,
sampler=batch_indices,
batch_size=None,
batch_sampler=None,
drop_last=drop_last_batch,
num_workers=num_workers,
*args,
**kwargs,
)
[docs] @classmethod
def get_writer(cls, mmap: bool = False, template: AbstractColumn = None):
from meerkat.writers.concat_writer import ConcatWriter
if mmap:
raise ValueError("Memmapping not supported with this column type.")
else:
return ConcatWriter(output_type=cls, template=template)
Columnable = Union[Sequence, np.ndarray, pd.Series, torch.Tensor]
[docs] @classmethod
# @capture_provenance()
def from_data(cls, data: Union[Columnable, AbstractColumn]):
"""Convert data to a meerkat column using the appropriate Column
type."""
if isinstance(data, AbstractColumn):
# TODO: Need ton make this view but should decide where to do it exactly
return data # .view()
if isinstance(data, pd.Series):
from .pandas_column import PandasSeriesColumn
return PandasSeriesColumn(data)
if torch.is_tensor(data):
from .tensor_column import TensorColumn
return TensorColumn(data)
if isinstance(data, np.ndarray):
from .numpy_column import NumpyArrayColumn
return NumpyArrayColumn(data)
if isinstance(data, Sequence):
from ..cells.abstract import AbstractCell
if len(data) != 0 and isinstance(data[0], AbstractCell):
from .cell_column import CellColumn
return CellColumn(data)
if len(data) != 0 and isinstance(
data[0], (int, float, bool, np.ndarray, np.generic)
):
from .numpy_column import NumpyArrayColumn
return NumpyArrayColumn(data)
if len(data) != 0 and isinstance(data[0], str):
from .pandas_column import PandasSeriesColumn
return PandasSeriesColumn(data)
if len(data) != 0 and torch.is_tensor(data[0]):
from .tensor_column import TensorColumn
return TensorColumn(data)
from .list_column import ListColumn
return ListColumn(data)
else:
raise ValueError(f"Cannot create column out of data of type {type(data)}")
[docs] def head(self, n: int = 5) -> AbstractColumn:
"""Get the first `n` examples of the column."""
return self.lz[:n]
[docs] def tail(self, n: int = 5) -> AbstractColumn:
"""Get the last `n` examples of the column."""
return self.lz[-n:]
[docs] def to_pandas(self) -> pd.Series:
return pd.Series([self.lz[int(idx)] for idx in range(len(self))])
def _copy_data(self) -> object:
return copy(self._data)
def _view_data(self) -> object:
return self._data
@property
def is_mmap(self):
return False