Source code for meerkat.columns.audio_column
from typing import Callable
import torch
from ..display import audio_file_formatter
from ..tools.lazy_loader import LazyLoader
from .file_column import FileColumn
torchaudio = LazyLoader("torchaudio")
[docs]class AudioColumn(FileColumn):
"""A lambda column where each cell represents an audio file on disk. The underlying
data is a `PandasSeriesColumn` of strings, where each string is the path to an
image. The column materializes the images into memory when indexed. If the column
is lazy indexed with the ``lz`` indexer, the images are not materialized and an
``FileCell`` or an ``AudioColumn`` is returned instead.
Args:
data (Sequence[str]): A list of filepaths to images.
transform (callable): A function that transforms the image (e.g.
``torchvision.transforms.functional.center_crop``).
.. warning::
In order for the column to be serializable, the transform function must
be pickleable.
loader (callable): A callable with signature ``def loader(filepath: str) ->
PIL.Image:``. Defaults to ``torchvision.datasets.folder.default_loader``.
.. warning::
In order for the column to be serializable with ``write()``, the loader
function must be pickleable.
base_dir (str): A base directory that the paths in ``data`` are relative to. If
``None``, the paths are assumed to be absolute.
"""
@staticmethod
def _get_default_formatter() -> Callable:
return audio_file_formatter
[docs] @classmethod
def default_loader(cls, *args, **kwargs):
return torchaudio.load(*args, **kwargs)[0]
def _repr_cell(self, idx):
return self.lz[idx]
[docs] def collate(self, batch):
tensors = [b.t() for b in batch]
tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
tensors = tensors.transpose(1, -1)
return tensors