fix transpose batching rule bug, add tests

This commit is contained in:
Matthew Johnson 2019-02-12 07:26:32 -08:00
parent 98ace13a98
commit adaea811fc
2 changed files with 23 additions and 1 deletions

View File

@ -1722,7 +1722,7 @@ def _transpose_shape_rule(operand, permutation):
def _transpose_batch_rule(batched_args, batch_dims, permutation):
operand, = batched_args
bdim, = batch_dims
perm = tuple(onp.insert(onp.add(permutation, 1), bdim, 0))
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
return transpose(operand, perm), 0
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,

View File

@ -771,6 +771,28 @@ class BatchingTest(jtu.JaxTestCase):
expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
assert onp.all(ans == expected)
def testTranspose(self):
x = onp.arange(4 * 3 * 3).reshape((4, 3, 3))
ans = vmap(lambda x: x + x.T)(x)
expected = x + onp.swapaxes(x, -1, -2)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposePermutation(self):
x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x)
expected = onp.transpose(x, (0, 2, 1, 3))
self.assertAllClose(ans, expected, check_dtypes=False)
x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x)
expected = onp.transpose(x, (0, 2, 3, 1))
self.assertAllClose(ans, expected, check_dtypes=False)
x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x)
expected = onp.transpose(x, (2, 1, 3, 0))
self.assertAllClose(ans, expected, check_dtypes=False)
if __name__ == '__main__':
absltest.main()