Improve speed of tracing dynamic_update_slice (#3247)

* Improve tracing performance of _dynamic_slice_indices

* More precisely preserve semantics of dynamic_slice_indices

* Use safe_map in dynamic_slice_indices
This commit is contained in:
Jamie Townsend 2020-06-02 14:37:32 +01:00 committed by GitHub
parent 38bfcee753
commit 3909875f9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -46,7 +46,7 @@ from ..interpreters import pxla
from ..interpreters import ad
from ..interpreters import batching
from ..interpreters import masking
from ..util import curry, cache, safe_zip, unzip2, prod
from ..util import curry, cache, safe_zip, unzip2, prod, safe_map
from ..tree_util import build_tree, tree_unflatten, tree_map
from ..lib import pytree
from ..lib import xla_bridge
@ -5275,22 +5275,24 @@ def _check_shapelike(fun_name, arg_name, obj):
def _dynamic_slice_indices(operand, start_indices):
if not isinstance(start_indices, (tuple, list)):
if start_indices.ndim != 1:
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
start_indices = [squeeze(slice(start_indices, [i], [i+1]), dimensions=(0,))
for i in range(operand.ndim)]
else:
start_indices = [onp.asarray(i, dtype=dtypes.int_) if isinstance(i, int)
else i for i in start_indices]
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
raise ValueError(msg.format(len(start_indices), operand.shape))
# map int over operand.shape to raise any dynamic-shape errors
return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)
for i, d in zip(start_indices, operand.shape)]
safe_map(int, operand.shape)
if not isinstance(start_indices, (tuple, list)):
if start_indices.ndim != 1:
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
return select(lt(start_indices, _zeros(start_indices)),
add(start_indices, _const(start_indices, operand.shape)),
start_indices)
else:
return [onp.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
if isinstance(i, (int, onp.integer))
else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i)
for i, d in zip(start_indices, operand.shape)]