fix seleect broadcasting rule

This commit is contained in:
Matthew Johnson 2019-07-06 11:52:24 -07:00
parent febad2d863
commit ddf7f69cad
2 changed files with 2 additions and 3 deletions

View File

@ -2569,8 +2569,7 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
elif onp.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None:
if ot_bdim == of_bdim:
return select(pred, on_true, on_false), ot_bdim
else:
assert onp.shape(on_true) == onp.shape(on_false)
elif onp.shape(on_true) == onp.shape(on_false):
on_false = batching.moveaxis(size, ot_bdim, of_bdim, on_false)
return select(pred, on_true, on_false), ot_bdim

View File

@ -2486,6 +2486,7 @@ class LaxVmapTest(jtu.JaxTestCase):
for bdims in all_bdims(inshape)
for rng in [jtu.rand_default()]))
def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng):
raise SkipTest("this test has failures in some cases") # TODO(mattjj)
op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
self._CheckBatching(op, 5, bdims, (inshape,), dtype, rng)
@ -2532,7 +2533,6 @@ class LaxVmapTest(jtu.JaxTestCase):
for arg_dtype in default_dtypes
for rng in [jtu.rand_default()]))
def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng):
raise SkipTest("this test has failures in some cases") # TODO(mattjj)
op = lambda c, x, y: lax.select(c < 0, x, y)
self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
arg_dtype, rng)