mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
inline and remove scatter_mlir
rules
This commit is contained in:
parent
cc54b6e6ad
commit
f18bff5371
@ -2016,16 +2016,39 @@ batching.primitive_batchers[scatter_p] = (
|
||||
partial(_scatter_batching_rule, scatter))
|
||||
|
||||
|
||||
def _scatter_lower_opaque(ctx, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
unique_indices, indices_are_sorted, mode):
|
||||
aval_x, aval_indices, aval_updates = ctx.avals_in
|
||||
aval_y, = ctx.avals_out
|
||||
elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape
|
||||
trailing_window_dims = [aval_updates.ndim + i for i in range(len(elt_shape))]
|
||||
dimension_numbers = dimension_numbers._replace(
|
||||
update_window_dims=(*dimension_numbers.update_window_dims,
|
||||
*trailing_window_dims))
|
||||
scatter_lower = partial(
|
||||
_scatter_lower, update_jaxpr=update_jaxpr, update_consts=update_consts,
|
||||
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, mode=mode)
|
||||
res, = mlir.delegate_lowering(
|
||||
ctx, scatter_lower, operand, indices, updates,
|
||||
avals_in=[core.physical_aval(aval_x), aval_indices,
|
||||
core.physical_aval(aval_updates)],
|
||||
avals_out=[core.physical_aval(aval_y)])
|
||||
return res
|
||||
|
||||
|
||||
def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
aval_out, = ctx.avals_out
|
||||
if dtypes.is_opaque_dtype(aval_out.dtype):
|
||||
return [aval_out.dtype._rules.scatter_mlir(
|
||||
ctx, ctx.avals_in, aval_out, operand, indices, updates,
|
||||
return [_scatter_lower_opaque(
|
||||
ctx, operand, indices, updates,
|
||||
update_jaxpr=update_jaxpr, update_consts=update_consts,
|
||||
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, mode=mode)]
|
||||
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
|
||||
|
@ -521,28 +521,6 @@ class KeyTyRules:
|
||||
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def scatter_mlir(ctx, avals_in, aval_out, x, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
unique_indices, indices_are_sorted, mode):
|
||||
aval_x, aval_indices, aval_updates = avals_in
|
||||
aval_y = aval_out
|
||||
key_shape = aval_x.dtype.impl.key_shape
|
||||
trailing_window_dims = [aval_updates.ndim + i for i in range(len(key_shape))]
|
||||
dimension_numbers = dimension_numbers._replace(
|
||||
update_window_dims=(*dimension_numbers.update_window_dims, *trailing_window_dims))
|
||||
scatter_lower = partial(
|
||||
lax_internal.slicing._scatter_lower, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode)
|
||||
res, = mlir.delegate_lowering(
|
||||
ctx, scatter_lower, x, indices, updates,
|
||||
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices,
|
||||
keys_aval_to_base_arr_aval(aval_updates)],
|
||||
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
|
||||
return res
|
||||
|
||||
def _comparison_mlir(direction, reduction_op, identity,
|
||||
ctx, avals_in, aval_out, x, y, **kwargs):
|
||||
aval_x, aval_y = avals_in
|
||||
|
@ -3139,6 +3139,19 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(ys, FooArray)
|
||||
self.assertEqual(ys.shape, (3, 2, 1))
|
||||
|
||||
@parameterized.parameters([
|
||||
(0,),
|
||||
(slice(1),),
|
||||
(np.array([0, 2]),),
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_scatter(self, idx):
|
||||
k = jax.jit(lambda: make(()))()
|
||||
ks = jax.jit(lambda: make((3,)))()
|
||||
ys = jax.jit(lambda x, y: x.at[idx].set(y))(ks, k)
|
||||
self.assertIsInstance(ys, FooArray)
|
||||
self.assertEqual(ys.shape, (3,))
|
||||
|
||||
def test_select(self):
|
||||
ks = jax.jit(lambda: make((3,)))()
|
||||
cs = jnp.array([True, False, False])
|
||||
|
Loading…
x
Reference in New Issue
Block a user