mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Handle 0D convolutions correctly in shape rule. (#1972)
This commit is contained in:
parent
327dca8f76
commit
facbe0d76a
@ -4331,7 +4331,8 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
||||
msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
|
||||
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
|
||||
|
||||
lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads)))
|
||||
lhs_padded = onp.add(lhs_shape[2:], onp.sum(onp.array(pads).reshape(-1, 2),
|
||||
axis=1))
|
||||
out_space = onp.floor_divide(
|
||||
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
||||
out_space = onp.maximum(0, out_space)
|
||||
|
@ -489,6 +489,16 @@ class LaxTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(mattjj): test conv_general_dilated against numpy
|
||||
|
||||
def testConv0DIsDot(self):
|
||||
rng = jtu.rand_default()
|
||||
def args_maker():
|
||||
return [rng((10, 5), onp.float32), rng((5, 7), onp.float32)]
|
||||
jnp_fun = partial(lax.conv_general_dilated, window_strides=(),
|
||||
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(jnp_fun, onp.dot, args_maker)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _conv_transpose_via_grad(data, kernel, strides, padding,
|
||||
rhs_dilation=None, dimension_numbers=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user