simplify slicing jaxprs a little

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
This commit is contained in:
Matthew Johnson 2022-08-11 19:39:50 -07:00
parent 7cd81ca1a8
commit 42dd7cac43
3 changed files with 24 additions and 13 deletions

View File

@ -539,5 +539,13 @@ def bench_remat_eager_retracing_overheads_static_argnums(state):
y.block_until_ready()
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_slicing_compilation(state):
x = jnp.arange(3)
while state:
jax.jit(lambda x: (x[0], x[1], x[2])).lower(x).compile()
if __name__ == "__main__":
google_benchmark.main()

View File

@ -50,6 +50,8 @@ Shape = core.Shape
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
_dtype = partial(dtypes.dtype, canonicalize=True)
def slice(operand: Array, start_indices: Sequence[int],
limit_indices: Sequence[int],
@ -263,7 +265,7 @@ def gather(operand: Array, start_indices: Array,
parsed_mode = GatherScatterMode.from_any(mode)
if parsed_mode == GatherScatterMode.FILL_OR_DROP:
if fill_value is None:
dtype = lax._dtype(operand)
dtype = _dtype(operand)
if dtypes.issubdtype(dtype, np.inexact):
fill_value = np.nan
elif dtypes.issubdtype(dtype, np.signedinteger):
@ -2042,14 +2044,15 @@ def _dynamic_slice_indices(operand, start_indices: Any):
if start_indices.ndim != 1:
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
start_indices = [i for i in start_indices]
return [np.asarray(i + d if i < 0 else i, lax._dtype(i))
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d)
else lax.select(
lax.lt(i, lax._const(i, 0)),
lax.add(i, lax.convert_element_type(core.dimension_as_value(d), lax._dtype(i))),
i)
for i, d in zip(start_indices, operand.shape)]
start_indices = list(start_indices)
result = []
for i, d in zip(start_indices, operand.shape):
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i)
else:
d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i))
result.append(lax.select(i < 0, i + d, i))
return result
# TODO(mattjj): getslice is a prototype for dynamic shapes, revise or remove it

View File

@ -3485,10 +3485,10 @@ def _normalize_index(index, axis_size):
else:
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
_dtype(index))
return lax.select(
lax.lt(index, _lax_const(index, 0)),
lax.add(index, axis_size_val),
index)
if isinstance(index, (int, np.integer)):
return lax.add(index, axis_size_val) if index < 0 else index
else:
return lax.select(index < 0, lax.add(index, axis_size_val), index)
TAKE_ALONG_AXIS_DOC = """