mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
avoid generating trivial lax.slice operations
This commit is contained in:
parent
0e749f29ef
commit
b4e83ee751
13
jax/lax.py
13
jax/lax.py
@ -556,10 +556,15 @@ def slice(operand, start_indices, limit_indices, strides=None):
|
||||
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
|
||||
operator.
|
||||
"""
|
||||
return slice_p.bind(operand, start_indices=tuple(start_indices),
|
||||
limit_indices=tuple(limit_indices),
|
||||
strides=None if strides is None else tuple(strides),
|
||||
operand_shape=operand.shape)
|
||||
if (onp.all(onp.equal(start_indices, 0))
|
||||
and onp.all(onp.equal(limit_indices, operand.shape))
|
||||
and strides is None):
|
||||
return operand
|
||||
else:
|
||||
return slice_p.bind(operand, start_indices=tuple(start_indices),
|
||||
limit_indices=tuple(limit_indices),
|
||||
strides=None if strides is None else tuple(strides),
|
||||
operand_shape=operand.shape)
|
||||
|
||||
def dynamic_slice(operand, start_indices, slice_sizes):
|
||||
"""Wraps XLA's `DynamicSlice
|
||||
|
Loading…
x
Reference in New Issue
Block a user