Source code for ngclearn.utils.patch_utils

"""
Image/tensor patching utility routines.
"""
import numpy as np
from jax import numpy as jnp
from sklearn.feature_extraction.image import extract_patches_2d


[docs] def patch_with_stride(X, patch, stride): n_x, n_y = X.shape[0] // patch[0], X.shape[1] // patch[1] result = [] for x_idx in range(n_x): for y_idx in range(n_y): x_start = stride[0] * x_idx x_end = x_start + patch[0] y_start = stride[1] * y_idx y_end = y_start + patch[1] img_ = X[x_start: x_end, y_start: y_end] if img_.shape == (patch[0], patch[1]): result.append(img_) return jnp.array(result)
[docs] def patch_with_overlap(X, patch, overlap): stride = (X.shape[0] - overlap[0], X.shape[1] - overlap[1]) return patch_with_stride(X, patch, stride)
[docs] class Create_Patches: """ This function will create small patches out of the image based on the provided attributes. Args: img: jax array of size (H, W) patch_shape: (height_patch, width_patch) overlap_shape: (height_overlap, width_overlap) Returns: jnp.array: Array containing the patches, shape: (num_patches, patch_height, patch_width) """ #patched: (height_patch, width_patch) #overlap: (height_overlap, width_overlap) #add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap) #create_patches: creates small patches out of the image based on the provided attributes. def __init__(self, img, patch_shape, overlap_shape): self.img = img self.height_patch, self.width_patch = patch_shape self.height_over, self.width_over = overlap_shape self.height, self.width = self.img.shape self.nw_patches = self.width // self.width_patch #(self.width + self.width_over) // (self.width_patch - self.width_over) - 2 self.nh_patches = self.height // self.height_patch #(self.height + self.height_over) // (self.height_patch - self.height_over) - 2 def _add_frame(self): """ This function will add zero frames (increase the dimension) to the image Returns: image with increased size (x.shape[0], x.shape[1]) -> (x.shape[0] + (height_patch - height_overlap), x.shape[1] + (width_patch - width_overlap)) """ self.img = np.hstack((jnp.zeros((self.img.shape[0], (self.height_patch - self.height_over))), self.img, jnp.zeros((self.img.shape[0], (self.height_patch - self.height_over))))) self.img = np.vstack((jnp.zeros(((self.width_patch - self.width_over), self.img.shape[1])), self.img, jnp.zeros(((self.width_patch - self.width_over), self.img.shape[1])))) self.height, self.width = self.img.shape self.nw_patches = (self.width + self.width_over) // (self.width_patch - self.width_over) - 2 self.nh_patches = (self.height + self.height_over) // (self.height_patch - self.height_over) - 2
[docs] def create_patches(self, add_frame=False, center=True): """ This function will create small patches out of the image based on the provided attributes. Keyword Args: add_frame: If true the function will add zero frames (increase the dimension) to the image center: Returns: jnp.array: Array containing the patches shape: (num_patches, patch_height, patch_width) """ if add_frame == True: self._add_frame() if center == True: mu = np.mean(self.img, axis=0, keepdims=True) self.img = self.img - mu result = [] for nh_ in range(self.nh_patches): for nw_ in range(self.nw_patches): img_ = self.img[(self.height_patch - self.height_over) * nh_: nh_ * ( self.height_patch - self.height_over) + self.height_patch , (self.width_patch - self.width_over) * nw_: nw_ * ( self.width_patch - self.width_over) + self.width_patch] if img_.shape == (self.height_patch, self.width_patch): result.append(img_) return jnp.array(result)
[docs] def generate_patch_set(x_batch, patch_size=(8, 8), max_patches=50, center=True, seed=1234, vis_mode=False): ## scikit """ Generates a set of patches from an array/list of image arrays (via random sampling with replacement). This uses scikit-learn's patch creation function to generate a set of (px x py) patches. Note: this routine also subtracts each patch's mean from itself. Args: x_batch: the array of image arrays to sample from patch_size: a 2-tuple of the form (pH = patch height, pW = patch width) max_patches: maximum number of patches to extract/generate from source images center: centers each patch by subtracting the patch mean (per-patch) seed: seed to control the random state of internal patch sampling Returns: an array (D x (pH * pW)), where each row is a flattened patch sample """ _x_batch = np.array(x_batch) px = py = int(np.sqrt(_x_batch.shape[1])) # get image shape of the data p_batch = None for s in range(_x_batch.shape[0]): xs = _x_batch[s, :] xs = xs.reshape(px, py) patches = extract_patches_2d(xs, patch_size, max_patches=max_patches, random_state=seed)#, random_state=69) patches = np.reshape(patches, (len(patches), -1)) # flatten each patch in set if s > 0: p_batch = np.concatenate((p_batch,patches),axis=0) else: p_batch = patches mu = 0 if center: ## center patches by subtracting out their means mu = np.mean(p_batch, axis=1, keepdims=True) p_batch = p_batch - mu if vis_mode: return jnp.array(p_batch), mu else: return jnp.array(p_batch)
[docs] def generate_pacthify_patch_set(x_batch_, patch_size=(5, 5), center=True): ## patchify ## this is a patchify-specific function (only use if you have patchify installed...) import patchify as ptch """ Generates a set of patches from an array/list of image arrays (via random sampling with replacement). This uses the patchify library to create a of non-random non-overlapping or overlapping (w/ controllable stride) patches. Note: this routine also subtracts each patch's mean from itself. Args: x_batch_: the array of image arrays to sample from patch_size: a 2-tuple of the form (pH = patch height, pW = patch width) center: centers each patch by subtracting the patch mean (per-patch) Returns: an array (D x (pH * pW)), where each row is a flattened patch sample """ x_batch = np.array(x_batch_) px = py = int(np.sqrt(x_batch.shape[1])) # get image shape of the data x_batch = np.expand_dims(x_batch.reshape(px, py), axis=2) pch_x = patch_size[0] pch_y = patch_size[1] pX = np.squeeze( ptch.patchify(x_batch, (pch_x,pch_y,1), step=pch_x) ) # step = stride patchBatch = [] for i in range(pX.shape[0]): for j in range(pX.shape[1]): _p = np.reshape(pX[i,j,:,:], (1, pch_x * pch_y)) patchBatch.append(_p) patchBatch = jnp.concatenate(patchBatch, axis=0) if center == True: mu = np.mean(patchBatch, axis=1,keepdims=True) patchBatch = patchBatch - mu return patchBatch