fix reshape transpose bug (and add tests)

This version of reshape (taking a `dimensions` argument, which
effectively fuses in a transpose) seems only to be used in the JVP rule
for lax._reduce_prod (basically np.product), but its transpose rule was
totally busted and untested.
This commit is contained in:
Matthew Johnson 2019-06-17 11:49:54 -07:00
parent 1dc4a4d05e
commit fef68deef6
2 changed files with 24 additions and 15 deletions

View File

@ -1600,7 +1600,7 @@ ad.defjvp(sub_p,
ad.primitive_transposes[sub_p] = _sub_transpose
mul_p = standard_binop([_num, _num], 'mul')
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul)
def _safe_mul_translation_rule(c, x, y):
@ -2272,11 +2272,11 @@ def _reshape_translation_rule(c, operand, new_sizes, dimensions, old_sizes):
return c.Reshape(operand, new_sizes=new_sizes, dimensions=dimensions)
def _reshape_transpose_rule(t, new_sizes, dimensions, old_sizes):
out = reshape(t, old_sizes)
if dimensions is None:
return [out]
return [reshape(t, old_sizes)]
else:
return [transpose(out, onp.argsort(dimensions))]
return [transpose(reshape(t, onp.take(old_sizes, dimensions)),
onp.argsort(dimensions))]
def _reshape_batch_rule(batched_args, batch_dims, new_sizes, dimensions, **unused):
operand, = batched_args
@ -3084,8 +3084,6 @@ ad.deflinear(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p)
def _reduce_prod_shape_rule(operand, axes):
return tuple(onp.delete(operand.shape, axes))

View File

@ -1035,6 +1035,7 @@ class LaxTest(jtu.JaxTestCase):
"dims": dims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, default_dtypes),
(1, lax.mul, default_dtypes),
(-onp.inf, lax.max, float_dtypes),
(onp.iinfo(onp.int32).min, lax.max, [onp.int32]),
(onp.iinfo(onp.int64).min, lax.max, [onp.int64]),
@ -1754,20 +1755,29 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], tol, tol, tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_outshape={}".format(
{"testcase_name": "_inshape={}_outshape={}_perm={}".format(
jtu.format_shape_dtype_string(arg_shape, dtype),
jtu.format_shape_dtype_string(out_shape, dtype)),
jtu.format_shape_dtype_string(out_shape, dtype),
permutation),
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
"rng": rng}
"rng": rng, "permutation": permutation}
for dtype in float_dtypes
for arg_shape, out_shape in [
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
for arg_shape, out_shape, permutation in [
[(3, 4), (12,), None],
[(2, 1, 4), (8,), None],
[(2, 2, 4), (2, 8), None],
[(3, 4), (12,), (0, 1)],
[(3, 4), (12,), (1, 0)],
[(2, 1, 4), (8,), (0, 2, 1)],
[(2, 1, 4), (8,), (2, 0, 1)],
[(2, 2, 4), (2, 8), (0, 2, 1)],
[(2, 2, 4), (2, 8), (2, 0, 1)],
]
for rng in [jtu.rand_default()]))
def testReshapeGrad(self, arg_shape, out_shape, dtype, rng):
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
operand = rng(arg_shape, dtype)
reshape = lambda x: lax.reshape(x, out_shape)
reshape = lambda x: lax.reshape(x, out_shape, permutation)
check_grads(reshape, (operand,), 2, ["fwd", "rev"], tol, tol, tol)
@parameterized.named_parameters(jtu.cases_from_list(
@ -1916,9 +1926,9 @@ class LaxAutodiffTest(jtu.JaxTestCase):
"dims": dims, "rng": rng}
for init_val, op, dtypes in [
(0, lax.add, inexact_dtypes),
(1, lax.mul, inexact_dtypes),
(-onp.inf, lax.max, inexact_dtypes),
(onp.inf, lax.min, inexact_dtypes),
(1, lax.mul, inexact_dtypes),
]
for dtype in dtypes
for shape, dims in [
@ -1926,7 +1936,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)]
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
]
for rng in [jtu.rand_small()]))
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng):