Source code for meerkat.cells.video

import os
from typing import Callable, Collection, Optional

import torch

from meerkat import AbstractCell
from meerkat.contrib.video_corruptions.utils import stderr_suppress
from meerkat.tools.lazy_loader import LazyLoader

cv2 = LazyLoader("cv2")
F = LazyLoader("torchvision.transforms.functional")


[docs]class VideoCell(AbstractCell): """Interface for loading video data. Examples: # Load a video from "/path/to/video.mp4", where the time dimension has index one >>> cell = VideoCell("/path/to/video.mp4", time_dim=1) """ def __init__( self, filepath: str, transform: Optional[Callable] = None, time_dim: Optional[int] = 1, ): super().__init__() self.filepath = filepath self.time_dim = time_dim self.transform = transform def _read_all_frames(self): if not os.path.isfile(self.filepath): raise ValueError(f"{self.filepath} is not a valid file!") cap = cv2.VideoCapture(self.filepath) frames = [] ret = True while ret: ret, img = cap.read() try: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) except cv2.error: break img = F.to_tensor(img) frames.append(img) cap.release() frames = torch.stack(frames, dim=self.time_dim) return frames
[docs] def get(self): with stderr_suppress(): frames = self._read_all_frames() # TODO: support different decoders if self.transform is not None: frames = self.transform(frames) return frames
@classmethod def _state_keys(cls) -> Collection: return {"filepath", "time_dim", "transform"}