mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
initial tranpose conv implementation
This commit is contained in:
parent
b0e9650789
commit
797d411eeb
82
jax/lax.py
82
jax/lax.py
@ -1229,6 +1229,88 @@ def conv_with_general_padding(lhs, rhs, window_strides, padding,
|
||||
rhs_dilation=rhs_dilation)
|
||||
|
||||
|
||||
def _conv_transpose_padding(k, s, padding):
|
||||
"""Calculate before and after padding for a dim of transposed convolution.
|
||||
|
||||
Args:
|
||||
k: int: kernel dimension.
|
||||
s: int: dimension stride value.
|
||||
padding: 'same' or 'valid' padding mode for original forward conv.
|
||||
|
||||
Returns:
|
||||
2-tuple: ints: before and after padding for transposed convolution.
|
||||
"""
|
||||
if padding.lower() == 'same':
|
||||
pad_len = k + s - 2
|
||||
if s > k - 1:
|
||||
pad_a = k - 1
|
||||
else:
|
||||
pad_a = int(onp.ceil(pad_len / 2))
|
||||
elif padding.lower() == 'valid':
|
||||
pad_len = k + s - 2 + max(k - s, 0)
|
||||
pad_a = k - 1
|
||||
else:
|
||||
raise ValueError('Padding mode must be `same` or `valid`.')
|
||||
pad_b = pad_len - pad_a
|
||||
return pad_a, pad_b
|
||||
|
||||
|
||||
def _flip_axes(x, axes):
|
||||
"""Flip ndarray 'x' along each axis specified in axes tuple."""
|
||||
for axis in axes:
|
||||
x = onp.flip(x, axis)
|
||||
return x
|
||||
|
||||
|
||||
def conv_transpose(data, kernel, strides, padding, dimension_numbers=None):
|
||||
"""Convenience wrapper for calculating the N-d convolution transpose.
|
||||
|
||||
This function directly calculates convT rather than indirectly calculating
|
||||
the gradient (transpose) of a forward convolution.
|
||||
|
||||
Args:
|
||||
data: a rank `n+2` dimensional input array.
|
||||
kernel: a rank `n+2` dimensional array of kernel weights.
|
||||
strides: sequence of `n` integers, sets fractional stride.
|
||||
padding: 'same', 'valid' will set as transpose of corresponding forward
|
||||
conv, or a sequence of `n` integer 2-tuples describing before-and-after
|
||||
padding for each `n` spatial dimension.
|
||||
dimension_numbers: tuple of dimension descriptors as in
|
||||
lax.conv_general_dilated. Defaults to tensorflow convention.
|
||||
|
||||
Returns:
|
||||
Transposed N-d convolution, with padding following the conventions of the
|
||||
corresponding keras and tensorflow conv-transpose operators.
|
||||
"""
|
||||
assert len(data.shape) == len(kernel.shape) and len(data.shape) > 2
|
||||
ndims = len(data.shape)
|
||||
one = (1,) * (ndims - 2)
|
||||
#Set dimensional layout defaults if not specified.
|
||||
if dimension_numbers is None:
|
||||
if ndims == 3:
|
||||
dimension_numbers = ('NHC', 'HIO', 'NHC')
|
||||
elif ndims == 4:
|
||||
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
||||
elif ndims == 5:
|
||||
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
|
||||
else:
|
||||
raise ValueError('No 4+ dimensional dimension_number defaults.')
|
||||
dn = conv_dimension_numbers(data.shape, kernel.shape, dimension_numbers)
|
||||
k_shape = onp.take(kernel.shape, dn.rhs_spec)
|
||||
k_sdims = k_shape[2:]
|
||||
# Calculate correct output shape given padding and strides.
|
||||
if padding.lower() in {'same', 'valid'}:
|
||||
pads = [_conv_transpose_padding(k, s, padding)
|
||||
for k,s in zip(k_sdims.tolist(), strides)]
|
||||
else:
|
||||
pads = padding
|
||||
# transposed conv = flipped kernel plus LHS dilation
|
||||
kernel_t = _flip_axes(kernel, onp.array(dn.rhs_spec)[2:])
|
||||
# flip input/output channel axes
|
||||
kernel_t = onp.swapaxes(kernel_t, dn.rhs_spec[0], dn.rhs_spec[1])
|
||||
return conv_general_dilated(data, kernel_t, one, pads, strides, one, dn)
|
||||
|
||||
|
||||
def full_like(x, fill_value, dtype=None, shape=None):
|
||||
"""Create a full array like np.full based on the example array `x`.
|
||||
|
||||
|
@ -445,6 +445,67 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(mattjj): test conv_general_dilated against numpy
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "rng": rng, 'dspec': dspec}
|
||||
for lhs_shape, rhs_shape in [
|
||||
((b, 9, 10, i), (3, 3, i, j))
|
||||
for b, i, j in itertools.product([2, 3], repeat=3)]
|
||||
for dtype in [onp.float32]
|
||||
for strides in [(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)]
|
||||
for padding in ["VALID", "SAME"]
|
||||
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
|
||||
for rng in [jtu.rand_small()]))
|
||||
def testConvTranspose(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, dspec, rng):
|
||||
def deconv_output_length(input_length, filter_size, padding, stride=0):
|
||||
if padding.lower() == 'valid':
|
||||
length = input_length * stride + max(filter_size - stride, 0)
|
||||
elif padding.lower() == 'same':
|
||||
length = input_length * stride
|
||||
return length
|
||||
def inv_permutation(p):
|
||||
return [x[0] for x in sorted(enumerate(p), key=lambda x: x[1])]
|
||||
def conv_transpose_via_grad(data, kernel, strides, padding,
|
||||
dimension_numbers=dspec):
|
||||
assert len(data.shape) == len(kernel.shape)
|
||||
ndims = len(data.shape)
|
||||
nspatial = ndims - 2
|
||||
one = (1,) * nspatial
|
||||
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
||||
dimension_numbers)
|
||||
in_shape = onp.take(data.shape, dn.lhs_spec)
|
||||
in_sdims = in_shape[2:]
|
||||
k_shape = onp.take(kernel.shape, dn.rhs_spec)
|
||||
k_sdims = k_shape[2:]
|
||||
o_sdims = [deconv_output_length(in_sdims[i], k_sdims[i], padding,
|
||||
stride=strides[i])
|
||||
for i in range(nspatial)]
|
||||
o_shape = [in_shape[0], k_shape[1]] + o_sdims
|
||||
o_layout = onp.take(onp.array(o_shape), inv_permutation(dn.out_spec))
|
||||
placeholder = onp.ones(o_layout, data.dtype)
|
||||
conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding,
|
||||
one, one, dn)
|
||||
_, g = api.vjp(conv, placeholder)
|
||||
return g(data)[0]
|
||||
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
def fun(lhs, rhs):
|
||||
return lax.conv_transpose(lhs, rhs, strides, padding,
|
||||
dimension_numbers=dspec)
|
||||
|
||||
def fun_via_grad(lhs, rhs):
|
||||
return conv_transpose_via_grad(lhs, rhs, strides, padding,
|
||||
dimension_numbers=dspec)
|
||||
|
||||
# self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user