mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #12219 from jakevdp:indexing-slice
PiperOrigin-RevId: 485946084
This commit is contained in:
commit
f4be5ab173
@ -676,6 +676,19 @@ def bench_slicing_compilation2(state):
|
||||
while state:
|
||||
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_repeated_static_indexing(state):
|
||||
x = jnp.arange(500)
|
||||
while state:
|
||||
jax.block_until_ready([x[i] for i in range(500)])
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_repeated_static_slicing(state):
|
||||
x = jnp.arange(1000)
|
||||
while state:
|
||||
jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)])
|
||||
|
||||
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
spec = pjit_lib.PartitionSpec('x')
|
||||
|
@ -3776,10 +3776,9 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
|
||||
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
|
||||
if 0 <= idx < arr.shape[0]:
|
||||
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
|
||||
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
|
||||
else:
|
||||
return lax.index_in_dim(arr, idx, keepdims=False)
|
||||
# Use dynamic rather than static index here to avoid slow repeated execution:
|
||||
# See https://github.com/google/jax/issues/12198
|
||||
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
|
||||
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
|
||||
isinstance(idx, slice) and
|
||||
(type(idx.start) is int or idx.start is None) and
|
||||
@ -3794,6 +3793,10 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
|
||||
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
|
||||
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
|
||||
elif step == 1:
|
||||
# Use dynamic rather than static slice here to avoid slow repeated execution:
|
||||
# See https://github.com/google/jax/issues/12198
|
||||
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
|
||||
else:
|
||||
return lax.slice_in_dim(arr, start, stop, step)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user