From 42dd7cac43fd5f971ed314c0b9d8817dfbc2273b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 11 Aug 2022 19:39:50 -0700 Subject: [PATCH] simplify slicing jaxprs a little Co-authored-by: Sharad Vikram --- benchmarks/api_benchmark.py | 8 ++++++++ jax/_src/lax/slicing.py | 21 ++++++++++++--------- jax/_src/numpy/lax_numpy.py | 8 ++++---- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index c7c67aeb5..673f2ec0d 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -539,5 +539,13 @@ def bench_remat_eager_retracing_overheads_static_argnums(state): y.block_until_ready() +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def bench_slicing_compilation(state): + x = jnp.arange(3) + while state: + jax.jit(lambda x: (x[0], x[1], x[2])).lower(x).compile() + + if __name__ == "__main__": google_benchmark.main() diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index d6dcb516d..b01083f98 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -50,6 +50,8 @@ Shape = core.Shape map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +_dtype = partial(dtypes.dtype, canonicalize=True) + def slice(operand: Array, start_indices: Sequence[int], limit_indices: Sequence[int], @@ -263,7 +265,7 @@ def gather(operand: Array, start_indices: Array, parsed_mode = GatherScatterMode.from_any(mode) if parsed_mode == GatherScatterMode.FILL_OR_DROP: if fill_value is None: - dtype = lax._dtype(operand) + dtype = _dtype(operand) if dtypes.issubdtype(dtype, np.inexact): fill_value = np.nan elif dtypes.issubdtype(dtype, np.signedinteger): @@ -2042,14 +2044,15 @@ def _dynamic_slice_indices(operand, start_indices: Any): if start_indices.ndim != 1: raise ValueError("Slice indices must be a 1D sequence, got {}" .format(start_indices.shape)) - start_indices = [i for i in start_indices] - return [np.asarray(i + d if i < 0 else i, lax._dtype(i)) - if isinstance(i, (int, np.integer)) and core.is_constant_dim(d) - else lax.select( - lax.lt(i, lax._const(i, 0)), - lax.add(i, lax.convert_element_type(core.dimension_as_value(d), lax._dtype(i))), - i) - for i, d in zip(start_indices, operand.shape)] + start_indices = list(start_indices) + result = [] + for i, d in zip(start_indices, operand.shape): + if isinstance(i, (int, np.integer)) and core.is_constant_dim(d): + result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i) + else: + d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i)) + result.append(lax.select(i < 0, i + d, i)) + return result # TODO(mattjj): getslice is a prototype for dynamic shapes, revise or remove it diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ca0a7ffb1..248731773 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3485,10 +3485,10 @@ def _normalize_index(index, axis_size): else: axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size), _dtype(index)) - return lax.select( - lax.lt(index, _lax_const(index, 0)), - lax.add(index, axis_size_val), - index) + if isinstance(index, (int, np.integer)): + return lax.add(index, axis_size_val) if index < 0 else index + else: + return lax.select(index < 0, lax.add(index, axis_size_val), index) TAKE_ALONG_AXIS_DOC = """