mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Made scatter-transpose more efficient
This commit is contained in:
parent
bfe8acb31e
commit
0f39d59ef0
@ -1981,12 +1981,10 @@ def _scatter_transpose_rule(t, operand, indices, updates, *,
|
||||
operand_t = update_t = None
|
||||
if ad.is_undefined_primal(operand):
|
||||
# Zero out gradient entries that correspond to updated indices.
|
||||
mask = scatter(lax._ones(t, dtype=np.bool_), indices,
|
||||
lax.full(updates_shape, False),
|
||||
dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=True, mode=mode)
|
||||
operand_t = lax.select(mask, t, lax._zeros(t))
|
||||
operand_t = scatter(t, indices, lax.full(updates_shape, 0, dtype=t.dtype),
|
||||
dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=True, mode=mode)
|
||||
|
||||
if ad.is_undefined_primal(updates):
|
||||
gather_dnums = GatherDimensionNumbers(
|
||||
|
Loading…
x
Reference in New Issue
Block a user