Add support for 0d transpose convolution (#3643)

* Allow 0d transpose convolution

* Add test for 0d conv transpose

* remove whitespace
This commit is contained in:
Roman Novak 2020-07-02 14:38:35 -07:00 committed by GitHub
parent d10cf0e38f
commit 4442c333ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 2 deletions

View File

@ -1487,12 +1487,14 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
Transposed N-d convolution, with output padding following the conventions of
keras.layers.Conv2DTranspose.
"""
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) > 2
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
ndims = len(lhs.shape)
one = (1,) * (ndims - 2)
# Set dimensional layout defaults if not specified.
if dimension_numbers is None:
if ndims == 3:
if ndims == 2:
dimension_numbers = ('NC', 'IO', 'NC')
elif ndims == 3:
dimension_numbers = ('NHC', 'HIO', 'NHC')
elif ndims == 4:
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')

View File

@ -662,6 +662,43 @@ class LaxTest(jtu.JaxTestCase):
# NB: below just checks for agreement, we're not calling numpy.
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
"rng_factory": rng_factory, 'dspec': dspec}
for lhs_shape, rhs_shape in [
((b, i), (i, j))
for b, i, j in itertools.product([2,3],[2,3],[2,3])]
for dtype in float_dtypes
for strides in [()]
for padding in ["VALID", "SAME"]
for dspec in [('NC', 'IO', 'NC'),]
for rhs_dilation in [None, ()]
for rng_factory in [jtu.rand_small]))
def testConvTranspose0D(self, lhs_shape, rhs_shape, dtype, strides,
padding, dspec, rhs_dilation, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
def fun(lhs, rhs):
return lax.conv_transpose(lhs, rhs, strides, padding,
dimension_numbers=dspec,
rhs_dilation=rhs_dilation,
transpose_kernel=False)
def fun_via_grad(lhs, rhs):
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
rhs_dilation=rhs_dilation,
dimension_numbers=dspec)
# NB: below just checks for agreement, we're not calling numpy.
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),