Source code for meerkat.contrib.wilds.transforms
import torch
from meerkat.tools.lazy_loader import LazyLoader
transforms = LazyLoader("torchvision.transforms")
transformers = LazyLoader("transformers")
[docs]def initialize_transform(transform_name, config, dataset):
if transform_name is None:
return None
elif transform_name == "bert":
return initialize_bert_transform(config)
elif transform_name == "image_base":
return initialize_image_base_transform(config, dataset)
elif transform_name == "image_resize_and_center_crop":
return initialize_image_resize_and_center_crop_transform(config, dataset)
elif transform_name == "poverty_train":
return initialize_poverty_train_transform()
else:
raise ValueError(f"{transform_name} not recognized")
[docs]def initialize_bert_transform(config):
assert "bert" in config.model
assert config.max_token_length is not None
tokenizer = getBertTokenizer(config.model)
def transform(text):
tokens = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=config.max_token_length,
return_tensors="pt",
)
if config.model == "bert-base-uncased":
x = torch.stack(
(
tokens["input_ids"],
tokens["attention_mask"],
tokens["token_type_ids"],
),
dim=2,
)
elif config.model == "distilbert-base-uncased":
x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2)
x = torch.squeeze(x, dim=0) # First shape dim is always 1
return x
return transform
[docs]def getBertTokenizer(model):
if model == "bert-base-uncased":
tokenizer = transformers.BertTokenizerFast.from_pretrained(model)
elif model == "distilbert-base-uncased":
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained(model)
else:
raise ValueError(f"Model: {model} not recognized.")
return tokenizer
[docs]def initialize_image_base_transform(config, dataset):
transform_steps = []
if dataset.original_resolution is not None and min(
dataset.original_resolution
) != max(dataset.original_resolution):
crop_size = min(dataset.original_resolution)
transform_steps.append(transforms.CenterCrop(crop_size))
if config.target_resolution is not None and config.dataset != "fmow":
transform_steps.append(transforms.Resize(config.target_resolution))
transform_steps += [
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
transform = transforms.Compose(transform_steps)
return transform
[docs]def initialize_image_resize_and_center_crop_transform(config, dataset):
"""Resizes the image to a slightly larger square then crops the center."""
assert dataset.original_resolution is not None
assert config.resize_scale is not None
scaled_resolution = tuple(
int(res * config.resize_scale) for res in dataset.original_resolution
)
if config.target_resolution is not None:
target_resolution = config.target_resolution
else:
target_resolution = dataset.original_resolution
transform = transforms.Compose(
[
transforms.Resize(scaled_resolution),
transforms.CenterCrop(target_resolution),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
return transform
[docs]def initialize_poverty_train_transform():
transforms_ls = [
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1),
transforms.ToTensor(),
]
rgb_transform = transforms.Compose(transforms_ls)
def transform_rgb(img):
# bgr to rgb and back to bgr
img[:3] = rgb_transform(img[:3][[2, 1, 0]])[[2, 1, 0]]
return img
transform = transforms.Lambda(lambda x: transform_rgb(x))
return transform