Source code for meerkat.datasets.inaturalist

import json
import os
from typing import List

import pandas as pd
from torchvision.datasets.utils import download_and_extract_archive

import meerkat as mk

IMAGES_URLS = {
    "train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
    "val": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
    "test": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/public_test.tar.gz",  # noqa: E501
}

INFO_URLS = {
    "train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.json.tar.gz",  # noqa: E501
    "val": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.json.tar.gz",
    "test": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/public_test.json.tar.gz",  # noqa: E501
}


[docs]def build_inaturalist_dp( dataset_dir: str, download: bool = True, splits: List[str] = None ) -> mk.DataPanel: """Build a DataPanel from the inaturalist dataset. Args: dataset_dir: The directory to store the dataset in. download: Whether to download the dataset if it does not yet exist. splits: A list of splits to include. Defaults to all splits. """ if splits is None: splits = ["train", "test", "val"] dps = [] for split in splits: if not os.path.exists(os.path.join(dataset_dir, split)) and download: download_and_extract_archive( IMAGES_URLS[split], download_root=dataset_dir, remove_finished=True ) if not os.path.exists(os.path.join(dataset_dir, f"{split}.json")) and download: download_and_extract_archive( INFO_URLS[split], download_root=dataset_dir, remove_finished=True ) with open(os.path.join(dataset_dir, f"{split}.json"), "r") as f: info = json.load(f) dp = mk.DataPanel(info["images"]) # need to rename "id" so there aren't conflicts with the "id" in other # datapanels (see annotations below) dp["image_id"] = dp["id"] dp.remove_column("id") # add image column dp["image"] = mk.ImageColumn(dp["file_name"], base_dir=dataset_dir) dp["date"] = pd.to_datetime(dp["date"]) # add annotations for each image if split != "test": # no annotations for test set annotation_dp = mk.DataPanel(info["annotations"]) annotation_dp["annotation_id"] = annotation_dp["id"] annotation_dp.remove_column("id") dp = dp.merge(annotation_dp, on="image_id") # join on the category table to get the category name category_dp = mk.DataPanel(info["categories"]) category_dp["category_id"] = category_dp["id"] category_dp.remove_column("id") dp = dp.merge(category_dp, on="category_id") dps.append(dp) return mk.concat(dps) if len(dps) > 1 else dps[0]