mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
simplify slicing jaxprs a little
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
This commit is contained in:
parent
7cd81ca1a8
commit
42dd7cac43
@ -539,5 +539,13 @@ def bench_remat_eager_retracing_overheads_static_argnums(state):
|
||||
y.block_until_ready()
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_slicing_compilation(state):
|
||||
x = jnp.arange(3)
|
||||
while state:
|
||||
jax.jit(lambda x: (x[0], x[1], x[2])).lower(x).compile()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -50,6 +50,8 @@ Shape = core.Shape
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
|
||||
def slice(operand: Array, start_indices: Sequence[int],
|
||||
limit_indices: Sequence[int],
|
||||
@ -263,7 +265,7 @@ def gather(operand: Array, start_indices: Array,
|
||||
parsed_mode = GatherScatterMode.from_any(mode)
|
||||
if parsed_mode == GatherScatterMode.FILL_OR_DROP:
|
||||
if fill_value is None:
|
||||
dtype = lax._dtype(operand)
|
||||
dtype = _dtype(operand)
|
||||
if dtypes.issubdtype(dtype, np.inexact):
|
||||
fill_value = np.nan
|
||||
elif dtypes.issubdtype(dtype, np.signedinteger):
|
||||
@ -2042,14 +2044,15 @@ def _dynamic_slice_indices(operand, start_indices: Any):
|
||||
if start_indices.ndim != 1:
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape))
|
||||
start_indices = [i for i in start_indices]
|
||||
return [np.asarray(i + d if i < 0 else i, lax._dtype(i))
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d)
|
||||
else lax.select(
|
||||
lax.lt(i, lax._const(i, 0)),
|
||||
lax.add(i, lax.convert_element_type(core.dimension_as_value(d), lax._dtype(i))),
|
||||
i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
start_indices = list(start_indices)
|
||||
result = []
|
||||
for i, d in zip(start_indices, operand.shape):
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
|
||||
result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i)
|
||||
else:
|
||||
d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i))
|
||||
result.append(lax.select(i < 0, i + d, i))
|
||||
return result
|
||||
|
||||
|
||||
# TODO(mattjj): getslice is a prototype for dynamic shapes, revise or remove it
|
||||
|
@ -3485,10 +3485,10 @@ def _normalize_index(index, axis_size):
|
||||
else:
|
||||
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
|
||||
_dtype(index))
|
||||
return lax.select(
|
||||
lax.lt(index, _lax_const(index, 0)),
|
||||
lax.add(index, axis_size_val),
|
||||
index)
|
||||
if isinstance(index, (int, np.integer)):
|
||||
return lax.add(index, axis_size_val) if index < 0 else index
|
||||
else:
|
||||
return lax.select(index < 0, lax.add(index, axis_size_val), index)
|
||||
|
||||
|
||||
TAKE_ALONG_AXIS_DOC = """
|
||||
|
Loading…
x
Reference in New Issue
Block a user