Source code for meerkat.contrib.torchvision
import random
from PIL import Image
from torchvision.datasets.cifar import CIFAR10
import meerkat as mk
[docs]def get_torchvision_dataset(dataset_name, download_dir, is_train):
return NotImplemented
# dataset = torchvision.datasets.__dict__[dataset_name.upper()](
# root=download_dir,
# train=is_train,
# download=True,
# )
[docs]def get_cifar10(download_dir: str, frac_val: float = 0.0, download: bool = True):
"""Load CIFAR10 as a Meerkat DataPanel.
Args:
download_dir: download directory
frac_val: fraction of training set to use for validation
Returns:
a DataPanel containing columns `raw_image`, `image` and `label`
"""
dps = []
for split in ["train", "test"]:
dataset = CIFAR10(
root=download_dir,
train=split == "train",
download=download,
)
dp = mk.DataPanel(
{
"raw_image": dataset.data,
"label": mk.TensorColumn(dataset.targets),
}
)
if split == "train" and frac_val > 1e-4:
# sample indices for splitting off val
val_indices = set(
random.sample(
range(len(dp)),
int(frac_val * len(dp)),
)
)
dp["split"] = [
"train" if i not in val_indices else "val" for i in range(len(dp))
]
else:
dp["split"] = [split] * len(dataset)
dps.append(dp)
dp = mk.concat(dps)
dp["image"] = mk.LambdaColumn(dp["raw_image"], Image.fromarray)
return dp