Source code for meerkat.columns.file_column

from __future__ import annotations

import logging
import os
import urllib.request
from typing import Collection, Sequence
from urllib.error import HTTPError
from urllib.parse import urlparse

from meerkat.columns.abstract import AbstractColumn
from meerkat.columns.lambda_column import LambdaCell, LambdaColumn
from meerkat.columns.pandas_column import PandasSeriesColumn
from meerkat.tools.lazy_loader import LazyLoader

folder = LazyLoader("torchvision.datasets.folder")

logger = logging.getLogger(__name__)


class FileLoaderMixin:
    def fn(self, filepath: str):
        absolute_path = (
            os.path.join(self.base_dir, filepath)
            if self.base_dir is not None
            else filepath
        )
        image = self.loader(absolute_path)

        if self.transform is not None:
            image = self.transform(image)
        return image


[docs]class FileCell(FileLoaderMixin, LambdaCell): def __init__( self, transform: callable = None, loader: callable = None, data: str = None, base_dir: str = None, ): self.loader = self.default_loader if loader is None else loader self.transform = transform self._data = data self.base_dir = base_dir @property def absolute_path(self): return ( os.path.join(self.base_dir, self.data) if self.base_dir is not None else self.data ) def __eq__(self, other): return ( (other.__class__ == self.__class__) and (self.data == other.data) and (self.transform == other.transform) and (self.loader == other.loader) ) def __repr__(self): transform = getattr(self.transform, "__qualname__", repr(self.transform)) dirs = self.data.split("/") short_path = ("" if len(dirs) <= 2 else ".../") + "/".join(dirs[-2:]) return f"{self.__class__.__name__}.({short_path}, transform={transform})"
[docs]class FileColumn(FileLoaderMixin, LambdaColumn): """A column where each cell represents an file stored on disk or the web. The underlying data is a `PandasSeriesColumn` of strings, where each string is the path to a file. The column materializes the files into memory when indexed. If the column is lazy indexed with the ``lz`` indexer, the files are not materialized and a ``FileCell`` or a ``FileColumn`` is returned instead. Args: data (Sequence[str]): A list of filepaths to images. transform (callable): A function that transforms the image (e.g. ``torchvision.transforms.functional.center_crop``). .. warning:: In order for the column to be serializable, the transform function must be pickleable. loader (callable): A callable with signature ``def loader(filepath: str) -> PIL.Image:``. Defaults to ``torchvision.datasets.folder.default_loader``. .. warning:: In order for the column to be serializable with ``write()``, the loader function must be pickleable. base_dir (str): A base directory that the paths in ``data`` are relative to. If ``None``, the paths are assumed to be absolute. """ def __init__( self, data: Sequence[str] = None, transform: callable = None, loader: callable = None, base_dir: str = None, *args, **kwargs, ): if not isinstance(data, PandasSeriesColumn): data = PandasSeriesColumn(data) super(FileColumn, self).__init__(data, *args, **kwargs) self.loader = self.default_loader if loader is None else loader self.transform = transform self.base_dir = base_dir def _create_cell(self, data: object) -> FileCell: return FileCell( data=data, loader=self.loader, transform=self.transform, base_dir=self.base_dir, )
[docs] @classmethod def from_filepaths( cls, filepaths: Sequence[str], loader: callable = None, transform: callable = None, base_dir: str = None, *args, **kwargs, ): return cls( data=filepaths, loader=loader, transform=transform, base_dir=base_dir, *args, **kwargs, )
[docs] @classmethod def default_loader(cls, *args, **kwargs): return folder.default_loader(*args, **kwargs)
@classmethod def _state_keys(cls) -> Collection: return (super()._state_keys() | {"transform", "loader", "base_dir"}) - {"fn"} def _set_state(self, state: dict): state["base_dir"] = state.get("base_dir", None) # backwards compatibility super()._set_state(state)
[docs] def is_equal(self, other: AbstractColumn) -> bool: return ( (other.__class__ == self.__class__) and (self.loader == other.loader) and (self.transform == other.transform) and self.data.is_equal(other.data) )
def download_image(url: str, cache_dir: str): parse = urlparse(url) local_path = os.path.join(cache_dir, parse.netloc + parse.path) if not os.path.exists(local_path): os.makedirs(os.path.dirname(local_path), exist_ok=True) try: urllib.request.urlretrieve(url, local_path) except (HTTPError, ConnectionResetError): logger.warning(f"Could not download {url}. Skipping.") return None return folder.default_loader(local_path) class Downloader: def __init__( self, cache_dir: str, downloader: callable = None, ): self.cache_dir = cache_dir if not os.path.exists(cache_dir): os.makedirs(cache_dir) if downloader is None: self.downloader = download_image def __call__(self, url: str): return self.downloader(url, self.cache_dir)