from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Hashable, List, Sequence, Union
import numpy as np
import pandas as pd
import pyarrow as pa
import torch
from meerkat.block.ref import BlockRef
from meerkat.columns.numpy_column import NumpyArrayColumn
from meerkat.columns.tensor_column import TensorColumn
from .abstract import AbstractBlock, BlockIndex, BlockView
[docs]class ArrowBlock(AbstractBlock):
[docs] @dataclass(eq=True, frozen=True)
class Signature:
nrows: int
klass: type
# mmap: bool
def __init__(self, data: pa.Table, *args, **kwargs):
super(ArrowBlock, self).__init__(*args, **kwargs)
self.data = data
@property
def signature(self) -> Hashable:
return self.Signature(klass=ArrowBlock, nrows=len(self.data))
def _get_data(self, index: BlockIndex) -> pa.Array:
return self.data[index]
[docs] @classmethod
def from_column_data(cls, data: pa.Array) -> BlockView:
data = pa.Table.from_pydict({"col": data})
block = cls(data)
return BlockView(block=block, block_index="col")
[docs] @classmethod
def from_block_data(cls, data: pa.Table) -> List[BlockView]:
block = cls(data)
return [
BlockView(block=block, block_index=column) for column in data.column_names
]
@classmethod
def _consolidate(
cls,
block_refs: Sequence[BlockRef],
) -> BlockRef:
table = pa.Table.from_pydict(
# need to ignore index when concatenating
{
name: ref.block.data[col._block_index]
for ref in block_refs
for name, col in ref.items()
}
)
block = cls(table)
# pull out the block columns from all the block_refs
columns = {}
for ref in block_refs:
columns.update(ref)
new_columns = {
name: col._clone(data=block[name]) for name, col in columns.items()
}
return BlockRef(block=block, columns=new_columns)
@staticmethod
def _convert_index(index):
if isinstance(index, list):
return np.array(index)
if torch.is_tensor(index):
# need to convert to numpy for boolean indexing
return index.numpy()
if isinstance(index, NumpyArrayColumn):
return index.data
if isinstance(index, TensorColumn):
# need to convert to numpy for boolean indexing
return index.data.numpy()
if isinstance(index, pd.Series):
# need to convert to numpy for boolean indexing
return index.values
from meerkat.columns.pandas_column import PandasSeriesColumn
if isinstance(index, PandasSeriesColumn):
return index.data.values
return index
def _get(
self, index, block_ref: BlockRef, materialize: bool = True
) -> Union[BlockRef, dict]:
index = self._convert_index(index)
# TODO: check if they're trying to index more than just the row dimension
if isinstance(index, int):
# if indexing a single row, we do not return a block manager, just a dict
return {
name: self.data[col._block_index][index]
for name, col in block_ref.columns.items()
}
if isinstance(index, slice):
data = self.data[index]
elif index.dtype == bool:
data = self.data.filter(pa.array(index))
else:
# we do not want to use ``data = self.data.take(index)``
# because it can't handle ChunkedArrays that don't fit in an Array
# https://issues.apache.org/jira/browse/ARROW-9773
# TODO (Sabri): Huggingface gets around this in a similar manner but
# applies the slices to the record batches, because this allows them to do
# the batch lookup in numpy, which is faster than pure python, which is
# presumably why Table.slice does
# noqa E501, https://github.com/huggingface/datasets/blob/491dad8507792f6f51077867e22412af7cd5c2f1/src/datasets/table.py#L110
data = pa.concat_tables(self.data.slice(i, 1) for i in index)
block = self.__class__(data)
columns = {
name: col._clone(data=block[col._block_index])
for name, col in block_ref.columns.items()
}
# note that the new block may share memory with the old block
return BlockRef(block=block, columns=columns)
@staticmethod
def _write_table(path: str, table: pa.Table):
# noqa E501, source: huggingface implementation https://github.com/huggingface/datasets/blob/92304b42cf0cc6edafc97832c07de767b81306a6/src/datasets/table.py#L50
with open(path, "wb") as sink:
writer = pa.RecordBatchStreamWriter(sink=sink, schema=table.schema)
batches: List[pa.RecordBatch] = table.to_batches()
for batch in batches:
writer.write_batch(batch)
writer.close()
return sum(batch.nbytes for batch in batches)
@staticmethod
def _read_table(path: str, mmap: bool = False):
if mmap:
return pa.ipc.open_stream(pa.memory_map(path)).read_all()
else:
return pa.ipc.open_stream(pa.input_stream(path)).read_all()
def _write_data(self, path: str):
self._write_table(os.path.join(path, "data.arrow"), self.data)
@staticmethod
def _read_data(path: str, mmap: bool = False):
return ArrowBlock._read_table(os.path.join(path, "data.arrow"), mmap=mmap)