"""
Calculation toolbox that drives conv/deconv operations in the ngc-learn
convolution components sub-branch; this contains routines/co-routines
for `ngclearn.components.synapses.convolution`.
"""
import numpy as np
from jax import jit, numpy as jnp, random, nn, lax
from functools import partial
from jax._src import core
@partial(jit, static_argnums=[1])
def _pad(x, padding):
"""
Jit-i-fied padding function.
Args:
x (ndarray): The input array to be padded.
padding (tuple): A tuple containing the amounts of padding to apply to each dimension;
Format: (pad_bottom, pad_top, pad_left, pad_right).
Returns:
_x (ndarray): The padded array.
"""
# Unpack the padding amounts for each dimension
pad_bottom, pad_top, pad_left, pad_right = padding
# Apply padding to the input array 'x'
# Padding is applied as:
# [[0, 0], # No padding to the batch dimension
# [pad_bottom, pad_top], # Padding for the height dimension
# [pad_left, pad_right], # Padding for the width dimension
# [0, 0]] # No padding to the channel dimension
_x = jnp.pad(x,
[[0, 0],
[pad_bottom, pad_top],
[pad_left, pad_right],
[0, 0]], mode="constant").astype(jnp.float32) #To ensure all variables are of type float32
# Return the padded array
return _x
[docs]
@jit
def rot180(tensor):
"""
Rotate the input tensor by 180 degrees.
Args:
tensor (ndarray): The input tensor to be rotated.
Returns:
ndarray: The tensor rotated by 180 degrees.
"""
# Flip the tensor along the first two axes (height and width) to achieve a 180-degree rotation
flipped_tensor = jnp.flip(tensor, axis=[0, 1])
# Transpose the tensor to reorder the axes
# The axes [0, 1, 3, 2] correspond to:
# 0: Height
# 1: Width
# 3: Input Channels
# 2: Output Channels
rotated_tensor = jnp.transpose(flipped_tensor, axes=[0, 1, 3, 2])
# Return the rotated tensor
return rotated_tensor
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def get_same_conv_padding(lhs, rhs, stride_size=1, rhs_dilation=(1, 1),
lhs_dilation=(1, 1)):
padding = "SAME"
window_strides = (stride_size, stride_size)
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
dnums = lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
lhs_perm, rhs_perm, _ = dnums
rhs_shape = jnp.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [core.dilate_dim(k, r) for k, r in
zip(rhs_shape, rhs_dilation)]
padding = lax.padtype_to_pads(
jnp.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
# type: ignore[index]
window_strides, padding)
return padding
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def get_valid_conv_padding(lhs, rhs, stride_size=1, rhs_dilation=(1, 1),
lhs_dilation=(1, 1)):
padding = "VALID"
window_strides = (stride_size, stride_size)
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
dnums = lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
lhs_perm, rhs_perm, _ = dnums
rhs_shape = jnp.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [core.dilate_dim(k, r) for k, r in
zip(rhs_shape, rhs_dilation)]
padding = lax.padtype_to_pads(
jnp.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,
# type: ignore[index]
window_strides, padding)
return padding
def _conv_same_transpose_padding(inputs, output, kernel, stride):
"""
Calculate the padding for a transpose convolution operation to achieve 'SAME' padding.
Args:
inputs (int): The size of the input.
output (int): The size of the output.
kernel (int): The size of the convolution kernel.
stride (int): The stride length of the convolution.
Returns:
tuple: The padding for the height and width dimensions.
"""
pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
if stride > kernel:
pad_a = kernel - 1
else:
pad_a = int(np.ceil(pad_len / 2))
pad_b = pad_len - pad_a
return ((pad_a, pad_b), (pad_a, pad_b))
## Better optimized version of conv_same_padding
#@jit
# def _conv_same_transpose_padding(inputs, output, kernel, stride):
# """
# Calculate the padding for a transpose convolution operation to achieve 'same' padding.
#
# Parameters:
# inputs (int): The size of the input.
# output (int): The size of the output.
# kernel (int): The size of the convolution kernel.
# stride (int): The stride length of the convolution.
#
# Returns:
# tuple: The padding for the height and width dimensions.
# """
# # Calculate the total padding length required
# pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
#
# # Determine padding based on stride and kernel size
# pad_a = jnp.where(stride > kernel, kernel - 1, jnp.ceil(pad_len / 2).astype(int))
# pad_b = pad_len - pad_a
#
# # Return the padding for height and width as tuples
# return ((pad_a, pad_b), (pad_a, pad_b))
def _conv_valid_transpose_padding(inputs, output, kernel, stride):
"""
Calculate the padding for a transpose convolution operation to achieve 'VALID' padding.
Args:
inputs (int): The size of the input.
output (int): The size of the output.
kernel (int): The size of the convolution kernel.
stride (int): The stride length of the convolution.
Returns:
tuple: The padding for the height and width dimensions.
"""
pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
pad_a = kernel - 1
pad_b = pad_len - pad_a
return ((pad_a, pad_b), (pad_a, pad_b))
# ## Optimized version version2 for conv_valid_transpose_padding
# @jit
# def _conv_valid_transpose_padding(inputs, output, kernel, stride):
# """
# Calculate the padding for a transpose convolution operation to achieve 'valid' padding.
#
# Parameters:
# inputs (int): The size of the input.
# output (int): The size of the output.
# kernel (int): The size of the convolution kernel.
# stride (int): The stride length of the convolution.
#
# Returns:
# tuple: The padding for the height and width dimensions.
# """
# # Calculate the total padding length required
# pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
#
# # Set the padding value
# pad_a = kernel - 1
# pad_b = pad_len - pad_a
#
# # Ensure values are cast to integers
# pad_a = jnp.int32(pad_a)
# pad_b = jnp.int32(pad_b)
#
# # Return the padding for height and width as tuples
# return ((pad_a, pad_b), (pad_a, pad_b))
# @jit
# def _deconv_valid_transpose_padding(inputs, output, kernel, stride):
# """
# Calculate the padding for a transpose deconvolution operation to achieve 'valid' padding.
#
# Parameters:
# inputs (int): The size of the input.
# output (int): The size of the output.
# kernel (int): The size of the deconvolution kernel.
# stride (int): The stride length of the deconvolution.
#
# Returns:
# tuple: The padding for the height and width dimensions.
# """
# # Calculate the total padding length required
# pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
#
# # Set the padding value
# pad_a = output - 1
# pad_b = pad_len - pad_a
#
# # Ensure values are cast to integers
# pad_a = jnp.int32(pad_a)
# pad_b = jnp.int32(pad_b)
#
# # Return the padding for height and width as tuples
# return ((pad_a, pad_b), (pad_a, pad_b))
def _deconv_valid_transpose_padding(inputs, output, kernel, stride):
"""
Calculate the padding for a transpose deconvolution operation to achieve 'VALID' padding.
Args::
inputs (int): The size of the input.
output (int: The size of the output.
kernel (int): The size of the deconvolution kernel.
stride (int): The stride length of the deconvolution.
Returns:
tuple: The padding for the height and width dimensions.
"""
pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
pad_a = output - 1
pad_b = pad_len - pad_a
return ((pad_a, pad_b), (pad_a, pad_b))
# @jit
# def _deconv_same_transpose_padding(inputs, output, kernel, stride):
# """
# Calculate the padding for a transpose deconvolution operation to achieve 'same' padding.
#
# Parameters:
# inputs (int): The size of the input.
# output (int): The size of the output.
# kernel (int): The size of the deconvolution kernel.
# stride (int): The stride length of the deconvolution.
#
# Returns:
# tuple: The padding for the height and width dimensions.
# """
# # Calculate the total padding length required
# pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
#
# # Determine padding values
# pad_a = jnp.where(stride >= output - 1, output - 1, jnp.ceil(pad_len / 2)).astype(int)
# pad_b = pad_len - pad_a
#
# # Return the padding for height and width as tuples
# return ((pad_a, pad_b), (pad_a, pad_b))
def _deconv_same_transpose_padding(inputs, output, kernel, stride):
"""
Calculate the padding for a transpose deconvolution operation to achieve 'SAME' padding.
Args:
inputs (int): The size of the input.
output (int): The size of the output.
kernel (int): The size of the deconvolution kernel.
stride (int): The stride length of the deconvolution.
Returns:
tuple: The padding for the height and width dimensions.
"""
pad_len = output - ((stride - 1) * (inputs - 1) + inputs - (kernel - 1))
if stride >= output - 1:
pad_a = output - 1
else:
pad_a = int(np.ceil(pad_len / 2)) # int(jnp.ceil(pad_len / 2))
pad_b = pad_len - pad_a
return ((pad_a, pad_b), (pad_a, pad_b))
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def deconv2d(inputs, filters, stride_size=1, rhs_dilation=(1, 1),
padding=((0, 0), (0, 0))): ## Deconv2D
dim_numbers = ('NHWC', 'HWIO', 'NHWC')
out = lax.conv_transpose(inputs, # lhs = image tensor
filters, # rhs = conv kernel tensor
(stride_size, stride_size), # window strides
padding, # padding mode
rhs_dilation, # rhs/kernel dilation
dim_numbers)
return out
[docs]
@partial(jit, static_argnums=[2, 3, 4, 5])
def conv2d(inputs, filters, stride_size=1, rhs_dilation=(1, 1),
lhs_dilation=(1, 1), padding=((0, 0), (0, 0))): ## Conv2D
dim_numbers = ('NHWC', 'HWIO', 'NHWC')
out = lax.conv_general_dilated(inputs, # lhs = image tensor
filters, # rhs = conv kernel tensor
(stride_size, stride_size), # window strides
padding, # padding mode
lhs_dilation, # lhs/image dilation
rhs_dilation, # rhs/kernel dilation
dim_numbers)
return out
################################################################################
## ngc-learn convolution calculations
# @partial(jit, static_argnums=[2, 3, 4])
# def calc_dK_conv(x, d_out, delta_shape, stride_size=1, padding=((0, 0), (0, 0))):
# _x = x
# deX, deY = delta_shape
# if deX > 0:
# ## apply a pre-computation trimming step ("negative padding")
# _x = x[:, 0:x.shape[1]-deX, 0:x.shape[2]-deY, :]
# return _calc_dK_conv(_x, d_out, stride_size=stride_size, padding=padding)
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def calc_dK_conv(x, d_out, delta_shape, stride_size=1, padding=((0, 0), (0, 0))):
"""
Calculate the gradient with respect to the kernel for a convolution operation.
Args:
x (ndarray): The input array.
d_out (ndarray): The gradient with respect to the output.
delta_shape (tuple): The shape difference (deX, deY) between the input and output.
stride_size (int): The stride size for the convolution. Defaults to 1.
padding (tuple): Padding to apply to the input. Defaults to ((0, 0), (0, 0)).
Returns:
ndarray: The gradient with respect to the kernel.
"""
deX, deY = delta_shape
# Apply a pre-computation trimming step ("negative padding") if needed
#_x = jnp.where(deX > 0, x[:, 0:x.shape[1]-deX, 0:x.shape[2]-deY, :], x)
# _deX = jnp.maximum(deX, 0).astype(jnp.int32)
# _deY = jnp.maximum(deY, 0).astype(jnp.int32)
_x = x[:, 0:x.shape[1]-deX, 0:x.shape[2]-deY, :]
# Calculate the gradient with respect to the kernel
return _calc_dK_conv(_x, d_out, stride_size=stride_size, padding=padding)
@partial(jit, static_argnums=[2, 3])
def _calc_dK_conv(x, d_out, stride_size=1, padding=((0, 0), (0, 0))):
xT = jnp.transpose(x, axes=[3, 1, 2, 0])
d_out_T = jnp.transpose(d_out, axes=[1, 2, 0, 3])
## original conv2d
dW = conv2d(inputs=xT, filters=d_out_T, stride_size=1, padding=padding,
rhs_dilation=(stride_size, stride_size)).astype(jnp.float32)
return jnp.transpose(dW, axes=[1, 2, 0, 3])
################################################################################
# input update computation
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def calc_dX_conv(K, d_out, delta_shape, stride_size=1, anti_padding=None):
deX, deY = delta_shape
# if abs(deX) > 0 and stride_size > 1:
# return _calc_dX_subset(K, d_out, (abs(deX),abs(deY)), stride_size=stride_size,
# anti_padding=anti_padding)
dx = _calc_dX_conv(K, d_out, stride_size=stride_size, anti_padding=anti_padding)
return dx
@partial(jit, static_argnums=[2, 3])
def _calc_dX_conv(K, d_out, stride_size=1, anti_padding=None):
w_size = K.shape[0]
K_T = rot180(K) # Assuming rot180 is defined elsewhere.
_pad = w_size - 1
return deconv2d(d_out, filters=K_T, stride_size=stride_size,
padding=anti_padding).astype(jnp.float32)
################################################################################
## ngc-learn deconvolution calculations
# @partial(jit, static_argnums=[2, 3, 4, 5])
# def calc_dK_deconv(x, d_out, delta_shape, stride_size=1, out_size =2, padding="SAME"):
# _x = x
# deX, deY = delta_shape
# if deX > 0:
# ## apply a pre-computation trimming step ("negative padding")
# _x = x[:, 0:x.shape[1]-deX, 0:x.shape[2]-deY, :]
# return _calc_dK_deconv(_x, d_out, stride_size=stride_size, out_size = out_size)
[docs]
@partial(jit, static_argnums=[2, 3, 4, 5])
def calc_dK_deconv(x, d_out, delta_shape, stride_size=1, out_size=2, padding="SAME"):
"""
Calculate the gradient with respect to the kernel for a deconvolution operation.
Args:
x (ndarray): The input array.
d_out (ndarray): The gradient with respect to the output.
delta_shape (tuple): The shape difference (deX, deY) between the input and output.
stride_size (int): The stride size for the deconvolution. Defaults to 1.
out_size (int): The output size for the deconvolution.
padding (str): Padding to apply to the input. Defaults to "SAME".
Returns:
ndarray: The gradient with respect to the kernel.
"""
deX, deY = delta_shape
# Apply a pre-computation trimming step ("negative padding") if needed
#_x = jnp.where(deX > 0, x[:, :x.shape[1]-deX, :x.shape[2]-deY, :], x)
_x = x[:, :x.shape[1]-deX, :x.shape[2]-deY, :]
# Calculate the gradient with respect to the kernel
return _calc_dK_deconv(_x, d_out, stride_size=stride_size, out_size=out_size, padding=padding)
@partial(jit, static_argnums=[2, 3, 4])
def _calc_dK_deconv(x, d_out, stride_size=1, out_size=2, padding="SAME"):
xT = jnp.transpose(x, axes=[3, 1, 2, 0])
d_out_T = jnp.transpose(d_out, axes=[1, 2, 0, 3])
if padding == "VALID":
pad_args = _deconv_valid_transpose_padding(xT.shape[1], out_size, d_out_T.shape[1], stride_size)
elif padding == "SAME":
pad_args = _deconv_same_transpose_padding(xT.shape[1], out_size, d_out_T.shape[1], stride_size)
dW = deconv2d(inputs=xT, filters=d_out_T, stride_size=stride_size,
padding=pad_args)
dW = jnp.transpose(dW, axes=[1, 2, 0, 3])
return dW
################################################################################
# input update computation
[docs]
@partial(jit, static_argnums=[2, 3, 4])
def calc_dX_deconv(K, d_out, delta_shape, stride_size=1, padding=((0, 0), (0, 0))):
"""
Wrapper function to calculate the gradient with respect to the input (dX)
from the gradient with respect to the output (d_out) using the kernel (K).
This version takes into account the shape difference (delta_shape) between
the input and the output of the convolution.
Args:
K (ndarray): The convolution kernel.
d_out (ndarray): The gradient with respect to the output of the convolution.
delta_shape (tuple): The shape difference (deX, deY) between the input and output.
stride_size (int): The stride size for the deconvolution. Defaults to 1.
padding (tuple): Padding to apply to the input. Defaults to ((0, 0), (0, 0)).
Returns:
dx (ndarray): The gradient with respect to the input.
"""
# Extract the shape difference in the x and y dimensions
deX, deY = delta_shape
# Conditional logic to handle cases where delta_shape is non-zero and stride_size is greater than 1
# if abs(deX) > 0 and stride_size > 1:
# return _calc_dX_subset(K, d_out, (abs(deX), abs(deY)), stride_size=stride_size, padding=padding)
# Call the _calc_dX_deconv function to perform the deconvolution
dx = _calc_dX_deconv(K, d_out, stride_size=stride_size, padding=padding)
# Return the gradient with respect to the input
return dx
@partial(jit, static_argnums=[2, 3])
def _calc_dX_deconv(K, d_out, stride_size=1, padding=((0, 0), (0, 0))):
"""
Perform a deconvolution to get the gradient with respect to the input (dX)
from the gradient with respect to the output (d_out) using the kernel (K).
Args:
K (ndarray): The convolution kernel.
d_out (ndarray): The gradient with respect to the output of the convolution.
stride_size (int): The stride size for the deconvolution. Defaults to 1.
padding (tuple): Padding to apply to the input. Defaults to ((0, 0), (0, 0)).
Returns:
dx (ndarray): The gradient with respect to the input.
"""
# The size of the kernel
w_size = K.shape[0]
# Rotate the kernel by 180 degrees
K_T = rot180(K) # Equivalent to jnp.transpose(K, axes=[1, 0, 3, 2])
# Padding size for the deconvolution, derived from the kernel size
_pad = w_size - 1
# Perform the deconvolution using conv2d with the rotated kernel
dx = conv2d(d_out,
filters=K_T,
stride_size=stride_size,
padding=padding)
# Return the gradient with respect to the input
return dx
################################################################################