Source code for meerkat.contrib.gqa

import json
import os
from typing import Dict, Mapping

from tqdm import tqdm

import meerkat as mk


[docs]def crop_object(row: Mapping[str, object]): img = row["image"] length = max(row["h"], row["w"]) box = ( max(row["x"] - ((length - row["w"]) / 2), 0), max(row["y"] - ((length - row["h"]) / 2), 0), min(row["x"] + row["w"] + ((length - row["w"]) / 2), img.width), min(row["y"] + row["h"] + ((length - row["h"]) / 2), img.height), ) return img.crop(box)
[docs]def build_gqa_dps(dataset_dir: str, write: bool = False) -> Dict[str, mk.DataPanel]: objects = [] images = [] relations = [] attributes = [] for split in ["train", "val"]: print(f"Loading {split} scene graphs...") with open(os.path.join(dataset_dir, f"{split}_sceneGraphs.json")) as f: graphs = json.load(f) for image_id, graph in tqdm(graphs.items()): image_id = int(image_id) # convert to int for faster filtering and joins for object_id, obj in graph.pop("objects").items(): object_id = int( object_id ) # convert to int for faster filtering and joins for relation in obj.pop("relations"): relations.append( { "subject_object_id": object_id, "object_id": int(relation["object"]), "name": relation["name"], } ) for attribute in obj.pop("attributes"): attributes.append({"object_id": object_id, "attribute": attribute}) objects.append({"object_id": object_id, "image_id": image_id, **obj}) images.append({"image_id": image_id, **graph}) # prepare DataPanels print("Preparing DataPanels...") image_dp = mk.DataPanel(images) image_dp["image"] = mk.ImageColumn( image_dp["image_id"].map( lambda x: os.path.join(dataset_dir, "images", f"{x}.jpg") ) ) object_dp = mk.DataPanel(objects).merge( image_dp[["image_id", "image", "height", "width"]], on="image_id" ) object_dp["object_image"] = object_dp.to_lambda(crop_object) # filter out objects with no width or height object_dp = object_dp.lz[(object_dp["h"] != 0) & (object_dp["w"] != 0)] # filter out objects whose bounding boxes are not contained within the image object_dp = object_dp.lz[ (object_dp["x"] < object_dp["width"]) & (object_dp["y"] < object_dp["height"]) ] dps = { "images": image_dp, "objects": object_dp, "relations": mk.DataPanel(relations), "attributes": mk.DataPanel(attributes), } if write: write_gqa_dps(dps=dps, dataset_dir=dataset_dir) return dps
[docs]def read_gqa_dps(dataset_dir: str) -> Dict[str, mk.DataPanel]: return { key: mk.DataPanel.read(os.path.join(dataset_dir, f"{key}.mk")) for key in ["attributes", "relations", "objects", "images"] }
[docs]def write_gqa_dps(dps: Mapping[str, mk.DataPanel], dataset_dir: str): for key, dp in dps.items(): dp.write(os.path.join(dataset_dir, f"{key}.mk"))