add more special cases of select batching rule

This commit is contained in:
Matthew Johnson 2019-02-03 14:00:51 -08:00
parent 44cffd0053
commit bf7a438c94

View File

@ -1774,6 +1774,13 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
assert onp.ndim(pred) == 1
pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim])
return select(pred, on_true, on_false), pred_bdim
elif onp.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None:
if ot_bdim == of_bdim:
return select(pred, on_true, on_false), ot_bdim
else:
assert onp.shape(on_true) == onp.shape(on_false)
on_false = batching.moveaxis(size, ot_bdim, of_bdim, on_false)
return select(pred, on_true, on_false), ot_bdim
pred = batching.bdim_at_front(pred, pred_bdim, size, force_broadcast=True)
on_true = batching.bdim_at_front(on_true, ot_bdim, size, force_broadcast=True)