Source code for meerkat.contrib.wilds.config

"""WILDS configuration defaults and operations.

All default configurations are integrated from the WILDS repository:
https://github.com/p-lambda/wilds/blob/ca7e4fd345aa4f998b8322191b083444d85e2a08/examples/configs/datasets.py
"""

base_config = {
    "split_scheme": None,
    "dataset_kwargs": {},
    "download": False,
    "frac": 1.0,
    "version": None,
    "loader_kwargs": {},
    "train_loader": None,
    "uniform_over_groups": None,
    "distinct_groups": None,
    "n_groups_per_batch": None,
    "batch_size": None,
    "eval_loader": "standard",
    "model": None,
    "model_kwargs": {},
    "train_transform": None,
    "eval_transform": None,
    "target_resolution": None,
    "resize_scale": None,
    "max_token_length": None,
    "loss_function": None,
    "groupby_fields": None,
    "group_dro_step_size": None,
    "coral_penalty_weight": None,
    "irm_lambda": None,
    "irm_penalty_anneal_iters": None,
    "algo_log_metric": None,
    "val_metric": None,
    "val_metric_decreasing": None,
    "n_epochs": None,
    "optimizer": None,
    "lr": None,
    "weight_decay": None,
    "max_grad_norm": None,
    "optimizer_kwargs": {},
    "scheduler": None,
    "scheduler_kwargs": {},
    "scheduler_metric_split": "val",
    "scheduler_metric_name": None,
    "process_outputs_function": None,
    "evaluate_all_splits": True,
    "eval_splits": [],
    "eval_only": False,
    "eval_epoch": None,
    "device": 0,
    "seed": 0,
    "log_dir": "./logs",
    "log_every": 50,
    "save_step": None,
    "save_best": True,
    "save_last": True,
    "no_group_logging": None,
    "use_wandb": False,
    "progress_bar": False,
    "resume": False,
}

dataset_defaults = {
    "amazon": {
        "split_scheme": "official",
        "model": "distilbert-base-uncased",
        "train_transform": "bert",
        "eval_transform": "bert",
        "max_token_length": 512,
        "loss_function": "cross_entropy",
        "algo_log_metric": "accuracy",
        "batch_size": 8,
        "lr": 1e-5,
        "weight_decay": 0.01,
        "n_epochs": 3,
        "n_groups_per_batch": 2,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 1.0,
        "loader_kwargs": {
            "num_workers": 1,
            "pin_memory": True,
        },
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "bdd100k": {
        "split_scheme": "official",
        "model": "resnet50",
        "model_kwargs": {"pretrained": True},
        "loss_function": "multitask_bce",
        "val_metric": "acc_all",
        "val_metric_decreasing": False,
        "optimizer": "SGD",
        "optimizer_kwargs": {"momentum": 0.9},
        "batch_size": 32,
        "lr": 0.001,
        "weight_decay": 0.0001,
        "n_epochs": 10,
        "algo_log_metric": "multitask_binary_accuracy",
        "train_transform": "image_base",
        "eval_transform": "image_base",
        "process_outputs_function": "binary_logits_to_pred",
    },
    "camelyon17": {
        "split_scheme": "official",
        "model": "densenet121",
        "model_kwargs": {"pretrained": False},
        "train_transform": "image_base",
        "eval_transform": "image_base",
        "target_resolution": (96, 96),
        "loss_function": "cross_entropy",
        "groupby_fields": ["hospital"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
        "optimizer": "SGD",
        "optimizer_kwargs": {"momentum": 0.9},
        "scheduler": None,
        "batch_size": 32,
        "lr": 0.001,
        "weight_decay": 0.01,
        "n_epochs": 5,
        "n_groups_per_batch": 2,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 0.1,
        "algo_log_metric": "accuracy",
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "celebA": {
        "split_scheme": "official",
        "model": "resnet50",
        "model_kwargs": {"pretrained": True},
        "train_transform": "image_base",
        "eval_transform": "image_base",
        "loss_function": "cross_entropy",
        "groupby_fields": ["male", "y"],
        "val_metric": "acc_wg",
        "val_metric_decreasing": False,
        "optimizer": "SGD",
        "optimizer_kwargs": {"momentum": 0.9},
        "scheduler": None,
        "batch_size": 64,
        "lr": 0.001,
        "weight_decay": 0.0,
        "n_epochs": 200,
        "algo_log_metric": "accuracy",
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "civilcomments": {
        "split_scheme": "official",
        "model": "distilbert-base-uncased",
        "train_transform": "bert",
        "eval_transform": "bert",
        "loss_function": "cross_entropy",
        "groupby_fields": ["black", "y"],
        "val_metric": "acc_wg",
        "val_metric_decreasing": False,
        "batch_size": 16,
        "lr": 1e-5,
        "weight_decay": 0.01,
        "n_epochs": 5,
        "algo_log_metric": "accuracy",
        "max_token_length": 300,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 10.0,
        "loader_kwargs": {
            "num_workers": 1,
            "pin_memory": True,
        },
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "fmow": {
        "split_scheme": "official",
        "dataset_kwargs": {
            "oracle_training_set": False,
            "seed": 111,
            "use_ood_val": True,
        },
        "model": "densenet121",
        "model_kwargs": {"pretrained": True},
        "train_transform": "image_base",
        "eval_transform": "image_base",
        "loss_function": "cross_entropy",
        "groupby_fields": [
            "year",
        ],
        "val_metric": "acc_worst_region",
        "val_metric_decreasing": False,
        "optimizer": "Adam",
        "scheduler": "StepLR",
        "scheduler_kwargs": {"gamma": 0.96},
        "batch_size": 64,
        "lr": 0.0001,
        "weight_decay": 0.0,
        "n_epochs": 50,
        "n_groups_per_batch": 8,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 0.1,
        "algo_log_metric": "accuracy",
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "iwildcam": {
        "loss_function": "cross_entropy",
        "val_metric": "F1-macro_all",
        "model_kwargs": {"pretrained": True},
        "train_transform": "image_base",
        "eval_transform": "image_base",
        "target_resolution": (448, 448),
        "val_metric_decreasing": False,
        "algo_log_metric": "accuracy",
        "model": "resnet50",
        "lr": 3e-5,
        "weight_decay": 0.0,
        "batch_size": 16,
        "n_epochs": 12,
        "optimizer": "Adam",
        "split_scheme": "official",
        "scheduler": None,
        "groupby_fields": [
            "location",
        ],
        "n_groups_per_batch": 2,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 10.0,
        "no_group_logging": True,
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "ogb-molpcba": {
        "split_scheme": "official",
        "model": "gin-virtual",
        "model_kwargs": {"dropout": 0.5},  # include pretrained
        "loss_function": "multitask_bce",
        "groupby_fields": [
            "scaffold",
        ],
        "val_metric": "ap",
        "val_metric_decreasing": False,
        "optimizer": "Adam",
        "batch_size": 32,
        "lr": 1e-03,
        "weight_decay": 0.0,
        "n_epochs": 100,
        "n_groups_per_batch": 4,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 0.1,
        "no_group_logging": True,
        "process_outputs_function": None,
        "algo_log_metric": "multitask_binary_accuracy",
    },
    "py150": {
        "split_scheme": "official",
        "model": "code-gpt-py",
        "loss_function": "lm_cross_entropy",
        "val_metric": "acc",
        "val_metric_decreasing": False,
        "optimizer": "AdamW",
        "optimizer_kwargs": {"eps": 1e-8},
        "lr": 8e-5,
        "weight_decay": 0.0,
        "n_epochs": 3,
        "batch_size": 6,
        "groupby_fields": [
            "repo",
        ],
        "n_groups_per_batch": 2,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 1.0,
        "no_group_logging": True,
        "algo_log_metric": "multitask_accuracy",
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "poverty": {
        "split_scheme": "official",
        "dataset_kwargs": {
            "no_nl": False,
            "fold": "A",
            "oracle_training_set": False,
            "use_ood_val": True,
        },
        "model": "resnet18_ms",
        "model_kwargs": {"num_channels": 8},
        "train_transform": "poverty_train",
        "eval_transform": None,
        "loss_function": "mse",
        "groupby_fields": [
            "country",
        ],
        "val_metric": "r_wg",
        "val_metric_decreasing": False,
        "algo_log_metric": "mse",
        "optimizer": "Adam",
        "scheduler": "StepLR",
        "scheduler_kwargs": {"gamma": 0.96},
        "batch_size": 64,
        "lr": 0.001,
        "weight_decay": 0.0,
        "n_epochs": 200,
        "n_groups_per_batch": 8,
        "irm_lambda": 1.0,
        "coral_penalty_weight": 0.1,
        "process_outputs_function": None,
    },
    "waterbirds": {
        "split_scheme": "official",
        "model": "resnet50",
        "train_transform": "image_resize_and_center_crop",
        "eval_transform": "image_resize_and_center_crop",
        "resize_scale": 256.0 / 224.0,
        "model_kwargs": {"pretrained": True},
        "loss_function": "cross_entropy",
        "groupby_fields": ["background", "y"],
        "val_metric": "acc_wg",
        "val_metric_decreasing": False,
        "algo_log_metric": "accuracy",
        "optimizer": "SGD",
        "optimizer_kwargs": {"momentum": 0.9},
        "scheduler": None,
        "batch_size": 128,
        "lr": 1e-5,
        "weight_decay": 1.0,
        "n_epochs": 300,
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "yelp": {
        "split_scheme": "official",
        "model": "bert-base-uncased",
        "train_transform": "bert",
        "eval_transform": "bert",
        "max_token_length": 512,
        "loss_function": "cross_entropy",
        "algo_log_metric": "accuracy",
        "batch_size": 8,
        "lr": 2e-6,
        "weight_decay": 0.01,
        "n_epochs": 3,
        "n_groups_per_batch": 2,
        "process_outputs_function": "multiclass_logits_to_pred",
    },
    "sqf": {
        "split_scheme": "all_race",
        "model": "logistic_regression",
        "train_transform": None,
        "eval_transform": None,
        "model_kwargs": {"in_features": 104},
        "loss_function": "cross_entropy",
        "groupby_fields": ["y"],
        "val_metric": "precision_at_global_recall_all",
        "val_metric_decreasing": False,
        "algo_log_metric": "accuracy",
        "optimizer": "Adam",
        "optimizer_kwargs": {},
        "scheduler": None,
        "batch_size": 4,
        "lr": 5e-5,
        "weight_decay": 0,
        "n_epochs": 4,
        "process_outputs_function": None,
    },
}

##########################################
#   Split-specific defaults for Amazon   #
##########################################

amazon_split_defaults = {
    "official": {
        "groupby_fields": ["user"],
        "val_metric": "10th_percentile_acc",
        "val_metric_decreasing": False,
        "no_group_logging": True,
    },
    "user": {
        "groupby_fields": ["user"],
        "val_metric": "10th_percentile_acc",
        "val_metric_decreasing": False,
        "no_group_logging": True,
    },
    "time": {
        "groupby_fields": ["year"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    },
    "time_baseline": {
        "groupby_fields": ["year"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    },
}

user_baseline_splits = [
    "A1CNQTCRQ35IMM_baseline",
    "A1NE43T0OM6NNX_baseline",
    "A1UH21GLZTYYR5_baseline",
    "A20EEWWSFMZ1PN_baseline",
    "A219Y76LD1VP4N_baseline",
    "A37BRR2L8PX3R2_baseline",
    "A3JVZY05VLMYEM_baseline",
    "A9Q28YTLYREO7_baseline",
    "ASVY5XSYJ1XOE_baseline",
    "AV6QDP8Q0ONK4_baseline",
]
for split in user_baseline_splits:
    amazon_split_defaults[split] = {
        "groupby_fields": ["user"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    }

category_splits = [
    "arts_crafts_and_sewing_generalization",
    "automotive_generalization",
    "books,movies_and_tv,home_and_kitchen,electronics_generalization",
    "books_generalization",
    "category_subpopulation",
    "cds_and_vinyl_generalization",
    "cell_phones_and_accessories_generalization",
    "clothing_shoes_and_jewelry_generalization",
    "digital_music_generalization",
    "electronics_generalization",
    "grocery_and_gourmet_food_generalization",
    "home_and_kitchen_generalization",
    "industrial_and_scientific_generalization",
    "kindle_store_generalization",
    "luxury_beauty_generalization",
    "movies_and_tv,books,home_and_kitchen_generalization",
    "movies_and_tv,books_generalization",
    "movies_and_tv_generalization",
    "musical_instruments_generalization",
    "office_products_generalization",
    "patio_lawn_and_garden_generalization",
    "pet_supplies_generalization",
    "prime_pantry_generalization",
    "sports_and_outdoors_generalization",
    "tools_and_home_improvement_generalization",
    "toys_and_games_generalization",
    "video_games_generalization",
]
for split in category_splits:
    amazon_split_defaults[split] = {
        "groupby_fields": ["category"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    }

########################################
#   Split-specific defaults for Yelp   #
########################################

yelp_split_defaults = {
    "official": {
        "groupby_fields": ["year"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    },
    "user": {
        "groupby_fields": ["user"],
        "val_metric": "10th_percentile_acc",
        "val_metric_decreasing": False,
        "no_group_logging": True,
    },
    "time": {
        "groupby_fields": ["year"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    },
    "time_baseline": {
        "groupby_fields": ["year"],
        "val_metric": "acc_avg",
        "val_metric_decreasing": False,
    },
}

###############################
#   Split-specific defaults   #
###############################

split_defaults = {
    "amazon": amazon_split_defaults,
    "yelp": yelp_split_defaults,
}

loader_defaults = {
    "loader_kwargs": {
        "num_workers": 4,
        "pin_memory": True,
    },
    "n_groups_per_batch": 4,
}


[docs]def populate_defaults(config): """Populates hyperparameters with defaults implied by choices of other hyperparameters.""" assert config.dataset is not None, "dataset must be specified" # implied defaults from choice of dataset config = populate_config(config, dataset_defaults[config.dataset]) # implied defaults from choice of split if ( config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset] ): config = populate_config( config, split_defaults[config.dataset][config.split_scheme] ) # implied defaults from choice of loader config = populate_config(config, loader_defaults) # basic checks required_fields = [ "split_scheme", ] for field in required_fields: assert ( getattr(config, field) is not None ), f"Must manually specify {field} for this setup." return config
[docs]def populate_config(config, template: dict, force_compatibility=False): """Populates missing (key, val) pairs in config with (key, val) in template. Example usage: populate config with defaults Args: - config: namespace - template: dict - force_compatibility: option to raise errors if config.key != template[key] """ if template is None: return config d_config = vars(config) for key, val in template.items(): if not isinstance(val, dict): # config[key] expected to be a non-index-able if key not in d_config or d_config[key] is None: d_config[key] = val elif d_config[key] != val and force_compatibility: raise ValueError(f"Argument {key} must be set to {val}") else: # config[key] expected to be a kwarg dict if key not in d_config: d_config[key] = {} for kwargs_key, kwargs_val in val.items(): if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None: d_config[key][kwargs_key] = kwargs_val elif d_config[key][kwargs_key] != kwargs_val and force_compatibility: raise ValueError( f"Argument {key}[{kwargs_key}] must be set to {val}" ) return config