generalize select batch rule (fixes #311)

This commit is contained in:
Matthew Johnson 2019-02-03 09:27:03 -08:00
parent 1ab4a2ea54
commit fe96c15d49

View File

@ -1755,17 +1755,15 @@ def _select_transpose_rule(t, pred, on_true, on_false):
select(pred, zeros, t) if on_false is None else None]
def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
oprand, on_true, on_false, = batched_args
pred, on_true, on_false, = batched_args
pred_bdim, ot_bdim, of_bdim = batch_dims
if (ot_bdim not in {None, pred_bdim}) or (of_bdim not in {None, pred_bdim}):
raise NotImplementedError # TODO(schsam, mattjj): Handle more cases.
# TODO(schsam, mattjj): Switch to using broadcast_in_dim.
ot = _ones(oprand) * on_true
of = _ones(oprand) * on_false
return select(oprand, ot, of), pred_bdim
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
if i is not None)
pred = batching.bdim_at_front(pred, pred_bdim, size)
on_true = batching.bdim_at_front(on_true, ot_bdim, size)
on_false = batching.bdim_at_front(on_false, of_bdim, size)
return select(pred, on_true, on_false), 0
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
ad.defjvp(select_p,