mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Propagate symbolic zeros instead of instantiating them.
This commit is contained in:
parent
7ee59e96d3
commit
cfdf1cd3e9
@ -2804,12 +2804,14 @@ def _scatter_add_transpose_rule(t, operand, scatter_indices, updates,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
updates_shape):
|
||||
assert scatter_indices is not None
|
||||
if t is ad_util.zero:
|
||||
return [ad_util.zero, None, ad_util.zero]
|
||||
|
||||
operand_t = update_t = None
|
||||
if operand is None:
|
||||
operand_t = t
|
||||
|
||||
if updates is None:
|
||||
t = ad.instantiate_zeros(operand, t)
|
||||
gather_dnums = GatherDimensionNumbers(
|
||||
offset_dims=dimension_numbers.update_window_dims,
|
||||
collapsed_slice_dims=dimension_numbers.inserted_window_dims,
|
||||
|
Loading…
x
Reference in New Issue
Block a user