mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Avoid instantiating zeros in dynamic_slice/gather transpose rules.
This commit is contained in:
parent
a8cf0cd36d
commit
c1bec691c5
@ -2630,7 +2630,7 @@ def _dynamic_slice_jvp_rule(g, operand, start_indices, slice_sizes,
|
||||
def _dynamic_slice_transpose_rule(t, operand, start_indices, slice_sizes,
|
||||
operand_shape):
|
||||
assert operand is None
|
||||
zeros = full(operand_shape, 0, dtype=_dtype(t))
|
||||
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
||||
return [dynamic_update_slice(zeros, t, start_indices), ad_util.zero]
|
||||
|
||||
def _dynamic_slice_batching_rule(batched_args, batch_dims, slice_sizes,
|
||||
@ -2798,7 +2798,7 @@ def _gather_transpose_rule(t, operand, start_indices, dimension_numbers,
|
||||
assert operand is None
|
||||
if t is ad_util.zero:
|
||||
return [ad_util.zero, ad_util.zero]
|
||||
zeros = full(operand_shape, 0, dtype=t.dtype)
|
||||
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
||||
scatter_dnums = ScatterDimensionNumbers(
|
||||
update_window_dims=dimension_numbers.offset_dims,
|
||||
inserted_window_dims=dimension_numbers.collapsed_slice_dims,
|
||||
|
Loading…
x
Reference in New Issue
Block a user