mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix transpose batching rule bug, add tests
This commit is contained in:
parent
98ace13a98
commit
adaea811fc
@ -1722,7 +1722,7 @@ def _transpose_shape_rule(operand, permutation):
|
||||
def _transpose_batch_rule(batched_args, batch_dims, permutation):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
perm = tuple(onp.insert(onp.add(permutation, 1), bdim, 0))
|
||||
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
|
||||
return transpose(operand, perm), 0
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
|
@ -771,6 +771,28 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
|
||||
assert onp.all(ans == expected)
|
||||
|
||||
def testTranspose(self):
|
||||
x = onp.arange(4 * 3 * 3).reshape((4, 3, 3))
|
||||
ans = vmap(lambda x: x + x.T)(x)
|
||||
expected = x + onp.swapaxes(x, -1, -2)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposePermutation(self):
|
||||
x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
|
||||
ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x)
|
||||
expected = onp.transpose(x, (0, 2, 1, 3))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
|
||||
ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x)
|
||||
expected = onp.transpose(x, (0, 2, 3, 1))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
|
||||
ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x)
|
||||
expected = onp.transpose(x, (2, 1, 3, 0))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user