mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
This commit is contained in:
parent
0d81e988d8
commit
a9da06ce75
@ -4284,9 +4284,9 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||
paddings[axis] = (0, 1, 0)
|
||||
x2 = pad(x2, _const(x, 1), paddings)
|
||||
x = x1 * x2
|
||||
shape = list(x.shape)
|
||||
del shape[axis]
|
||||
return reshape(x, shape)
|
||||
if x.shape[axis] == 0:
|
||||
return full(input_shape[non_axes], _one(x))
|
||||
return squeeze(x, (axis,))
|
||||
|
||||
return api.jvp(_reduce_prod_tree, (operand,), (tangent,))
|
||||
|
||||
|
@ -651,6 +651,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0, 2)],
|
||||
[(3, 4, 5), (0, 1, 2)],
|
||||
[(3, 1), (1,)],
|
||||
[(3, 0, 5), (1,)],
|
||||
]))
|
||||
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
@ -664,7 +665,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
|
||||
1e-1 if dtype == dtypes.bfloat16 else
|
||||
1e-2 if dtypes.finfo(dtype).bits == 32 else None)
|
||||
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
||||
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
|
||||
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
|
@ -469,7 +469,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testOp(self, np_op, jnp_op, rng_factory, shapes, dtypes, check_dtypes,
|
||||
tolerance, inexact):
|
||||
np_op = jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value.*")(np_op)
|
||||
message="invalid value.*")(np_op)
|
||||
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user