mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add more special cases of select batching rule
This commit is contained in:
parent
44cffd0053
commit
bf7a438c94
@ -1774,6 +1774,13 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
||||
assert onp.ndim(pred) == 1
|
||||
pred = broadcast_in_dim(pred, on_true.shape, [pred_bdim])
|
||||
return select(pred, on_true, on_false), pred_bdim
|
||||
elif onp.ndim(pred) == 0 and ot_bdim is not None and of_bdim is not None:
|
||||
if ot_bdim == of_bdim:
|
||||
return select(pred, on_true, on_false), ot_bdim
|
||||
else:
|
||||
assert onp.shape(on_true) == onp.shape(on_false)
|
||||
on_false = batching.moveaxis(size, ot_bdim, of_bdim, on_false)
|
||||
return select(pred, on_true, on_false), ot_bdim
|
||||
|
||||
pred = batching.bdim_at_front(pred, pred_bdim, size, force_broadcast=True)
|
||||
on_true = batching.bdim_at_front(on_true, ot_bdim, size, force_broadcast=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user