Source code for meerkat.columns.scalar.arrow

from __future__ import annotations

import os
from typing import Sequence, Set

import pyarrow as pa
import torch
from pyarrow.compute import equal

from meerkat.block.abstract import BlockView
from meerkat.block.arrow_block import ArrowBlock
from meerkat.errors import ImmutableError

from ..abstract import Column
from .abstract import ScalarColumn


[docs]class ArrowScalarColumn(ScalarColumn): block_class: type = ArrowBlock def __init__( self, data: Sequence, *args, **kwargs, ): if isinstance(data, BlockView): if not isinstance(data.block, ArrowBlock): raise ValueError( "ArrowArrayColumn can only be initialized with ArrowBlock." ) elif not isinstance(data, (pa.Array, pa.ChunkedArray)): # Arrow cannot construct an array from a torch.Tensor. if isinstance(data, torch.Tensor): data = data.numpy() data = pa.array(data) super(ArrowScalarColumn, self).__init__(data=data, *args, **kwargs) def _get(self, index, materialize: bool = True): index = ArrowBlock._convert_index(index) if isinstance(index, slice) or isinstance(index, int): data = self._data[index] elif index.dtype == bool: data = self._data.filter(pa.array(index)) else: data = self._data.take(index) if self._is_batch_index(index): return self._clone(data=data) else: # Convert to Python object for consistency with other ScalarColumn # implementations. return data.as_py() def _set(self, index, value): raise ImmutableError("ArrowArrayColumn is immutable.") def _repr_cell(self, index) -> object: return self.data[index]
[docs] def is_equal(self, other: Column) -> bool: if other.__class__ != self.__class__: return False return equal(self.data, other.data)
@classmethod def _state_keys(cls) -> Set: return super()._state_keys() def _write_data(self, path): table = pa.Table.from_arrays([self.data], names=["0"]) ArrowBlock._write_table(os.path.join(path, "data.arrow"), table) @staticmethod def _read_data(path, mmap=False): table = ArrowBlock._read_table(os.path.join(path, "data.arrow"), mmap=mmap) return table["0"] @classmethod def concat(cls, columns: Sequence[ArrowScalarColumn]): arrays = [] for c in columns: if isinstance(c.data, pa.Array): arrays.append(c.data) elif isinstance(c.data, pa.ChunkedArray): arrays.extend(c.data.chunks) else: raise ValueError(f"Unexpected type {type(c.data)}") data = pa.concat_arrays(arrays) return columns[0]._clone(data=data)
[docs] def to_numpy(self): return self.data.to_numpy()
def to_tensor(self): return torch.tensor(self.data.to_numpy())
[docs] def to_pandas(self, allow_objects: bool = False): return self.data.to_pandas()
[docs] def to_arrow(self) -> pa.Array: return self.data