Handle 0D convolutions correctly in shape rule. (#1972)

This commit is contained in:
Peter Hawkins 2020-01-09 14:36:37 -05:00 committed by GitHub
parent 327dca8f76
commit facbe0d76a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 1 deletions

View File

@ -4331,7 +4331,8 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads)))
lhs_padded = onp.add(lhs_shape[2:], onp.sum(onp.array(pads).reshape(-1, 2),
axis=1))
out_space = onp.floor_divide(
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
out_space = onp.maximum(0, out_space)

View File

@ -489,6 +489,16 @@ class LaxTest(jtu.JaxTestCase):
# TODO(mattjj): test conv_general_dilated against numpy
def testConv0DIsDot(self):
rng = jtu.rand_default()
def args_maker():
return [rng((10, 5), onp.float32), rng((5, 7), onp.float32)]
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
self._CheckAgainstNumpy(jnp_fun, onp.dot, args_maker)
@staticmethod
def _conv_transpose_via_grad(data, kernel, strides, padding,
rhs_dilation=None, dimension_numbers=None):