mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[MLIR] Support all fill_modes in GPU MLIR lowering for scatter_add.
PiperOrigin-RevId: 415617659
This commit is contained in:
parent
3969eec0e0
commit
53318a2a7a
@ -1993,7 +1993,12 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
assert mode == GatherScatterMode.PROMISE_IN_BOUNDS, mode
|
||||
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx, avals_in, None, operand, indices, updates,
|
||||
dnums=dimension_numbers)
|
||||
|
||||
aval_out, = avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
|
||||
|
Loading…
x
Reference in New Issue
Block a user