mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix bug in dot batching rule
This commit is contained in:
parent
a61d5f6e78
commit
9bc5f2aecf
@ -1913,7 +1913,7 @@ def _dot_batch_rule(batched_args, batch_dims):
|
||||
|
||||
if rbd is None:
|
||||
assert lbd is not None
|
||||
rhs = broadcast(rhs, (lhs.shape[lbd],))
|
||||
rhs = broadcast(rhs, (lhs.shape[0],))
|
||||
else:
|
||||
rhs = batching.move_dim_to_front(rhs, rbd)
|
||||
rhs_batch = (0,)
|
||||
|
@ -282,6 +282,14 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
expected = onp.einsum('ni,ni->n', xs, ys)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testDot3(self):
|
||||
R = onp.random.RandomState(0).randn
|
||||
xs = R(5, 8, 10)
|
||||
ys = R(10, 1)
|
||||
ans = vmap(np.dot, in_axes=(1, None))(xs, ys)
|
||||
expected = onp.einsum('inj,jk->nik', xs, ys)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPad(self):
|
||||
R = onp.random.RandomState(0).randn
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user