Avoid instantiating zeros in dynamic_slice/gather transpose rules.

This commit is contained in:
Peter Hawkins 2019-06-24 13:44:49 -04:00
parent a8cf0cd36d
commit c1bec691c5

View File

@ -2630,7 +2630,7 @@ def _dynamic_slice_jvp_rule(g, operand, start_indices, slice_sizes,
def _dynamic_slice_transpose_rule(t, operand, start_indices, slice_sizes,
operand_shape):
assert operand is None
zeros = full(operand_shape, 0, dtype=_dtype(t))
zeros = full(operand_shape, tie_in(t, _zero(t)))
return [dynamic_update_slice(zeros, t, start_indices), ad_util.zero]
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,
@ -2798,7 +2798,7 @@ def _gather_transpose_rule(t, operand, start_indices, dimension_numbers,
assert operand is None
if t is ad_util.zero:
return [ad_util.zero, ad_util.zero]
zeros = full(operand_shape, 0, dtype=t.dtype)
zeros = full(operand_shape, tie_in(t, _zero(t)))
scatter_dnums = ScatterDimensionNumbers(
update_window_dims=dimension_numbers.offset_dims,
inserted_window_dims=dimension_numbers.collapsed_slice_dims,