Source code for meerkat.contrib.video_corruptions.transforms

import random
from typing import Optional, Tuple

import torch


[docs]class TemporalDownsampling(object): """Video transformation for performing temporal downsampling (i.e. reading in every Nth frame only). This can be used in tandem with VideoCell by passing it into the `transform` keyword in the constructor. Can be used with Compose in torchvision. When using with TemporalCrop, it is highly recommended to put TemporalDownsampling first, with TemporalCrop second. Arguments: downsample_factor (int): the factor by which the input video should be downsampled. Must be a strictly positive integer. time_dim (int): the time dimension of the input video. Examples: # Create a VideoCell from "/path/to/video.mp4" with time in dimension one, showing every other frame >>> cell = VideoCell("path/to/video.mp4", time_dim=1, transform=TemporalDownsampling(2, time_dim=1) ) Note that time_dim in the TemporalDownsampling call must match the the time_dim in the VideoCell constructor! """ def __init__(self, downsample_factor: int, time_dim: Optional[int] = 1): self.downsample_factor = downsample_factor self.time_dim = time_dim if self.downsample_factor != int(self.downsample_factor): raise ValueError("Fractional downsampling not supported.") if self.downsample_factor < 1: raise ValueError( "Downsampling must be by a factor of 1 (no upsampling) or greater." ) def __call__(self, video: torch.Tensor) -> torch.Tensor: video_length = video.size(self.time_dim) downsampled_indices = torch.arange( 0, video_length, self.downsample_factor ).long() frames = torch.index_select(video, self.time_dim, downsampled_indices) return frames
[docs]class TemporalCrop(object): """Video transformation for performing "temporal cropping:" the sampling of a pre-defined number of clips, each with pre-defined length, from a full video. Can be used with Compose in torchvision. When used with TemporalDownsampling, it is highly recommended to put TemporalCrop after TemporalDownsampling. Since TemporalCrop can change the number of dimensions in the output tensor, due to clip selection, it is in fact recommended to put this transform at the end of a video transformation pipeline. Arguments: n_clips (int): the number of clips that should be sampled. clip_length (int): the length of each clip (in the number of frames) time_dim (int): the index of the time dimension of the video clip_spacing (Optional; default "equal"): how to choose starting locations for sampling clips. Keyword "equal" means that clip starting locations are sampled from each 1/n_clips segment of the video. The other option, "anywhere", places no restrictions on where clip starting locations can be sampled. padding_mode: (Optional; default "loop"): behavior if a requested clip length would result a clip exceeding the end of the video. Keyword "loop" results in a wrap-around to the start of the video. The other option, "freeze", repeats the final frame until the requested clip length is achieved. sample_starting_location: (Optional; default False): whether to sample a starting location (usually used for training) for a clip. Can be used in tandem with "equal" during training to sample clips with random starting locations distributed across time. Redundant if `clip_spacing` is "anywhere". stack_clips: (Optional; default True): whether to stack clips in a new dimension (used in 3D action recognition backbones), or stack clips by concatenating across the time dimension (used in 2D action recognition backbones). Output shape if True is (n_clips, *video_shape). If False, the output shape has the same number of dimensions as the original video, but the time dimension is extended by a factor of n_clips. Examples: # Create a VideoCell from "/path/to/video.mp4" with time in dimension one, sampling 10 clips each of length 16, sampling clips equally across the video >>> cell = VideoCell("/path/to/video.mp4", time_dim=1, transform=TemporalCrop(10, 16, time_dim=1) ) # output shape: (10, n_channels, 16, H, W) # Create a VideoCell from "/path/to/video.mp4" with time in dimension one, sampling 8 clips each of length 8, sampling clips from arbitrary video locations and freezing the last frame if a clip exceeds the video length >>> cell = VideoCell("/path/to/video.mp4", time_dim=1, transform=TemporalCrop(8, 8, time_dim=1, clip_spacing="anywhere", padding_mode="freeze") ) # output shape: (8, n_channels, 8, H, W) # Create a VideoCell from "/path/to/video.mp4" with time in dimension one, sampling one frame from each third of the video, concatenating the frames in the time dimension >>> cell = VideoCell("/path/to/video.mp4", time_dim=1, transform=TemporalCrop(1, 3, time_dim=1, clip_spacing="equal", sample_starting_location=True, stack_clips=False) ) # output shape: (n_channels, 3, H, W) Note that time_dim in the TemporalDownsampling call must match the the time_dim in the VideoCell constructor! """ def __init__( self, n_clips: int, clip_length: int, time_dim: Optional[int] = 1, clip_spacing: Optional[str] = "equal", padding_mode: Optional[str] = "loop", sample_starting_location: Optional[bool] = False, stack_clips: Optional[bool] = True, ): self.n_clips = n_clips self.clip_length = clip_length self.time_dim = time_dim self.clip_spacing = clip_spacing self.padding_mode = padding_mode self.stack_clips = stack_clips self.sample_starting_location = sample_starting_location if clip_length != int(clip_length): raise ValueError("Clip length (# of frames per clip) must be an integer") if clip_length <= 0: raise ValueError( "Clip length (# of frames per clip) must be a positive integer." ) if n_clips != int(n_clips): raise ValueError("Number of clips is not an integer") if n_clips <= 0: raise ValueError("Number of clips must be a positive integer") assert clip_spacing in ["equal", "anywhere"] assert padding_mode in ["loop", "freeze"] def _get_sampling_boundaries( self, video_length: int, clip_number: int ) -> Tuple[int, int]: if self.clip_spacing == "equal": start = int(clip_number / self.n_clips * video_length) end = int((clip_number + 1) / self.n_clips * video_length) else: # self.clip_spacing == "anywhere" start = 0 end = video_length return start, end def _build_indices(self, start: int, length: int) -> torch.LongTensor: vanilla_indices = torch.arange(start, start + self.clip_length) if vanilla_indices.max().item() >= length: if self.padding_mode == "loop": vanilla_indices %= length else: # self.padding_mode == "freeze": vanilla_indices[vanilla_indices >= length] = length - 1 return vanilla_indices def __call__(self, video: torch.Tensor) -> torch.Tensor: video_length = video.size(self.time_dim) clips = [] for clip_number in range(self.n_clips): start, end = self._get_sampling_boundaries(video_length, clip_number) if self.sample_starting_location: first_frame = random.randint(start, end) else: first_frame = start indices = self._build_indices(first_frame, video_length) clip = torch.index_select(video, self.time_dim, indices) clips.append(clip) if self.stack_clips: # new dim for clips (n_clips, n_channels, duration, h, w) all_clips = torch.stack(clips, dim=0) else: # concat clips in time dimension (n_channels, n_clips * duration, h, w) all_clips = torch.cat(clips, dim=self.time_dim) return all_clips