from __future__ import annotations
import abc
import functools
import logging
import numbers
import os
from typing import Callable, Sequence
import numpy as np
import pandas as pd
import torch
from pandas.core.accessor import CachedAccessor
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_period_dtype,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.indexes.accessors import (
CombinedDatetimelikeProperties,
DatetimeProperties,
PeriodProperties,
TimedeltaProperties,
)
from pandas.core.strings import StringMethods
from yaml.representer import Representer
from meerkat.block.abstract import BlockView
from meerkat.block.pandas_block import PandasBlock
from meerkat.columns.abstract import AbstractColumn
Representer.add_representer(abc.ABCMeta, Representer.represent_name)
logger = logging.getLogger(__name__)
[docs]def getattr_decorator(fn: Callable):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
out = fn(*args, **kwargs)
if isinstance(out, pd.Series):
return PandasSeriesColumn(out)
elif isinstance(out, pd.DataFrame):
from meerkat import DataPanel
# column names must be str in meerkat
out = out.rename(mapper=str, axis="columns")
return DataPanel.from_pandas(out)
else:
return out
return wrapper
class _ReturnColumnMixin:
def __getattribute__(self, name):
try:
attr = super().__getattribute__(name)
if isinstance(attr, Callable):
return getattr_decorator(attr)
elif isinstance(attr, pd.Series):
return PandasSeriesColumn(attr)
elif isinstance(attr, pd.DataFrame):
from meerkat import DataPanel
return DataPanel.from_pandas(attr)
else:
return attr
except AttributeError:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
class _MeerkatStringMethods(_ReturnColumnMixin, StringMethods):
pass
class _MeerkatDatetimeProperties(_ReturnColumnMixin, DatetimeProperties):
pass
class _MeerkatTimedeltaProperties(_ReturnColumnMixin, TimedeltaProperties):
pass
class _MeerkatPeriodProperties(_ReturnColumnMixin, PeriodProperties):
pass
class _MeerkatCategoricalAccessor(_ReturnColumnMixin, CategoricalAccessor):
pass
class _MeerkatCombinedDatetimelikeProperties(CombinedDatetimelikeProperties):
def __new__(cls, data: pd.Series):
# CombinedDatetimelikeProperties isn't really instantiated. Instead
# we need to choose which parent (datetime or timedelta) is
# appropriate. Since we're checking the dtypes anyway, we'll just
# do all the validation here.
if not isinstance(data, ABCSeries):
raise TypeError(
f"cannot convert an object of type {type(data)} to a datetimelike index"
)
orig = data if is_categorical_dtype(data.dtype) else None
if orig is not None:
data = data._constructor(
orig.array,
name=orig.name,
copy=False,
dtype=orig._values.categories.dtype,
)
if is_datetime64_dtype(data.dtype):
obj = _MeerkatDatetimeProperties(data, orig)
elif is_datetime64tz_dtype(data.dtype):
obj = _MeerkatDatetimeProperties(data, orig)
elif is_timedelta64_dtype(data.dtype):
obj = _MeerkatTimedeltaProperties(data, orig)
elif is_period_dtype(data.dtype):
obj = _MeerkatPeriodProperties(data, orig)
else:
raise AttributeError("Can only use .dt accessor with datetimelike values")
return obj
[docs]class PandasSeriesColumn(
AbstractColumn,
np.lib.mixins.NDArrayOperatorsMixin,
):
block_class: type = PandasBlock
_HANDLED_TYPES = (np.ndarray, numbers.Number, str)
str = CachedAccessor("str", _MeerkatStringMethods)
dt = CachedAccessor("dt", _MeerkatCombinedDatetimelikeProperties)
cat = CachedAccessor("cat", _MeerkatCategoricalAccessor)
# plot = CachedAccessor("plot", pandas.plotting.PlotAccessor)
# sparse = CachedAccessor("sparse", SparseAccessor)
def _set_data(self, data: object):
if isinstance(data, BlockView):
if not isinstance(data.block, PandasBlock):
raise ValueError(
"Cannot create `PandasSeriesColumn` from a `BlockView` not "
"referencing a `PandasBlock`."
)
elif not isinstance(data, pd.Series):
data = pd.Series(data)
super(PandasSeriesColumn, self)._set_data(data)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (PandasSeriesColumn,)):
return NotImplemented
# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(
x.data if isinstance(x, PandasSeriesColumn) else x for x in inputs
)
if out:
kwargs["out"] = tuple(
x.data if isinstance(x, PandasSeriesColumn) else x for x in out
)
result = getattr(ufunc, method)(*inputs, **kwargs)
if type(result) is tuple:
# multiple return values
return tuple(type(self)(x) for x in result) # pragma: no cover
elif method == "at":
# no return value
return None # pragma: no cover
else:
# one return value
return type(self)(result)
def __getattr__(self, name):
if name == "__getstate__" or name == "__setstate__":
# for pickle, it's important to raise an attribute error if __getstate__
# or __setstate__ is called. Without this, pickle will use the __setstate__
# and __getstate__ of the underlying pandas Series
raise AttributeError()
try:
out = getattr(object.__getattribute__(self, "data"), name)
if isinstance(out, Callable):
return getattr_decorator(out)
else:
return out
except AttributeError:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
[docs] @classmethod
def from_array(cls, data: np.ndarray, *args, **kwargs):
return cls(data=data, *args, **kwargs)
def _get(self, index, materialize: bool = True):
index = self.block_class._convert_index(index)
data = self._data.iloc[index]
if self._is_batch_index(index):
# only create a numpy array column
return self._clone(data=data)
else:
return data
def _set_cell(self, index, value):
self._data.iloc[index] = value
def _set_batch(self, indices, values):
self._data.iloc[indices] = values
[docs] @classmethod
def concat(cls, columns: Sequence[PandasSeriesColumn]):
data = pd.concat([c.data for c in columns])
return columns[0]._clone(data=data)
def _write_data(self, path: str) -> None:
data_path = os.path.join(path, "data.pd")
self.data.to_pickle(data_path)
@staticmethod
def _read_data(
path: str,
):
data_path = os.path.join(path, "data.pd")
# Load in the data
return pd.read_pickle(data_path)
def _repr_cell(self, index) -> object:
return self[index]
[docs] def to_tensor(self) -> torch.Tensor:
"""Use `column.to_tensor()` instead of `torch.tensor(column)`, which is
very slow."""
dtype = self.data.values.dtype
if not np.issubdtype(dtype, np.number):
raise ValueError(
f"Cannot convert `PandasSeriesColumn` with dtype={dtype} to tensor."
)
# TODO (Sabri): understand why `torch.tensor(column)` is so slow
return torch.tensor(self.data.values)
[docs] def is_equal(self, other: AbstractColumn) -> bool:
if other.__class__ != self.__class__:
return False
return (self.data.values == other.data.values).all()
[docs] def to_pandas(self) -> pd.Series:
return self.data