from __future__ import annotations
import os
from typing import Sequence, Set
import pyarrow as pa
import torch
from pyarrow.compute import equal
from meerkat.block.abstract import BlockView
from meerkat.block.arrow_block import ArrowBlock
from meerkat.errors import ImmutableError
from ..abstract import Column
from .abstract import ScalarColumn
[docs]class ArrowScalarColumn(ScalarColumn):
block_class: type = ArrowBlock
def __init__(
self,
data: Sequence,
*args,
**kwargs,
):
if isinstance(data, BlockView):
if not isinstance(data.block, ArrowBlock):
raise ValueError(
"ArrowArrayColumn can only be initialized with ArrowBlock."
)
elif not isinstance(data, (pa.Array, pa.ChunkedArray)):
# Arrow cannot construct an array from a torch.Tensor.
if isinstance(data, torch.Tensor):
data = data.numpy()
data = pa.array(data)
super(ArrowScalarColumn, self).__init__(data=data, *args, **kwargs)
def _get(self, index, materialize: bool = True):
index = ArrowBlock._convert_index(index)
if isinstance(index, slice) or isinstance(index, int):
data = self._data[index]
elif index.dtype == bool:
data = self._data.filter(pa.array(index))
else:
data = self._data.take(index)
if self._is_batch_index(index):
return self._clone(data=data)
else:
# Convert to Python object for consistency with other ScalarColumn
# implementations.
return data.as_py()
def _set(self, index, value):
raise ImmutableError("ArrowArrayColumn is immutable.")
def _repr_cell(self, index) -> object:
return self.data[index]
[docs] def is_equal(self, other: Column) -> bool:
if other.__class__ != self.__class__:
return False
return equal(self.data, other.data)
@classmethod
def _state_keys(cls) -> Set:
return super()._state_keys()
def _write_data(self, path):
table = pa.Table.from_arrays([self.data], names=["0"])
ArrowBlock._write_table(os.path.join(path, "data.arrow"), table)
@staticmethod
def _read_data(path, mmap=False):
table = ArrowBlock._read_table(os.path.join(path, "data.arrow"), mmap=mmap)
return table["0"]
@classmethod
def concat(cls, columns: Sequence[ArrowScalarColumn]):
arrays = []
for c in columns:
if isinstance(c.data, pa.Array):
arrays.append(c.data)
elif isinstance(c.data, pa.ChunkedArray):
arrays.extend(c.data.chunks)
else:
raise ValueError(f"Unexpected type {type(c.data)}")
data = pa.concat_arrays(arrays)
return columns[0]._clone(data=data)
[docs] def to_numpy(self):
return self.data.to_numpy()
def to_tensor(self):
return torch.tensor(self.data.to_numpy())
[docs] def to_pandas(self, allow_objects: bool = False):
return self.data.to_pandas()
[docs] def to_arrow(self) -> pa.Array:
return self.data