inline and remove scatter_mlir rules

This commit is contained in:
Roy Frostig 2023-05-12 15:29:34 -07:00
parent cc54b6e6ad
commit f18bff5371
3 changed files with 38 additions and 24 deletions

View File

@ -2016,16 +2016,39 @@ batching.primitive_batchers[scatter_p] = (
partial(_scatter_batching_rule, scatter))
def _scatter_lower_opaque(ctx, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
unique_indices, indices_are_sorted, mode):
aval_x, aval_indices, aval_updates = ctx.avals_in
aval_y, = ctx.avals_out
elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape
trailing_window_dims = [aval_updates.ndim + i for i in range(len(elt_shape))]
dimension_numbers = dimension_numbers._replace(
update_window_dims=(*dimension_numbers.update_window_dims,
*trailing_window_dims))
scatter_lower = partial(
_scatter_lower, update_jaxpr=update_jaxpr, update_consts=update_consts,
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode)
res, = mlir.delegate_lowering(
ctx, scatter_lower, operand, indices, updates,
avals_in=[core.physical_aval(aval_x), aval_indices,
core.physical_aval(aval_updates)],
avals_out=[core.physical_aval(aval_y)])
return res
def _scatter_lower(ctx, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
aval_out, = ctx.avals_out
if dtypes.is_opaque_dtype(aval_out.dtype):
return [aval_out.dtype._rules.scatter_mlir(
ctx, ctx.avals_in, aval_out, operand, indices, updates,
return [_scatter_lower_opaque(
ctx, operand, indices, updates,
update_jaxpr=update_jaxpr, update_consts=update_consts,
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode)]
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,

View File

@ -521,28 +521,6 @@ class KeyTyRules:
# element-type-polymorphic primitive lowering rules
@staticmethod
def scatter_mlir(ctx, avals_in, aval_out, x, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
unique_indices, indices_are_sorted, mode):
aval_x, aval_indices, aval_updates = avals_in
aval_y = aval_out
key_shape = aval_x.dtype.impl.key_shape
trailing_window_dims = [aval_updates.ndim + i for i in range(len(key_shape))]
dimension_numbers = dimension_numbers._replace(
update_window_dims=(*dimension_numbers.update_window_dims, *trailing_window_dims))
scatter_lower = partial(
lax_internal.slicing._scatter_lower, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode)
res, = mlir.delegate_lowering(
ctx, scatter_lower, x, indices, updates,
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices,
keys_aval_to_base_arr_aval(aval_updates)],
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
return res
def _comparison_mlir(direction, reduction_op, identity,
ctx, avals_in, aval_out, x, y, **kwargs):
aval_x, aval_y = avals_in

View File

@ -3139,6 +3139,19 @@ class CustomElementTypesTest(jtu.JaxTestCase):
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (3, 2, 1))
@parameterized.parameters([
(0,),
(slice(1),),
(np.array([0, 2]),),
(np.array([False, True, True]),)
])
def test_scatter(self, idx):
k = jax.jit(lambda: make(()))()
ks = jax.jit(lambda: make((3,)))()
ys = jax.jit(lambda x, y: x.at[idx].set(y))(ks, k)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (3,))
def test_select(self):
ks = jax.jit(lambda: make((3,)))()
cs = jnp.array([True, False, False])