Source code for meerkat.mixins.collate
from collections.abc import Callable
from typing import List
[docs]def identity_collate(batch: List):
return batch
[docs]class CollateMixin:
def __init__(self, collate_fn: Callable = None, *args, **kwargs):
super(CollateMixin, self).__init__(*args, **kwargs)
if collate_fn is not None:
self._collate_fn = collate_fn
else:
self._collate_fn = identity_collate
@property
def collate_fn(self):
"""Method used to collate."""
return self._collate_fn
@collate_fn.setter
def collate_fn(self, value):
self._collate_fn = value
[docs] def collate(self, *args, **kwargs):
"""Collate data."""
return self._collate_fn(*args, **kwargs)