Add support for vmap of scatter where indices but not updates are batched.

This commit is contained in:
Peter Hawkins 2019-05-29 17:13:46 -04:00
parent 117749b754
commit 0293ecbb5e
2 changed files with 10 additions and 2 deletions

View File

@ -2860,8 +2860,10 @@ def _scatter_batching_rule(
operand_bdim = 0
if scatter_indices_bdim is not None and updates_bdim is None:
raise NotImplementedError # TODO(mattjj,phawkins)
elif scatter_indices_bdim is None and updates_bdim is not None:
updates = broadcast(updates, (size,))
updates_bdim = 0
if scatter_indices_bdim is None and updates_bdim is not None:
updates = batching.move_dim_to_front(updates, updates_bdim)
inserted_window_dims = tuple(onp.add(1, dimension_numbers.inserted_window_dims))
update_window_dims = (0,) + tuple(onp.add(1, dimension_numbers.update_window_dims))

View File

@ -31,6 +31,7 @@ from jax.api import vmap
from jax.core import unit
from jax.interpreters import partial_eval as pe
from jax.util import partial, curry
import jax.ops
from jax.config import config
config.parse_flags_with_absl()
@ -911,6 +912,11 @@ class BatchingTest(jtu.JaxTestCase):
self.assertAllClose(result, onp.array([1, 2]), check_dtypes=False)
self.assertEqual((), empty_tuple)
def testIndexAddBatchedIndexesOnly(self):
f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
result = vmap(f, (None, 0, None))(onp.zeros((10,)), onp.arange(10,), 1.)
self.assertAllClose(result, onp.eye(10), check_dtypes=False)
if __name__ == '__main__':
absltest.main()