mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
generalize select batch rule (fixes #311)
This commit is contained in:
parent
1ab4a2ea54
commit
fe96c15d49
16
jax/lax.py
16
jax/lax.py
@ -1755,17 +1755,15 @@ def _select_transpose_rule(t, pred, on_true, on_false):
|
||||
select(pred, zeros, t) if on_false is None else None]
|
||||
|
||||
def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
||||
oprand, on_true, on_false, = batched_args
|
||||
pred, on_true, on_false, = batched_args
|
||||
pred_bdim, ot_bdim, of_bdim = batch_dims
|
||||
|
||||
if (ot_bdim not in {None, pred_bdim}) or (of_bdim not in {None, pred_bdim}):
|
||||
raise NotImplementedError # TODO(schsam, mattjj): Handle more cases.
|
||||
|
||||
# TODO(schsam, mattjj): Switch to using broadcast_in_dim.
|
||||
ot = _ones(oprand) * on_true
|
||||
of = _ones(oprand) * on_false
|
||||
|
||||
return select(oprand, ot, of), pred_bdim
|
||||
size = next(x.shape[i] for x, i in zip(batched_args, batch_dims)
|
||||
if i is not None)
|
||||
pred = batching.bdim_at_front(pred, pred_bdim, size)
|
||||
on_true = batching.bdim_at_front(on_true, ot_bdim, size)
|
||||
on_false = batching.bdim_at_front(on_false, of_bdim, size)
|
||||
return select(pred, on_true, on_false), 0
|
||||
|
||||
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
|
||||
ad.defjvp(select_p,
|
||||
|
Loading…
x
Reference in New Issue
Block a user