mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix reshape transpose bug (and add tests)
This version of reshape (taking a `dimensions` argument, which effectively fuses in a transpose) seems only to be used in the JVP rule for lax._reduce_prod (basically np.product), but its transpose rule was totally busted and untested.
This commit is contained in:
parent
1dc4a4d05e
commit
fef68deef6
@ -1600,7 +1600,7 @@ ad.defjvp(sub_p,
|
||||
ad.primitive_transposes[sub_p] = _sub_transpose
|
||||
|
||||
mul_p = standard_binop([_num, _num], 'mul')
|
||||
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
|
||||
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul)
|
||||
|
||||
|
||||
def _safe_mul_translation_rule(c, x, y):
|
||||
@ -2272,11 +2272,11 @@ def _reshape_translation_rule(c, operand, new_sizes, dimensions, old_sizes):
|
||||
return c.Reshape(operand, new_sizes=new_sizes, dimensions=dimensions)
|
||||
|
||||
def _reshape_transpose_rule(t, new_sizes, dimensions, old_sizes):
|
||||
out = reshape(t, old_sizes)
|
||||
if dimensions is None:
|
||||
return [out]
|
||||
return [reshape(t, old_sizes)]
|
||||
else:
|
||||
return [transpose(out, onp.argsort(dimensions))]
|
||||
return [transpose(reshape(t, onp.take(old_sizes, dimensions)),
|
||||
onp.argsort(dimensions))]
|
||||
|
||||
def _reshape_batch_rule(batched_args, batch_dims, new_sizes, dimensions, **unused):
|
||||
operand, = batched_args
|
||||
@ -3084,8 +3084,6 @@ ad.deflinear(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p)
|
||||
|
||||
|
||||
|
||||
|
||||
def _reduce_prod_shape_rule(operand, axes):
|
||||
return tuple(onp.delete(operand.shape, axes))
|
||||
|
||||
|
@ -1035,6 +1035,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"dims": dims, "rng": rng}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(-onp.inf, lax.max, float_dtypes),
|
||||
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
|
||||
(onp.iinfo(onp.int64).min, lax.max, [onp.int64]),
|
||||
@ -1754,20 +1755,29 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], tol, tol, tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_outshape={}".format(
|
||||
{"testcase_name": "_inshape={}_outshape={}_perm={}".format(
|
||||
jtu.format_shape_dtype_string(arg_shape, dtype),
|
||||
jtu.format_shape_dtype_string(out_shape, dtype)),
|
||||
jtu.format_shape_dtype_string(out_shape, dtype),
|
||||
permutation),
|
||||
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
|
||||
"rng": rng}
|
||||
"rng": rng, "permutation": permutation}
|
||||
for dtype in float_dtypes
|
||||
for arg_shape, out_shape in [
|
||||
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
|
||||
for arg_shape, out_shape, permutation in [
|
||||
[(3, 4), (12,), None],
|
||||
[(2, 1, 4), (8,), None],
|
||||
[(2, 2, 4), (2, 8), None],
|
||||
[(3, 4), (12,), (0, 1)],
|
||||
[(3, 4), (12,), (1, 0)],
|
||||
[(2, 1, 4), (8,), (0, 2, 1)],
|
||||
[(2, 1, 4), (8,), (2, 0, 1)],
|
||||
[(2, 2, 4), (2, 8), (0, 2, 1)],
|
||||
[(2, 2, 4), (2, 8), (2, 0, 1)],
|
||||
]
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testReshapeGrad(self, arg_shape, out_shape, dtype, rng):
|
||||
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng):
|
||||
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
|
||||
operand = rng(arg_shape, dtype)
|
||||
reshape = lambda x: lax.reshape(x, out_shape)
|
||||
reshape = lambda x: lax.reshape(x, out_shape, permutation)
|
||||
check_grads(reshape, (operand,), 2, ["fwd", "rev"], tol, tol, tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
@ -1916,9 +1926,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
"dims": dims, "rng": rng}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, inexact_dtypes),
|
||||
(1, lax.mul, inexact_dtypes),
|
||||
(-onp.inf, lax.max, inexact_dtypes),
|
||||
(onp.inf, lax.min, inexact_dtypes),
|
||||
(1, lax.mul, inexact_dtypes),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for shape, dims in [
|
||||
@ -1926,7 +1936,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
[(3, 4, 5), (0,)],
|
||||
[(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)],
|
||||
[(3, 4, 5), (0, 1, 2)]
|
||||
[(3, 4, 5), (0, 1, 2)],
|
||||
[(3, 1), (1,)],
|
||||
]
|
||||
for rng in [jtu.rand_small()]))
|
||||
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng):
|
||||
|
Loading…
x
Reference in New Issue
Block a user