Source code for meerkat.mixins.io

import os

import dill
import yaml

from meerkat.tools.utils import MeerkatLoader


[docs]class ColumnIOMixin:
[docs] def write(self, path: str, *args, **kwargs) -> None: assert hasattr(self, "_data"), f"{self.__class__.__name__} requires `self.data`" # Make all the directories to the path os.makedirs(path, exist_ok=True) metadata = self._get_meta() # Write the state self._write_state(path) self._write_data(path, *args, **kwargs) # Save the metadata as a yaml file metadata_path = os.path.join(path, "meta.yaml") yaml.dump(metadata, open(metadata_path, "w")) return metadata
def _get_meta(self): return { "dtype": type(self), "len": len(self), **self.metadata, } def _write_data(self, path): data_path = os.path.join(path, "data.dill") dill.dump(self.data, open(data_path, "wb")) def _write_state(self, path): state = self._get_state() state_path = os.path.join(path, "state.dill") dill.dump(state, open(state_path, "wb"))
[docs] @classmethod def read( cls, path: str, _data: object = None, _meta: object = None, *args, **kwargs ) -> object: # Assert that the path exists assert os.path.exists(path), f"`path` {path} does not exist." # Load in the metadata meta = ( dict( yaml.load( open(os.path.join(path, "meta.yaml")), Loader=MeerkatLoader, ) ) if _meta is None else _meta ) col_type = meta["dtype"] # Load states state = col_type._read_state(path) data = col_type._read_data(path, *args, **kwargs) if _data is None else _data col = col_type.__new__(col_type) col._set_state(state) col._set_data(data) return col
@staticmethod def _read_state(path: str): try: return dill.load(open(os.path.join(path, "state.dill"), "rb")) except ModuleNotFoundError: dill_str = open(os.path.join(path, "state.dill"), "rb").read() if b"meerkat.nn" in dill_str: # backwards compatibility # TODO (Sabri): remove this in a future release dill_str = dill_str.replace(b"meerkat.nn", b"meerkat.ml") return dill.loads(dill_str) @staticmethod def _read_data(path: str, *args, **kwargs): return dill.load(open(os.path.join(path, "data.dill"), "rb"))