Merge pull request #12219 from jakevdp:indexing-slice

PiperOrigin-RevId: 485946084
This commit is contained in:
jax authors 2022-11-03 12:44:28 -07:00
commit f4be5ab173
2 changed files with 20 additions and 4 deletions

View File

@ -676,6 +676,19 @@ def bench_slicing_compilation2(state):
while state:
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_repeated_static_indexing(state):
x = jnp.arange(500)
while state:
jax.block_until_ready([x[i] for i in range(500)])
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_repeated_static_slicing(state):
x = jnp.arange(1000)
while state:
jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)])
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
spec = pjit_lib.PartitionSpec('x')

View File

@ -3776,10 +3776,9 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
else:
return lax.index_in_dim(arr, idx, keepdims=False)
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
@ -3794,6 +3793,10 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)