avoid generating trivial lax.slice operations

This commit is contained in:
Matthew Johnson 2019-03-21 07:27:08 -07:00
parent 0e749f29ef
commit b4e83ee751

View File

@ -556,10 +556,15 @@ def slice(operand, start_indices, limit_indices, strides=None):
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
operator.
"""
return slice_p.bind(operand, start_indices=tuple(start_indices),
limit_indices=tuple(limit_indices),
strides=None if strides is None else tuple(strides),
operand_shape=operand.shape)
if (onp.all(onp.equal(start_indices, 0))
and onp.all(onp.equal(limit_indices, operand.shape))
and strides is None):
return operand
else:
return slice_p.bind(operand, start_indices=tuple(start_indices),
limit_indices=tuple(limit_indices),
strides=None if strides is None else tuple(strides),
operand_shape=operand.shape)
def dynamic_slice(operand, start_indices, slice_sizes):
"""Wraps XLA's `DynamicSlice