Made scatter-transpose more efficient

This commit is contained in:
Patrick Kidger 2023-06-14 22:25:35 -07:00
parent bfe8acb31e
commit 0f39d59ef0

View File

@ -1981,12 +1981,10 @@ def _scatter_transpose_rule(t, operand, indices, updates, *,
operand_t = update_t = None
if ad.is_undefined_primal(operand):
# Zero out gradient entries that correspond to updated indices.
mask = scatter(lax._ones(t, dtype=np.bool_), indices,
lax.full(updates_shape, False),
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=True, mode=mode)
operand_t = lax.select(mask, t, lax._zeros(t))
operand_t = scatter(t, indices, lax.full(updates_shape, 0, dtype=t.dtype),
dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=True, mode=mode)
if ad.is_undefined_primal(updates):
gather_dnums = GatherDimensionNumbers(