Propagate symbolic zeros instead of instantiating them.

This commit is contained in:
Peter Hawkins 2019-05-28 15:41:27 -04:00
parent 7ee59e96d3
commit cfdf1cd3e9

View File

@ -2804,12 +2804,14 @@ def _scatter_add_transpose_rule(t, operand, scatter_indices, updates,
update_jaxpr, update_consts, dimension_numbers,
updates_shape):
assert scatter_indices is not None
if t is ad_util.zero:
return [ad_util.zero, None, ad_util.zero]
operand_t = update_t = None
if operand is None:
operand_t = t
if updates is None:
t = ad.instantiate_zeros(operand, t)
gather_dnums = GatherDimensionNumbers(
offset_dims=dimension_numbers.update_window_dims,
collapsed_slice_dims=dimension_numbers.inserted_window_dims,