initial tranpose conv implementation

This commit is contained in:
Anselm Levskaya 2019-04-09 15:06:46 -07:00
parent b0e9650789
commit 797d411eeb
2 changed files with 143 additions and 0 deletions

View File

@ -1229,6 +1229,88 @@ def conv_with_general_padding(lhs, rhs, window_strides, padding,
rhs_dilation=rhs_dilation)
def _conv_transpose_padding(k, s, padding):
"""Calculate before and after padding for a dim of transposed convolution.
Args:
k: int: kernel dimension.
s: int: dimension stride value.
padding: 'same' or 'valid' padding mode for original forward conv.
Returns:
2-tuple: ints: before and after padding for transposed convolution.
"""
if padding.lower() == 'same':
pad_len = k + s - 2
if s > k - 1:
pad_a = k - 1
else:
pad_a = int(onp.ceil(pad_len / 2))
elif padding.lower() == 'valid':
pad_len = k + s - 2 + max(k - s, 0)
pad_a = k - 1
else:
raise ValueError('Padding mode must be `same` or `valid`.')
pad_b = pad_len - pad_a
return pad_a, pad_b
def _flip_axes(x, axes):
"""Flip ndarray 'x' along each axis specified in axes tuple."""
for axis in axes:
x = onp.flip(x, axis)
return x
def conv_transpose(data, kernel, strides, padding, dimension_numbers=None):
"""Convenience wrapper for calculating the N-d convolution transpose.
This function directly calculates convT rather than indirectly calculating
the gradient (transpose) of a forward convolution.
Args:
data: a rank `n+2` dimensional input array.
kernel: a rank `n+2` dimensional array of kernel weights.
strides: sequence of `n` integers, sets fractional stride.
padding: 'same', 'valid' will set as transpose of corresponding forward
conv, or a sequence of `n` integer 2-tuples describing before-and-after
padding for each `n` spatial dimension.
dimension_numbers: tuple of dimension descriptors as in
lax.conv_general_dilated. Defaults to tensorflow convention.
Returns:
Transposed N-d convolution, with padding following the conventions of the
corresponding keras and tensorflow conv-transpose operators.
"""
assert len(data.shape) == len(kernel.shape) and len(data.shape) > 2
ndims = len(data.shape)
one = (1,) * (ndims - 2)
#Set dimensional layout defaults if not specified.
if dimension_numbers is None:
if ndims == 3:
dimension_numbers = ('NHC', 'HIO', 'NHC')
elif ndims == 4:
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
elif ndims == 5:
dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
else:
raise ValueError('No 4+ dimensional dimension_number defaults.')
dn = conv_dimension_numbers(data.shape, kernel.shape, dimension_numbers)
k_shape = onp.take(kernel.shape, dn.rhs_spec)
k_sdims = k_shape[2:]
# Calculate correct output shape given padding and strides.
if padding.lower() in {'same', 'valid'}:
pads = [_conv_transpose_padding(k, s, padding)
for k,s in zip(k_sdims.tolist(), strides)]
else:
pads = padding
# transposed conv = flipped kernel plus LHS dilation
kernel_t = _flip_axes(kernel, onp.array(dn.rhs_spec)[2:])
# flip input/output channel axes
kernel_t = onp.swapaxes(kernel_t, dn.rhs_spec[0], dn.rhs_spec[1])
return conv_general_dilated(data, kernel_t, one, pads, strides, one, dn)
def full_like(x, fill_value, dtype=None, shape=None):
"""Create a full array like np.full based on the example array `x`.

View File

@ -445,6 +445,67 @@ class LaxTest(jtu.JaxTestCase):
# TODO(mattjj): test conv_general_dilated against numpy
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "rng": rng, 'dspec': dspec}
for lhs_shape, rhs_shape in [
((b, 9, 10, i), (3, 3, i, j))
for b, i, j in itertools.product([2, 3], repeat=3)]
for dtype in [onp.float32]
for strides in [(1, 1), (1, 2), (2, 1), (2, 2), (3, 3)]
for padding in ["VALID", "SAME"]
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
for rng in [jtu.rand_small()]))
def testConvTranspose(self, lhs_shape, rhs_shape, dtype, strides,
padding, dspec, rng):
def deconv_output_length(input_length, filter_size, padding, stride=0):
if padding.lower() == 'valid':
length = input_length * stride + max(filter_size - stride, 0)
elif padding.lower() == 'same':
length = input_length * stride
return length
def inv_permutation(p):
return [x[0] for x in sorted(enumerate(p), key=lambda x: x[1])]
def conv_transpose_via_grad(data, kernel, strides, padding,
dimension_numbers=dspec):
assert len(data.shape) == len(kernel.shape)
ndims = len(data.shape)
nspatial = ndims - 2
one = (1,) * nspatial
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
dimension_numbers)
in_shape = onp.take(data.shape, dn.lhs_spec)
in_sdims = in_shape[2:]
k_shape = onp.take(kernel.shape, dn.rhs_spec)
k_sdims = k_shape[2:]
o_sdims = [deconv_output_length(in_sdims[i], k_sdims[i], padding,
stride=strides[i])
for i in range(nspatial)]
o_shape = [in_shape[0], k_shape[1]] + o_sdims
o_layout = onp.take(onp.array(o_shape), inv_permutation(dn.out_spec))
placeholder = onp.ones(o_layout, data.dtype)
conv = lambda x: lax.conv_general_dilated(x, kernel, strides, padding,
one, one, dn)
_, g = api.vjp(conv, placeholder)
return g(data)[0]
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.conv_transpose(lhs, rhs, strides, padding,
dimension_numbers=dspec)
def fun_via_grad(lhs, rhs):
return conv_transpose_via_grad(lhs, rhs, strides, padding,
dimension_numbers=dspec)
# self._CompileAndCheck(fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),