Source code for meerkat.columns.arrow_column

from __future__ import annotations

import os
from typing import Sequence, Set

import pyarrow as pa
import torch
from pyarrow.compute import equal

from meerkat.block.abstract import BlockView
from meerkat.block.arrow_block import ArrowBlock
from meerkat.columns.abstract import AbstractColumn
from meerkat.errors import ImmutableError


[docs]class ArrowArrayColumn( AbstractColumn, ): block_class: type = ArrowBlock def __init__( self, data: Sequence, *args, **kwargs, ): if isinstance(data, BlockView): if not isinstance(data.block, ArrowBlock): raise ValueError( "ArrowArrayColumn can only be initialized with ArrowBlock." ) elif not isinstance(data, (pa.Array, pa.ChunkedArray)): data = pa.array(data) super(ArrowArrayColumn, self).__init__(data=data, *args, **kwargs) def _get(self, index, materialize: bool = True): index = ArrowBlock._convert_index(index) if isinstance(index, slice) or isinstance(index, int): data = self._data[index] elif index.dtype == bool: data = self._data.filter(pa.array(index)) else: data = self._data.take(index) if self._is_batch_index(index): return self._clone(data=data) else: return data def _set(self, index, value): raise ImmutableError("ArrowArrayColumn is immutable.") def _repr_cell(self, index) -> object: return self.data[index]
[docs] def is_equal(self, other: AbstractColumn) -> bool: if other.__class__ != self.__class__: return False return equal(self.data, other.data)
@classmethod def _state_keys(cls) -> Set: return super()._state_keys() def _write_data(self, path): table = pa.Table.from_arrays([self.data], names=["0"]) ArrowBlock._write_table(os.path.join(path, "data.arrow"), table) @staticmethod def _read_data(path, mmap=False): table = ArrowBlock._read_table(os.path.join(path, "data.arrow"), mmap=mmap) return table["0"]
[docs] @classmethod def concat(cls, columns: Sequence[ArrowArrayColumn]): data = pa.concat_arrays([c.data for c in columns]) return columns[0]._clone(data=data)
[docs] def to_numpy(self): return self.data.to_numpy(zero_copy_only=False)
[docs] def to_tensor(self): return torch.tensor(self.data.to_numpy())
[docs] def to_pandas(self): return self.data.to_pandas()