mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Address review comment.
This commit is contained in:
parent
e57a5c42c5
commit
e28e73b38f
@ -2735,9 +2735,10 @@ def _dynamic_slice_jvp(primals, tangents, slice_sizes, operand_shape):
|
||||
def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes=None,
|
||||
operand_shape=None):
|
||||
assert operand is None
|
||||
assert all(s is not None for s in start_indices)
|
||||
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
||||
return ([dynamic_update_slice(zeros, t, start_indices)] +
|
||||
[ad_util.zero] * len(start_indices))
|
||||
[None] * len(start_indices))
|
||||
|
||||
def _batch_dynamic_slice_indices(indices, bdims):
|
||||
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), -1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user