Source code for meerkat.datasets.expw

import os
import subprocess

import pandas as pd

import meerkat as mk

from ..abstract import DatasetBuilder
from ..info import DatasetInfo
from ..registry import datasets
from ..utils import download_google_drive


[docs]@datasets.register() class expw(DatasetBuilder): VERSION_TO_GDRIVE_ID = {"main": "19Eb_WiTsWelYv7Faff0L5Lmo1zv0vzwR"} VERSIONS = ["main"] info = DatasetInfo( name="expw", full_name="Expression in-the-Wild", description=( "Imagenette is a subset of 10 easily classified classes from Imagenet " "(tench, English springer, cassette player, chain saw, church, " "French horn, garbage truck, gas pump, golf ball, parachute)." ), homepage="https://github.com/fastai/imagenette", tags=["image", "classification"], )
[docs] def build(self): df = pd.read_csv( os.path.join(self.dataset_dir, "label/label.lst"), delimiter=" ", names=[ "image_name", "face_id_in_image", "face_box_top", "face_box_left", "face_box_right", "face_box_bottom", "face_box_cofidence", "expression_label", ], ) df = df.drop_duplicates() dp = mk.DataPanel.from_pandas(df) # ensure that all the image files are downloaded if ( not dp["image_name"] .apply( lambda name: os.path.exists( os.path.join(self.dataset_dir, "image/origin", name) ) ) .all() ): raise ValueError( "Some images are not downloaded to expected directory: " f"{os.path.join(self.dataset_dir, 'image/origin')}. Verify download." ) # remove file extension and add the face_id dp["example_id"] = ( dp["image_name"].str.replace(".jpg", "", regex=False) + "_" + dp["face_id_in_image"].astype(str) ) dp["image"] = mk.ImageColumn.from_filepaths( "image/origin/" + dp["image_name"], base_dir=self.dataset_dir ) dp["face_image"] = dp[ "image", "face_box_top", "face_box_left", "face_box_right", "face_box_bottom", ].to_lambda(crop) return dp
[docs] def download(self): gdrive_id = self.VERSION_TO_GDRIVE_ID[self.version] download_google_drive(id=gdrive_id, dst=self.dataset_dir, is_folder=True) os.makedirs(os.path.join(self.dataset_dir, "image"), exist_ok=True) for file in os.listdir(os.path.join(self.dataset_dir, "image")): # run 7za to extract the file using subprocess subprocess.call( [ "7za", "x", os.path.join(self.dataset_dir, "image", file), "-o" + os.path.join(self.dataset_dir, "image"), ] )
def crop(row: dict): return row["image"].crop( ( row["face_box_left"], row["face_box_top"], row["face_box_right"], row["face_box_bottom"], ) )