mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Improve speed of tracing dynamic_update_slice (#3247)
* Improve tracing performance of _dynamic_slice_indices * More precisely preserve semantics of dynamic_slice_indices * Use safe_map in dynamic_slice_indices
This commit is contained in:
parent
38bfcee753
commit
3909875f9d
@ -46,7 +46,7 @@ from ..interpreters import pxla
|
||||
from ..interpreters import ad
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import masking
|
||||
from ..util import curry, cache, safe_zip, unzip2, prod
|
||||
from ..util import curry, cache, safe_zip, unzip2, prod, safe_map
|
||||
from ..tree_util import build_tree, tree_unflatten, tree_map
|
||||
from ..lib import pytree
|
||||
from ..lib import xla_bridge
|
||||
@ -5275,22 +5275,24 @@ def _check_shapelike(fun_name, arg_name, obj):
|
||||
|
||||
|
||||
def _dynamic_slice_indices(operand, start_indices):
|
||||
if not isinstance(start_indices, (tuple, list)):
|
||||
if start_indices.ndim != 1:
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape))
|
||||
start_indices = [squeeze(slice(start_indices, [i], [i+1]), dimensions=(0,))
|
||||
for i in range(operand.ndim)]
|
||||
else:
|
||||
start_indices = [onp.asarray(i, dtype=dtypes.int_) if isinstance(i, int)
|
||||
else i for i in start_indices]
|
||||
if len(start_indices) != operand.ndim:
|
||||
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
||||
"vs {})")
|
||||
raise ValueError(msg.format(len(start_indices), operand.shape))
|
||||
# map int over operand.shape to raise any dynamic-shape errors
|
||||
return [select(lt(i, _const(i, 0)), add(i, _const(i, int(d))), i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
safe_map(int, operand.shape)
|
||||
if not isinstance(start_indices, (tuple, list)):
|
||||
if start_indices.ndim != 1:
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape))
|
||||
return select(lt(start_indices, _zeros(start_indices)),
|
||||
add(start_indices, _const(start_indices, operand.shape)),
|
||||
start_indices)
|
||||
else:
|
||||
return [onp.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
|
||||
if isinstance(i, (int, onp.integer))
|
||||
else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user