Source code for meerkat.columns.numpy_column

from __future__ import annotations

import abc
import functools
import logging
import numbers
import os
import shutil
from mmap import mmap
from typing import Callable, Sequence

import numpy as np
import pandas as pd
import torch
from yaml.representer import Representer

from meerkat.block.abstract import BlockView
from meerkat.block.numpy_block import NumpyBlock
from meerkat.columns.abstract import AbstractColumn
from meerkat.writers.concat_writer import ConcatWriter

Representer.add_representer(abc.ABCMeta, Representer.represent_name)

logger = logging.getLogger(__name__)


[docs]def getattr_decorator(fn: Callable): @functools.wraps(fn) def wrapper(*args, **kwargs): out = fn(*args, **kwargs) if isinstance(out, np.ndarray): return NumpyArrayColumn(out) else: return out return wrapper
[docs]class NumpyArrayColumn( AbstractColumn, np.lib.mixins.NDArrayOperatorsMixin, ): block_class: type = NumpyBlock def __init__( self, data: Sequence, *args, **kwargs, ): if isinstance(data, BlockView): if not isinstance(data.block, NumpyBlock): raise ValueError( "Cannot create `NumpyArrayColumn` from a `BlockView` not " "referencing a `NumpyBlock`." ) elif not isinstance(data, np.memmap): data = np.asarray(data) super(NumpyArrayColumn, self).__init__(data=data, *args, **kwargs) # TODO (sabri): need to support str here _HANDLED_TYPES = (np.ndarray, numbers.Number) def __array_ufunc__(self, ufunc: np.ufunc, method, *inputs, **kwargs): out = kwargs.get("out", ()) for x in inputs + out: # Only support operations with instances of _HANDLED_TYPES. # Use ArrayLike instead of type(self) for isinstance to # allow subclasses that don't override __array_ufunc__ to # handle ArrayLike objects. if not isinstance(x, self._HANDLED_TYPES + (NumpyArrayColumn,)) and not ( # support for at index method == "at" and isinstance(x, list) ): return NotImplemented # Defer to the implementation of the ufunc on unwrapped values. inputs = tuple(x.data if isinstance(x, NumpyArrayColumn) else x for x in inputs) if out: kwargs["out"] = tuple( x.data if isinstance(x, NumpyArrayColumn) else x for x in out ) result = getattr(ufunc, method)(*inputs, **kwargs) if type(result) is tuple: # multiple return values return tuple(type(self)(x) for x in result) elif method == "at": # no return value return None else: # one return value return self._clone(data=result) def __getattr__(self, name): try: out = getattr(object.__getattribute__(self, "data"), name) if isinstance(out, Callable): return getattr_decorator(out) else: return out except AttributeError: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" )
[docs] @classmethod def from_array(cls, data: np.ndarray, *args, **kwargs): return cls(data=data, *args, **kwargs)
def _set_batch(self, indices, values): self._data[indices] = values def _get(self, index, materialize: bool = True): index = NumpyBlock._convert_index(index) data = self._data[index] if self._is_batch_index(index): # only create a numpy array column return self._clone(data=data) else: return data def _copy_data(self) -> object: return self._data.copy() def _view_data(self) -> object: return self._data @property def is_mmap(self): # important to check if .base is a python mmap object, since a view of a mmap # is also a memmap object, but should not be symlinked or copied return isinstance(self.data, np.memmap) and isinstance(self.data.base, mmap) def _write_data(self, path: str, link: bool = True) -> None: path = os.path.join(path, "data.npy") # important to check if .base is a python mmap object, since a view of a mmap # is also a memmap object, but should not be symlinked if self.is_mmap: if link: os.symlink(self.data.filename, path) else: shutil.copy(self.data.filename, path) else: np.save(path, self.data) @staticmethod def _read_data(path: str, mmap=False, *args, **kwargs) -> np.ndarray: data_path = os.path.join(path, "data.npy") if mmap: return np.load(data_path, mmap_mode="r") return np.load(data_path)
[docs] @classmethod def concat(cls, columns: Sequence[NumpyArrayColumn]): data = np.concatenate([c.data for c in columns]) return columns[0]._clone(data=data)
[docs] def is_equal(self, other: AbstractColumn) -> bool: if other.__class__ != self.__class__: return False return np.array_equal(self.data, other.data, equal_nan=True)
[docs] @classmethod def get_writer(cls, mmap: bool = False, template: AbstractColumn = None): if mmap: from meerkat.writers.numpy_writer import NumpyMemmapWriter return NumpyMemmapWriter() else: return ConcatWriter(template=template, output_type=NumpyArrayColumn)
def _repr_cell(self, index) -> object: if len(self.shape) > 1: return f"np.ndarray(shape={self.shape[1:]})" else: return self[index]
[docs] def to_tensor(self) -> torch.Tensor: """Use `column.to_tensor()` instead of `torch.tensor(column)`, which is very slow.""" # TODO (Sabri): understand why `torch.tensor(column)` is so slow return torch.tensor(self.data)
[docs] def to_pandas(self) -> pd.Series: if len(self.shape) == 1: return pd.Series(self.data) else: # can only create a 1-D series return super().to_pandas()
[docs] @classmethod def from_npy( cls, path, mmap_mode=None, allow_pickle=False, fix_imports=True, encoding="ASCII", ): data = np.load( path, mmap_mode=mmap_mode, allow_pickle=allow_pickle, fix_imports=fix_imports, encoding=encoding, ) return cls(data)