fix bug in dot batching rule

This commit is contained in:
James Bradbury 2019-06-05 15:17:06 -07:00
parent a61d5f6e78
commit 9bc5f2aecf
2 changed files with 9 additions and 1 deletions

View File

@ -1913,7 +1913,7 @@ def _dot_batch_rule(batched_args, batch_dims):
if rbd is None:
assert lbd is not None
rhs = broadcast(rhs, (lhs.shape[lbd],))
rhs = broadcast(rhs, (lhs.shape[0],))
else:
rhs = batching.move_dim_to_front(rhs, rbd)
rhs_batch = (0,)

View File

@ -282,6 +282,14 @@ class BatchingTest(jtu.JaxTestCase):
expected = onp.einsum('ni,ni->n', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)
def testDot3(self):
R = onp.random.RandomState(0).randn
xs = R(5, 8, 10)
ys = R(10, 1)
ans = vmap(np.dot, in_axes=(1, None))(xs, ys)
expected = onp.einsum('inj,jk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPad(self):
R = onp.random.RandomState(0).randn