Source code for ngclearn.utils.masks.multiblock2d

# %%
import math
import numpy as np
from multiprocessing import Value

[docs] class MaskCollator(object): # Adapted from the Meta JEPA code-base to ngc-learn compliance """ A mechanism for generating/creating patch masks, generally for self-supervised learning. Args: cfgs_mask: configuration masks to apply crop_size: dimensions of crop patch_size: dimensions of patches to create """ def __init__(self, cfgs_mask, crop_size=(224, 224), patch_size=(16, 16),): super(MaskCollator, self).__init__() self.mask_generators = [] for m in cfgs_mask: mask_generator = _MaskGenerator( crop_size=crop_size, patch_size=patch_size, pred_mask_scale=m.get('spatial_scale'), aspect_ratio=m.get('aspect_ratio'), npred=m.get('num_blocks'), max_keep=m.get('max_keep', None), ) self.mask_generators.append(mask_generator)
[docs] def step(self): """ Steps this generator forward one step. Returns: next set of collated encoder masks, next set of predictor masks """ for mask_generator in self.mask_generators: mask_generator.step()
def __call__(self, batch): batch_size = len(batch) collated_masks_pred, collated_masks_enc = [], [] for i, mask_generator in enumerate(self.mask_generators): masks_enc, masks_pred = mask_generator(batch_size) collated_masks_enc.append(masks_enc) collated_masks_pred.append(masks_pred) return collated_masks_enc, collated_masks_pred
class _MaskGenerator(object): def __init__( self,crop_size=(224, 224), patch_size=(16, 16), pred_mask_scale=(0.2, 0.8), aspect_ratio=(0.3, 3.0), npred=1,max_keep=None ): super(_MaskGenerator, self).__init__() if not isinstance(crop_size, tuple): crop_size = (crop_size, ) * 2 self.crop_size = crop_size self.height, self.width = crop_size[0] // patch_size[0], crop_size[1] // patch_size[1] self.patch_size = patch_size self.aspect_ratio = aspect_ratio self.pred_mask_scale = pred_mask_scale self.npred = npred self.max_keep = max_keep self._itr_counter = Value('i', -1) # collator is shared across worker processes def step(self): i = self._itr_counter with i.get_lock(): i.value = (i.value + 1) % 2**16 v = i.value return v def _sample_block_size(self,rng: np.random.RandomState,scale, aspect_ratio_scale): # -- Sample spatial block mask scale _rand = rng.random() min_s, max_s = scale spatial_mask_scale = min_s + _rand * (max_s - min_s) spatial_num_keep = int(self.height * self.width * spatial_mask_scale) # -- Sample block aspect-ratio _rand = rng.random() min_ar, max_ar = aspect_ratio_scale aspect_ratio = min_ar + _rand * (max_ar - min_ar) # -- Compute block height and width (given scale and aspect-ratio) h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) h = min(h, self.height) w = min(w, self.width) return (h, w) def _sample_block_mask(self, b_size, rng: np.random.RandomState): h, w = b_size top = rng.randint(0, self.height - h + 1) left = rng.randint(0, self.width - w + 1) mask = np.ones((self.height, self.width), dtype=np.int32) mask[top:top+h, left:left+w] = 0 return mask def __call__(self, batch_size): """ Create encoder and predictor masks when collating imgs into a batch: | # 1. sample pred block size using seed | # 2. sample several pred block locations for each image (w/o seed) | # 3. return pred masks and complement (enc mask) Args: batch_size: number of samples to place w/in a generate batch Returns: collated encoder masks, collated predictor masks """ seed = self.step() rng = np.random.RandomState(seed) p_size = self._sample_block_size(rng=rng, scale=self.pred_mask_scale, aspect_ratio_scale=self.aspect_ratio,) collated_masks_pred, collated_masks_enc = [], [] min_keep_enc = min_keep_pred = self.height * self.width for _ in range(batch_size): empty_context = True while empty_context: # Create a mask for this sample mask_e = np.ones((self.height, self.width), dtype=np.int32) for _ in range(self.npred): mask_e *= self._sample_block_mask(p_size, rng) mask_e = mask_e.flatten() mask_p = np.where(mask_e == 0)[0] mask_e = np.where(mask_e != 0)[0] empty_context = len(mask_e) == 0 if not empty_context: min_keep_pred = min(min_keep_pred, len(mask_p)) min_keep_enc = min(min_keep_enc, len(mask_e)) collated_masks_pred.append(mask_p) collated_masks_enc.append(mask_e) if self.max_keep is not None: min_keep_enc = min(min_keep_enc, self.max_keep) # Truncate arrays to the minimum length to create uniform arrays collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] collated_masks_pred = np.array(collated_masks_pred) collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] collated_masks_enc = np.array(collated_masks_enc) return collated_masks_enc, collated_masks_pred