[MLIR] Support all fill_modes in GPU MLIR lowering for scatter_add.

PiperOrigin-RevId: 415617659
This commit is contained in:
Peter Hawkins 2021-12-10 14:56:37 -08:00 committed by jax authors
parent 3969eec0e0
commit 53318a2a7a

View File

@ -1993,7 +1993,12 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
assert mode == GatherScatterMode.PROMISE_IN_BOUNDS, mode
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates,
dnums=dimension_numbers)
aval_out, = avals_out
dnums = dimension_numbers
scatter_dnums = mhlo.ScatterDimensionNumbers.get(