from __future__ import annotations
import abc
import functools
import logging
import numbers
import os
from typing import Any, Callable, List, Sequence, Union
import numpy as np
import pandas as pd
import pyarrow as pa
import torch
from pandas.core.accessor import CachedAccessor
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_period_dtype,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.indexes.accessors import (
CombinedDatetimelikeProperties,
DatetimeProperties,
PeriodProperties,
TimedeltaProperties,
)
from pandas.core.strings import StringMethods
from yaml.representer import Representer
from meerkat.block.abstract import BlockView
from meerkat.block.pandas_block import PandasBlock
from meerkat.columns.abstract import Column
from meerkat.interactive.formatter import Formatter
from meerkat.mixins.aggregate import AggregationError
from .abstract import ScalarColumn
Representer.add_representer(abc.ABCMeta, Representer.represent_name)
logger = logging.getLogger(__name__)
def getattr_decorator(fn: Callable):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
out = fn(*args, **kwargs)
if isinstance(out, pd.Series):
return PandasScalarColumn(out)
elif isinstance(out, pd.DataFrame):
from meerkat import DataFrame
# column names must be str in meerkat
out = out.rename(mapper=str, axis="columns")
return DataFrame.from_pandas(out)
else:
return out
return wrapper
class _ReturnColumnMixin:
def __getattribute__(self, name):
try:
attr = super().__getattribute__(name)
if isinstance(attr, Callable):
return getattr_decorator(attr)
elif isinstance(attr, pd.Series):
return PandasScalarColumn(attr)
elif isinstance(attr, pd.DataFrame):
from meerkat import DataFrame
return DataFrame.from_pandas(attr)
else:
return attr
except AttributeError:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
class _MeerkatStringMethods(_ReturnColumnMixin, StringMethods):
pass
class _MeerkatDatetimeProperties(_ReturnColumnMixin, DatetimeProperties):
pass
class _MeerkatTimedeltaProperties(_ReturnColumnMixin, TimedeltaProperties):
pass
class _MeerkatPeriodProperties(_ReturnColumnMixin, PeriodProperties):
pass
class _MeerkatCategoricalAccessor(_ReturnColumnMixin, CategoricalAccessor):
pass
class _MeerkatCombinedDatetimelikeProperties(CombinedDatetimelikeProperties):
def __new__(cls, data: pd.Series):
# CombinedDatetimelikeProperties isn't really instantiated. Instead
# we need to choose which parent (datetime or timedelta) is
# appropriate. Since we're checking the dtypes anyway, we'll just
# do all the validation here.
if not isinstance(data, ABCSeries):
raise TypeError(
f"cannot convert an object of type {type(data)} to a datetimelike index"
)
orig = data if is_categorical_dtype(data.dtype) else None
if orig is not None:
data = data._constructor(
orig.array,
name=orig.name,
copy=False,
dtype=orig._values.categories.dtype,
)
if is_datetime64_dtype(data.dtype):
obj = _MeerkatDatetimeProperties(data, orig)
elif is_datetime64tz_dtype(data.dtype):
obj = _MeerkatDatetimeProperties(data, orig)
elif is_timedelta64_dtype(data.dtype):
obj = _MeerkatTimedeltaProperties(data, orig)
elif is_period_dtype(data.dtype):
obj = _MeerkatPeriodProperties(data, orig)
else:
raise AttributeError("Can only use .dt accessor with datetimelike values")
return obj
[docs]class PandasScalarColumn(
ScalarColumn,
np.lib.mixins.NDArrayOperatorsMixin,
):
block_class: type = PandasBlock
_HANDLED_TYPES = (np.ndarray, numbers.Number, str)
str = CachedAccessor("str", _MeerkatStringMethods)
dt = CachedAccessor("dt", _MeerkatCombinedDatetimelikeProperties)
cat = CachedAccessor("cat", _MeerkatCategoricalAccessor)
# plot = CachedAccessor("plot", pandas.plotting.PlotAccessor)
# sparse = CachedAccessor("sparse", SparseAccessor)
def _set_data(self, data: object):
if isinstance(data, BlockView):
if not isinstance(data.block, PandasBlock):
raise ValueError(
"Cannot create `PandasSeriesColumn` from a `BlockView` not "
"referencing a `PandasBlock`."
)
elif isinstance(data, pd.Series):
# Force the index to be contiguous so that comparisons between different
# pandas series columns are always possible.
data = data.reset_index(drop=True)
else:
data = pd.Series(data)
super(PandasScalarColumn, self)._set_data(data)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (PandasScalarColumn,)):
return NotImplemented
# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(
x.data if isinstance(x, PandasScalarColumn) else x for x in inputs
)
if out:
kwargs["out"] = tuple(
x.data if isinstance(x, PandasScalarColumn) else x for x in out
)
result = getattr(ufunc, method)(*inputs, **kwargs)
if type(result) is tuple:
# multiple return values
return tuple(type(self)(x) for x in result) # pragma: no cover
elif method == "at":
# no return value
return None # pragma: no cover
else:
# one return value
return type(self)(result)
def __getattr__(self, name):
if name == "__getstate__" or name == "__setstate__":
# for pickle, it's important to raise an attribute error if __getstate__
# or __setstate__ is called. Without this, pickle will use the __setstate__
# and __getstate__ of the underlying pandas Series
raise AttributeError()
try:
out = getattr(object.__getattribute__(self, "data"), name)
if isinstance(out, Callable):
return getattr_decorator(out)
else:
return out
except AttributeError:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
@classmethod
def from_array(cls, data: np.ndarray, *args, **kwargs):
return cls(data=data, *args, **kwargs)
def _get(self, index, materialize: bool = True):
index = self.block_class._convert_index(index)
data = self._data.iloc[index]
if self._is_batch_index(index):
# only create a numpy array column
return self._clone(data=data)
else:
return data
def _set_cell(self, index, value):
self._data.iloc[index] = value
def _set_batch(self, indices, values):
self._data.iloc[indices] = values
@classmethod
def concat(cls, columns: Sequence[PandasScalarColumn]):
data = pd.concat([c.data for c in columns])
return columns[0]._clone(data=data)
def _write_data(self, path: str) -> None:
data_path = os.path.join(path, "data.pd")
self.data.to_pickle(data_path)
@staticmethod
def _read_data(
path: str,
):
data_path = os.path.join(path, "data.pd")
# Load in the data
return pd.read_pickle(data_path)
def _repr_cell(self, index) -> object:
return self[index]
def _get_default_formatter(self) -> Formatter:
# can't implement this as a class level property because then it will treat
# the formatter as a method
from meerkat.interactive.formatter import BasicFormatter
if len(self) == 0:
return BasicFormatter()
if self.dtype == object:
return BasicFormatter(dtype="str")
if self.dtype == pd.StringDtype:
return BasicFormatter(dtype="str")
cell = self[0]
if isinstance(cell, np.generic):
return BasicFormatter(dtype=type(cell.item()).__name__)
return BasicFormatter()
def _is_valid_primary_key(self):
return self.data.is_unique
def _keyidx_to_posidx(self, keyidx: Any) -> int:
# TODO(sabri): when we implement indices, we should use them here if we have
# one
where_result = np.where(self.data == keyidx)
if len(where_result[0]) == 0:
raise KeyError(f"keyidx {keyidx} not found in column.")
posidx = where_result[0][0]
return int(posidx)
def _keyidxs_to_posidxs(self, keyidxs: Sequence[Any]) -> np.ndarray:
posidxs = np.where(np.isin(self.data, keyidxs))[0]
diff = np.setdiff1d(keyidxs, self.data[posidxs])
if len(diff) > 0:
raise KeyError(f"Key indexes {diff} not found in column.")
return posidxs
[docs] def sort(
self, ascending: Union[bool, List[bool]] = True, kind: str = "quicksort"
) -> PandasScalarColumn:
"""Return a sorted view of the column.
Args:
ascending (Union[bool, List[bool]]): Whether to sort in ascending or
descending order. If a list, must be the same length as `by`. Defaults
to True.
kind (str): The kind of sort to use. Defaults to 'quicksort'. Options
include 'quicksort', 'mergesort', 'heapsort', 'stable'.
Return:
AbstractColumn: A view of the column with the sorted data.
"""
# calls argsort() function to retrieve ordered indices
sorted_index = self.argsort(ascending, kind)
return self[sorted_index]
[docs] def argsort(
self, ascending: bool = True, kind: str = "quicksort"
) -> PandasScalarColumn:
"""Return indices that would sorted the column.
Args:
ascending (Union[bool, List[bool]]): Whether to sort in ascending or
descending order. If a list, must be the same length as `by`. Defaults
to True.
kind (str): The kind of sort to use. Defaults to 'quicksort'. Options
include 'quicksort', 'mergesort', 'heapsort', 'stable'.
Return:
PandasSeriesColumn: A view of the column with the sorted data.
For now! Raises error when shape of input array is more than one error.
"""
num_columns = len(self.shape)
# Raise error if array has more than one column
if num_columns > 1:
raise Exception("No implementation for array with more than one column.")
# returns indices of descending order of array
if not ascending:
return (-1 * self.data).argsort(kind=kind)
# returns indices of ascending order of array
return self.data.argsort(kind=kind)
[docs] def to_tensor(self) -> torch.Tensor:
"""Use `column.to_tensor()` instead of `torch.tensor(column)`, which is
very slow."""
dtype = self.data.values.dtype
if not np.issubdtype(dtype, np.number):
raise ValueError(
f"Cannot convert `PandasSeriesColumn` with dtype={dtype} to tensor."
)
# TODO (Sabri): understand why `torch.tensor(column)` is so slow
return torch.tensor(self.data.values)
[docs] def to_numpy(self) -> torch.Tensor:
return self.values
[docs] def to_pandas(self, allow_objects: bool = False) -> pd.Series:
return self.data.reset_index(drop=True)
[docs] def to_arrow(self) -> pa.Array:
return pa.array(self.data.values)
[docs] def is_equal(self, other: Column) -> bool:
if other.__class__ != self.__class__:
return False
return (self.data.values == other.data.values).all()
def mean(self, skipna: bool = True):
try:
return self.data.mean(skipna=skipna)
except TypeError:
raise AggregationError(
"Cannot apply mean aggregation to Pandas Series with "
f" dtype '{self.data.dtype}'."
)
PandasSeriesColumn = PandasScalarColumn