"""DataPanel class."""
from __future__ import annotations
import logging
import os
import pathlib
from typing import (
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import cytoolz as tz
import datasets
import dill
import numpy as np
import pandas as pd
import pyarrow as pa
import torch
import yaml
from pandas._libs import lib
import meerkat
from meerkat.block.manager import BlockManager
from meerkat.columns.abstract import AbstractColumn
from meerkat.columns.cell_column import CellColumn
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.mixins.inspect_fn import FunctionInspectorMixin
from meerkat.mixins.lambdable import LambdaMixin
from meerkat.mixins.mapping import MappableMixin
from meerkat.mixins.materialize import MaterializationMixin
from meerkat.provenance import ProvenanceMixin, capture_provenance
from meerkat.tools.utils import MeerkatLoader, convert_to_batch_fn
logger = logging.getLogger(__name__)
Example = Dict
Batch = Dict[str, Union[List, AbstractColumn]]
BatchOrDataset = Union[Batch, "DataPanel"]
[docs]class DataPanel(
CloneableMixin,
FunctionInspectorMixin,
LambdaMixin,
MappableMixin,
MaterializationMixin,
ProvenanceMixin,
):
"""Meerkat DataPanel class."""
# Path to a log directory
logdir: pathlib.Path = pathlib.Path.home() / "meerkat/"
# Create a directory
logdir.mkdir(parents=True, exist_ok=True)
def __init__(
self,
data: Union[dict, list, datasets.Dataset] = None,
*args,
**kwargs,
):
super(DataPanel, self).__init__(
*args,
**kwargs,
)
logger.debug("Creating DataPanel.")
self.data = data
def _repr_pandas_(self, max_rows: int = None):
if max_rows is None:
max_rows = meerkat.config.DisplayOptions.max_rows
df, formatters = self.data._repr_pandas_(max_rows=max_rows)
rename = {k: f"{k} ({v.__class__.__name__})" for k, v in self.items()}
return (
df[self.columns].rename(columns=rename),
{rename[k]: v for k, v in formatters.items()},
)
def _repr_html_(self, max_rows: int = None):
if max_rows is None:
max_rows = meerkat.config.DisplayOptions.max_rows
df, formatters = self._repr_pandas_(max_rows=max_rows)
return df.to_html(formatters=formatters, max_rows=max_rows, escape=False)
[docs] def streamlit(self):
return self._repr_pandas_()
def __repr__(self):
return (
f"{self.__class__.__name__}" f"(nrows: {self.nrows}, ncols: {self.ncols})"
)
def __len__(self):
return self.nrows
def __contains__(self, item):
return item in self.columns
@property
def data(self) -> BlockManager:
"""Get the underlying data (excluding invisible rows).
To access underlying data with invisible rows, use `_data`.
"""
return self._data
def _set_data(self, value: Union[BlockManager, Mapping] = None):
if isinstance(value, BlockManager):
self._data = value
elif isinstance(value, Mapping):
self._data = BlockManager.from_dict(value)
elif isinstance(value, Sequence):
if not isinstance(value[0], Mapping):
raise ValueError(
"Cannot set DataPanel `data` to a Sequence containing object of "
f" type {type(value[0])}. Must be a Sequence of Mapping."
)
gen = (list(x.keys()) for x in value)
columns = lib.fast_unique_multiple_list_gen(gen)
data = lib.dicts_to_array(value, columns=columns)
# Assert all columns are the same length
data = {column: list(data[:, idx]) for idx, column in enumerate(columns)}
self._data = BlockManager.from_dict(data)
elif value is None:
self._data = BlockManager()
else:
raise ValueError(
f"Cannot set DataPanel `data` to object of type {type(value)}."
)
@data.setter
def data(self, value):
self._set_data(value)
@property
def columns(self):
"""Column names in the DataPanel."""
return list(self.data.keys())
@property
def nrows(self):
"""Number of rows in the DataPanel."""
if self.ncols == 0:
return 0
return self.data.nrows
@property
def ncols(self):
"""Number of rows in the DataPanel."""
return self.data.ncols
@property
def shape(self):
"""Shape of the DataPanel (num_rows, num_columns)."""
return self.nrows, self.ncols
[docs] def add_column(
self, name: str, data: AbstractColumn.Columnable, overwrite=False
) -> None:
"""Add a column to the DataPanel."""
assert isinstance(
name, str
), f"Column name must of type `str`, not `{type(name)}`."
assert (name not in self.columns) or overwrite, (
f"Column with name `{name}` already exists, "
f"set `overwrite=True` to overwrite."
)
if name in self.columns:
self.remove_column(name)
column = AbstractColumn.from_data(data)
assert len(column) == len(self) or len(self.columns) == 0, (
f"`add_column` failed. "
f"Values length {len(column)} != dataset length {len(self)}."
)
# Add the column
self.data[name] = column
logger.info(f"Added column `{name}` with length `{len(column)}`.")
[docs] def remove_column(self, column: str) -> None:
"""Remove a column from the dataset."""
assert column in self.columns, f"Column `{column}` does not exist."
# Remove the column
del self.data[column]
logger.info(f"Removed column `{column}`.")
[docs] @capture_provenance(capture_args=["axis"])
def append(
self,
dp: DataPanel,
axis: Union[str, int] = "rows",
suffixes: Tuple[str] = None,
overwrite: bool = False,
) -> DataPanel:
"""Append a batch of data to the dataset.
`example_or_batch` must have the same columns as the dataset
(regardless of what columns are visible).
"""
return meerkat.concat(
[self, dp], axis=axis, suffixes=suffixes, overwrite=overwrite
)
[docs] def head(self, n: int = 5) -> DataPanel:
"""Get the first `n` examples of the DataPanel."""
return self.lz[:n]
[docs] def tail(self, n: int = 5) -> DataPanel:
"""Get the last `n` examples of the DataPanel."""
return self.lz[-n:]
def _get(self, index, materialize: bool = False):
if isinstance(index, str):
# str index => column selection (AbstractColumn)
if index in self.columns:
return self.data[index]
raise KeyError(f"Column `{index}` does not exist.")
elif isinstance(index, int):
# int index => single row (dict)
return {
k: self.data[k]._get(index, materialize=materialize)
for k in self.columns
}
# cases where `index` returns a datapanel
index_type = None
if isinstance(index, slice):
# slice index => multiple row selection (DataPanel)
index_type = "row"
elif (isinstance(index, tuple) or isinstance(index, list)) and len(index):
# tuple or list index => multiple row selection (DataPanel)
if isinstance(index[0], str):
index_type = "column"
else:
index_type = "row"
elif isinstance(index, np.ndarray):
if len(index.shape) != 1:
raise ValueError(
"Index must have 1 axis, not {}".format(len(index.shape))
)
# numpy array index => multiple row selection (DataPanel)
index_type = "row"
elif torch.is_tensor(index):
if len(index.shape) != 1:
raise ValueError(
"Index must have 1 axis, not {}".format(len(index.shape))
)
# torch tensor index => multiple row selection (DataPanel)
index_type = "row"
elif isinstance(index, pd.Series):
index_type = "row"
elif isinstance(index, AbstractColumn):
# column index => multiple row selection (DataPanel)
index_type = "row"
else:
raise TypeError("Invalid index type: {}".format(type(index)))
if index_type == "column":
if not set(index).issubset(self.columns):
missing_cols = set(index) - set(self.columns)
raise KeyError(f"DataPanel does not have columns {missing_cols}")
dp = self._clone(data=self.data[index])
return dp
elif index_type == "row": # pragma: no cover
return self._clone(
data=self.data.apply("_get", index=index, materialize=materialize)
)
# @capture_provenance(capture_args=[])
def __getitem__(self, index):
return self._get(index, materialize=True)
def __setitem__(self, index, value):
self.add_column(name=index, data=value, overwrite=True)
[docs] def consolidate(self):
self.data.consolidate()
[docs] @classmethod
def from_huggingface(cls, *args, **kwargs):
"""Load a Huggingface dataset as a DataPanel.
Use this to replace `datasets.load_dataset`, so
>>> dict_of_datasets = datasets.load_dataset('boolq')
becomes
>>> dict_of_datapanels = DataPanel.from_huggingface('boolq')
"""
# Load the dataset
dataset = datasets.load_dataset(*args, **kwargs)
if isinstance(dataset, dict):
return dict(
map(
lambda t: (t[0], cls.from_arrow(t[1]._data)),
dataset.items(),
)
)
else:
return cls.from_arrow(dataset._data)
[docs] @classmethod
@capture_provenance()
def from_jsonl(
cls,
json_path: str,
) -> DataPanel:
"""Load a dataset from a .jsonl file on disk, where each line of the
json file consists of a single example."""
return cls.from_pandas(pd.read_json(json_path, orient="records", lines=True))
[docs] @classmethod
@capture_provenance()
def from_batch(
cls,
batch: Batch,
) -> DataPanel:
"""Convert a batch to a Dataset."""
return cls(batch)
[docs] @classmethod
@capture_provenance()
def from_batches(
cls,
batches: Sequence[Batch],
) -> DataPanel:
"""Convert a list of batches to a dataset."""
return cls.from_batch(
tz.merge_with(
tz.compose(list, tz.concat),
*batches,
),
)
[docs] @classmethod
@capture_provenance()
def from_dict(
cls,
d: Dict,
) -> DataPanel:
"""Convert a dictionary to a dataset.
Alias for Dataset.from_batch(..).
"""
return cls.from_batch(
batch=d,
)
[docs] @classmethod
@capture_provenance()
def from_pandas(
cls,
df: pd.DataFrame,
):
"""Create a Dataset from a pandas DataFrame."""
# column names must be str in meerkat
df = df.rename(mapper=str, axis="columns")
return cls.from_batch(
df.to_dict("series"),
)
[docs] @classmethod
@capture_provenance()
def from_arrow(
cls,
table: pa.Table,
):
"""Create a Dataset from a pandas DataFrame."""
from meerkat.block.arrow_block import ArrowBlock
from meerkat.columns.arrow_column import ArrowArrayColumn
block_views = ArrowBlock.from_block_data(table)
return cls.from_batch(
{view.block_index: ArrowArrayColumn(view) for view in block_views}
)
[docs] @classmethod
@capture_provenance(capture_args=["filepath"])
def from_csv(cls, filepath: str, *args, **kwargs):
"""Create a Dataset from a csv file.
Args:
filepath (str): The file path or buffer to load from.
Same as :func:`pandas.read_csv`.
*args: Argument list for :func:`pandas.read_csv`.
**kwargs: Keyword arguments for :func:`pandas.read_csv`.
Returns:
DataPanel: The constructed datapanel.
"""
return cls.from_pandas(pd.read_csv(filepath, *args, **kwargs))
[docs] @classmethod
@capture_provenance()
def from_feather(
cls,
path: str,
):
"""Create a Dataset from a feather file."""
return cls.from_batch(
pd.read_feather(path).to_dict("list"),
)
[docs] @capture_provenance()
def to_pandas(self) -> pd.DataFrame:
"""Convert a Dataset to a pandas DataFrame."""
return pd.DataFrame(
{
name: column.to_pandas().reset_index(drop=True)
for name, column in self.items()
}
)
[docs] def to_jsonl(self, path: str) -> None:
"""Save a Dataset to a jsonl file."""
self.to_pandas().to_json(path, lines=True, orient="records")
def _get_collate_fns(self, columns: Iterable[str] = None):
columns = self.data.keys() if columns is None else columns
return {name: self.data[name].collate for name in columns}
def _collate(self, batch: List):
batch = tz.merge_with(list, *batch)
column_to_collate = self._get_collate_fns(batch.keys())
new_batch = {}
for name, values in batch.items():
new_batch[name] = column_to_collate[name](values)
dp = self._clone(data=new_batch)
return dp
@staticmethod
def _convert_to_batch_fn(
function: Callable, with_indices: bool, materialize: bool = True, **kwargs
) -> callable:
return convert_to_batch_fn(
function=function,
with_indices=with_indices,
materialize=materialize,
**kwargs,
)
[docs] def batch(
self,
batch_size: int = 1,
drop_last_batch: bool = False,
num_workers: int = 0,
materialize: bool = True,
shuffle: bool = False,
*args,
**kwargs,
):
"""Batch the dataset.
TODO:
Args:
batch_size: integer batch size
drop_last_batch: drop the last batch if its smaller than batch_size
Returns:
batches of data
"""
cell_columns, batch_columns = [], []
from meerkat.columns.lambda_column import LambdaColumn
for name, column in self.items():
if isinstance(column, (CellColumn, LambdaColumn)) and materialize:
cell_columns.append(name)
else:
batch_columns.append(name)
indices = np.arange(len(self))
if shuffle:
indices = np.random.permutation(indices)
if batch_columns:
batch_indices = []
for i in range(0, len(self), batch_size):
if drop_last_batch and i + batch_size > len(self):
continue
batch_indices.append(indices[i : i + batch_size])
batch_dl = torch.utils.data.DataLoader(
self[batch_columns] if materialize else self[batch_columns].lz,
sampler=batch_indices,
batch_size=None,
batch_sampler=None,
drop_last=drop_last_batch,
num_workers=num_workers,
*args,
**kwargs,
)
if cell_columns:
dp = self[cell_columns] if not shuffle else self[cell_columns].lz[indices]
cell_dl = torch.utils.data.DataLoader(
dp if materialize else dp.lz,
batch_size=batch_size,
collate_fn=self._collate,
drop_last=drop_last_batch,
num_workers=num_workers,
*args,
**kwargs,
)
if batch_columns and cell_columns:
for cell_batch, batch_batch in zip(cell_dl, batch_dl):
yield self._clone(data={**cell_batch.data, **batch_batch.data})
elif batch_columns:
for batch_batch in batch_dl:
yield batch_batch
elif cell_columns:
for cell_batch in cell_dl:
yield cell_batch
[docs] @capture_provenance(capture_args=["with_indices"])
def update(
self,
function: Optional[Callable] = None,
with_indices: bool = False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
remove_columns: Optional[List[str]] = None,
num_workers: int = 0,
output_type: Union[type, Dict[str, type]] = None,
mmap: bool = False,
mmap_path: str = None,
materialize: bool = True,
pbar: bool = False,
**kwargs,
) -> DataPanel:
"""Update the columns of the dataset."""
# TODO(karan): make this fn go faster
# most of the time is spent on the merge, speed it up further
# Return if `self` has no examples
if not len(self):
logger.info("Dataset empty, returning None.")
return self
# Get some information about the function
dp = self[input_columns] if input_columns is not None else self
function_properties = dp._inspect_function(
function, with_indices, is_batched_fn, materialize=materialize, **kwargs
)
assert (
function_properties.dict_output
), f"`function` {function} must return dict."
if not is_batched_fn:
# Convert to a batch function
function = convert_to_batch_fn(
function, with_indices=with_indices, materialize=materialize, **kwargs
)
logger.info(f"Converting `function` {function} to batched function.")
# Update always returns a new dataset
logger.info("Running update, a new dataset will be returned.")
# Copy the ._data dict with a reference to the actual columns
new_dp = self.view()
# Calculate the values for the new columns using a .map()
output = new_dp.map(
function=function,
with_indices=with_indices,
is_batched_fn=True,
batch_size=batch_size,
num_workers=num_workers,
output_type=output_type,
input_columns=input_columns,
mmap=mmap,
mmap_path=mmap_path,
materialize=materialize,
pbar=pbar,
**kwargs,
)
# Add new columns for the update
for col, vals in output.data.items():
new_dp.add_column(col, vals, overwrite=True)
# Remove columns
if remove_columns:
for col in remove_columns:
new_dp.remove_column(col)
logger.info(f"Removed columns {remove_columns}.")
return new_dp
[docs] @capture_provenance()
def map(
self,
function: Optional[Callable] = None,
with_indices: bool = False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
drop_last_batch: bool = False,
num_workers: int = 0,
output_type: Union[type, Dict[str, type]] = None,
mmap: bool = False,
mmap_path: str = None,
materialize: bool = True,
pbar: bool = False,
**kwargs,
) -> Optional[Union[Dict, List, AbstractColumn]]:
input_columns = self.columns if input_columns is None else input_columns
dp = self[input_columns]
return super(DataPanel, dp).map(
function=function,
with_indices=with_indices,
is_batched_fn=is_batched_fn,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_workers=num_workers,
output_type=output_type,
mmap=mmap,
mmap_path=mmap_path,
materialize=materialize,
pbar=pbar,
**kwargs,
)
[docs] @capture_provenance(capture_args=["function"])
def filter(
self,
function: Optional[Callable] = None,
with_indices=False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
drop_last_batch: bool = False,
num_workers: int = 0,
materialize: bool = True,
pbar: bool = False,
**kwargs,
) -> Optional[DataPanel]:
"""Filter operation on the DataPanel."""
# Return if `self` has no examples
if not len(self):
logger.info("DataPanel empty, returning None.")
return None
# Get some information about the function
dp = self[input_columns] if input_columns is not None else self
function_properties = dp._inspect_function(
function,
with_indices,
is_batched_fn=is_batched_fn,
materialize=materialize,
**kwargs,
)
assert function_properties.bool_output, "function must return boolean."
# Map to get the boolean outputs and indices
logger.info("Running `filter`, a new DataPanel will be returned.")
outputs = self.map(
function=function,
with_indices=with_indices,
input_columns=input_columns,
is_batched_fn=is_batched_fn,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_workers=num_workers,
materialize=materialize,
pbar=pbar,
**kwargs,
)
indices = np.where(outputs)[0]
# filter returns a new datapanel
return self.lz[indices]
[docs] def merge(
self,
right: meerkat.DataPanel,
how: str = "inner",
on: Union[str, List[str]] = None,
left_on: Union[str, List[str]] = None,
right_on: Union[str, List[str]] = None,
sort: bool = False,
suffixes: Sequence[str] = ("_x", "_y"),
validate=None,
):
from meerkat import merge
return merge(
self,
right,
how=how,
on=on,
left_on=left_on,
right_on=right_on,
sort=sort,
suffixes=suffixes,
validate=validate,
)
[docs] def items(self):
for name in self.columns:
yield name, self.data[name]
[docs] def keys(self):
return self.columns
[docs] def values(self):
for name in self.columns:
yield self.data[name]
[docs] @classmethod
def read(
cls,
path: str,
*args,
**kwargs,
) -> DataPanel:
"""Load a DataPanel stored on disk."""
# Load the metadata
metadata = dict(
yaml.load(open(os.path.join(path, "meta.yaml")), Loader=MeerkatLoader)
)
state = dill.load(open(os.path.join(path, "state.dill"), "rb"))
dp = cls.__new__(cls)
dp._set_state(state)
# Load the the manager
mgr_dir = os.path.join(path, "mgr")
if os.path.exists(mgr_dir):
data = BlockManager.read(mgr_dir, **kwargs)
else:
# backwards compatability to pre-manager datapanels
data = {
name: dtype.read(os.path.join(path, "columns", name), *args, **kwargs)
for name, dtype in metadata["column_dtypes"].items()
}
dp._set_data(data)
return dp
[docs] def write(
self,
path: str,
) -> None:
"""Save a DataPanel to disk."""
# Make all the directories to the path
os.makedirs(path, exist_ok=True)
# Get the DataPanel state
state = self._get_state()
# Get the metadata
metadata = {
"dtype": type(self),
"column_dtypes": {name: type(col) for name, col in self.data.items()},
"len": len(self),
}
# write the block manager
mgr_dir = os.path.join(path, "mgr")
self.data.write(mgr_dir)
# Write the state
state_path = os.path.join(path, "state.dill")
dill.dump(state, open(state_path, "wb"))
# Save the metadata as a yaml file
metadata_path = os.path.join(path, "meta.yaml")
yaml.dump(metadata, open(metadata_path, "w"))
@classmethod
def _state_keys(cls) -> Set[str]:
"""List of attributes that describe the state of the object."""
return set()
def _view_data(self) -> object:
return self.data.view()
def _copy_data(self) -> object:
return self.data.copy()
def __finalize__(self, *args, **kwargs):
return self