import logging
from typing import Callable, Dict, Mapping, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from meerkat.provenance import capture_provenance
logger = logging.getLogger(__name__)
[docs]class MappableMixin:
def __init__(self, *args, **kwargs):
super(MappableMixin, self).__init__(*args, **kwargs)
[docs] @capture_provenance()
def map(
self,
function: Optional[Callable] = None,
with_indices: bool = False,
is_batched_fn: bool = False,
batch_size: Optional[int] = 1,
drop_last_batch: bool = False,
num_workers: Optional[int] = 0,
output_type: Union[type, Dict[str, type]] = None,
materialize: bool = True,
pbar: bool = False,
mmap: bool = False,
mmap_path: str = None,
flush_size: int = None,
**kwargs,
):
# TODO (sabri): add materialize?
from meerkat.columns.abstract import AbstractColumn
from meerkat.datapanel import DataPanel
"""Map a function over the elements of the column."""
# Return if `self` has no examples
if not len(self):
logger.info("Dataset empty, returning None.")
return None
if not is_batched_fn:
# Convert to a batch function
function = self._convert_to_batch_fn(
function, with_indices=with_indices, materialize=materialize, **kwargs
)
is_batched_fn = True
logger.info(f"Converting `function` {function} to a batched function.")
# Run the map
logger.info("Running `map`, the dataset will be left unchanged.")
for i, batch in tqdm(
enumerate(
self.batch(
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_workers=num_workers,
materialize=materialize
# TODO: collate=batched was commented out in list_column
)
),
total=(len(self) // batch_size)
+ int(not drop_last_batch and len(self) % batch_size != 0),
disable=not pbar,
):
# Calculate the start and end indexes for the batch
start_index = i * batch_size
end_index = min(len(self), (i + 1) * batch_size)
# Use the first batch for setup
if i == 0:
# Get some information about the function
function_properties = self._inspect_function(
function,
with_indices,
is_batched_fn,
batch,
range(start_index, end_index),
materialize=materialize,
**kwargs,
)
# Pull out information
output = function_properties.output
dtype = function_properties.output_dtype
is_mapping = isinstance(output, Mapping)
is_type_mapping = isinstance(output_type, Mapping)
if not is_mapping and is_type_mapping:
raise ValueError(
"output_type is a mapping but function output is not a mapping"
)
writers = {}
for key, curr_output in (
output.items() if is_mapping else [("0", output)]
):
curr_output_type = (
type(AbstractColumn.from_data(curr_output))
if output_type is None
or (is_type_mapping and key not in output_type.keys())
else output_type[key]
if is_type_mapping
else output_type
)
writer = curr_output_type.get_writer(
mmap=mmap,
template=(
curr_output.copy()
if isinstance(curr_output, AbstractColumn)
else None
),
)
# Setup for writing to a certain output column
# TODO: support optionally memmapping only some columns
if mmap:
if not hasattr(curr_output, "shape"):
curr_output = np.array(curr_output)
# Assumes first dimension of output is the batch dimension.
shape = (len(self), *curr_output.shape[1:])
# Construct the mmap file path
if mmap_path is None:
mmap_path = self.logdir / f"{hash(function)}" / key
# Open the output writer
writer.open(str(mmap_path), dtype, shape=shape)
else:
# Create an empty dict or list for the outputs
writer.open()
writers[key] = writer
else:
# Run `function` on the batch
output = (
function(
batch,
range(i * batch_size, min(len(self), (i + 1) * batch_size)),
**kwargs,
)
if with_indices
else function(batch, **kwargs)
)
# Append the output
if output is not None:
if isinstance(output, Mapping):
if set(output.keys()) != set(writers.keys()):
raise ValueError(
"Map function must return same keys for each batch."
)
for k, writer in writers.items():
writer.write(output[k])
else:
writers["0"].write(output)
# intermittently flush
if flush_size is not None and ((i + 1) % flush_size == 0):
for writer in writers.values():
writer.flush()
# Check if we are returning a special output type
outputs = {key: writer.finalize() for key, writer in writers.items()}
if not is_mapping:
outputs = outputs["0"]
else:
# TODO (arjundd): This is duck type. We should probably make this
# class signature explicit.
outputs = (
self._clone(data=outputs)
if isinstance(self, DataPanel)
else DataPanel.from_batch(outputs)
)
outputs._visible_columns = None
return outputs