mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for vmap of scatter where indices but not updates are batched.
This commit is contained in:
parent
117749b754
commit
0293ecbb5e
@ -2860,8 +2860,10 @@ def _scatter_batching_rule(
|
||||
operand_bdim = 0
|
||||
|
||||
if scatter_indices_bdim is not None and updates_bdim is None:
|
||||
raise NotImplementedError # TODO(mattjj,phawkins)
|
||||
elif scatter_indices_bdim is None and updates_bdim is not None:
|
||||
updates = broadcast(updates, (size,))
|
||||
updates_bdim = 0
|
||||
|
||||
if scatter_indices_bdim is None and updates_bdim is not None:
|
||||
updates = batching.move_dim_to_front(updates, updates_bdim)
|
||||
inserted_window_dims = tuple(onp.add(1, dimension_numbers.inserted_window_dims))
|
||||
update_window_dims = (0,) + tuple(onp.add(1, dimension_numbers.update_window_dims))
|
||||
|
@ -31,6 +31,7 @@ from jax.api import vmap
|
||||
from jax.core import unit
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.util import partial, curry
|
||||
import jax.ops
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -911,6 +912,11 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False)
|
||||
self.assertEqual((), empty_tuple)
|
||||
|
||||
def testIndexAddBatchedIndexesOnly(self):
|
||||
f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
|
||||
result = vmap(f, (None, 0, None))(onp.zeros((10,)), onp.arange(10,), 1.)
|
||||
self.assertAllClose(result, onp.eye(10), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user