Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)

This commit is contained in:
Peter Hawkins 2020-07-13 09:43:19 -04:00 committed by GitHub
parent 0d81e988d8
commit a9da06ce75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 5 deletions

View File

@ -4284,9 +4284,9 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes):
paddings[axis] = (0, 1, 0)
x2 = pad(x2, _const(x, 1), paddings)
x = x1 * x2
shape = list(x.shape)
del shape[axis]
return reshape(x, shape)
if x.shape[axis] == 0:
return full(input_shape[non_axes], _one(x))
return squeeze(x, (axis,))
return api.jvp(_reduce_prod_tree, (operand,), (tangent,))

View File

@ -651,6 +651,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
[(3, 0, 5), (1,)],
]))
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
rng = rng_factory(self.rng())
@ -664,7 +665,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
1e-1 if dtype == dtypes.bfloat16 else
1e-2 if dtypes.finfo(dtype).bits == 32 else None)
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_dtype={}_padding={}"

View File

@ -469,7 +469,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact):
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
message="invalid value.*")(np_op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)