mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
fec5f47596
commit
9f3060a0e6
59
jax/lax.py
59
jax/lax.py
@ -255,54 +255,19 @@ def scatter_add(operand, scatter_indices, updates, dimension_numbers=None):
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
updates_shape=updates.shape)
|
||||
|
||||
|
||||
def index_take(src, idxs, axes):
|
||||
pvals = [_abstractify(arg) for arg in (src,) + idxs]
|
||||
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(partial(_index_take, axes), pvals)
|
||||
return index_take_p.bind(src, *idxs, axes=tuple(axes),
|
||||
input_shape=src.shape, jaxpr=jaxpr, consts=consts)
|
||||
|
||||
def _index_take(axes, src, *idxs):
|
||||
n = idxs[0].shape[0]
|
||||
slice_sizes = subvals(src.shape, zip(axes, [1] * len(axes)))
|
||||
|
||||
def body_fun(i, state):
|
||||
src, idxs, out = state
|
||||
src_ind = (dynamic_index_in_dim(x, i, 0, False) for x in idxs)
|
||||
start_indices = subvals([0] * src.ndim, zip(axes, src_ind))
|
||||
update = dynamic_slice(src, start_indices, slice_sizes)
|
||||
update = reshape(update, (1,) + out.shape[1:])
|
||||
out = dynamic_update_slice(out, update, [i] + [0] * (out.ndim - 1))
|
||||
return src, idxs, out
|
||||
|
||||
out = full_like(src, 0, shape=(n,) + tuple(onp.delete(src.shape, axes)))
|
||||
init_val = src, idxs, out
|
||||
_, _, out = fori_loop(0, n, body_fun, init_val)
|
||||
return out
|
||||
|
||||
def index_untake(src, dst, idxs, axes):
|
||||
pvals = [_abstractify(arg) for arg in (src, dst) + idxs]
|
||||
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(partial(_index_untake, axes), pvals)
|
||||
return index_untake_p.bind(src, dst, *idxs, axes=tuple(axes),
|
||||
jaxpr=jaxpr, consts=consts)
|
||||
|
||||
def _index_untake(axes, src, dst, *idxs):
|
||||
n = idxs[0].shape[0]
|
||||
slice_sizes = subvals(dst.shape, zip(axes, [1] * len(axes)))
|
||||
|
||||
def body_fun(i, state):
|
||||
src, dst, idxs = state
|
||||
vals = dynamic_slice(src, [i] + [0] * (src.ndim - 1), (1,) + src.shape[1:])
|
||||
vals = reshape(vals, subvals(dst.shape, zip(axes, [1] * len(axes))))
|
||||
dst_ind = (dynamic_index_in_dim(x, i, 0, False) for x in idxs)
|
||||
start_indices = subvals([0] * dst.ndim, zip(axes, dst_ind))
|
||||
update = add(vals, dynamic_slice(dst, start_indices, slice_sizes))
|
||||
dst = dynamic_update_slice(dst, update, start_indices)
|
||||
return src, dst, idxs
|
||||
|
||||
init_val = src, dst, idxs
|
||||
_, dst, _ = fori_loop(0, n, body_fun, init_val)
|
||||
return dst
|
||||
indices = concatenate([reshape(i, [i.shape[0], 1]) for i in idxs], 1)
|
||||
slice_sizes = list(src.shape)
|
||||
for ax in axes:
|
||||
slice_sizes[ax] = 1
|
||||
slice_sizes = tuple(slice_sizes)
|
||||
offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1))
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=axes,
|
||||
start_index_map=axes,
|
||||
index_vector_dim=1)
|
||||
return gather(src, indices, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
||||
|
||||
def transpose(operand, permutation):
|
||||
permutation = tuple(permutation)
|
||||
|
@ -1665,6 +1665,8 @@ def take_along_axis(arr, indices, axis):
|
||||
elif ndim(arr) == 1:
|
||||
return lax.index_take(arr, (indices,), (0,))
|
||||
else:
|
||||
# TODO(mattjj): if we lower directly to lax.gather here, we might be able to
|
||||
# avoid the reshape on the output.
|
||||
all_indices = [lax.broadcasted_iota(_dtype(indices), shape(indices), i)
|
||||
for i in range(ndim(arr))]
|
||||
all_indices[axis] = indices
|
||||
@ -1737,6 +1739,8 @@ def _rewriting_take(arr, idx, axis=0):
|
||||
elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
|
||||
canonical_idx = _canonicalize_tuple_index(arr, idx)
|
||||
result, axis = arr, 0
|
||||
# TODO(mattjj): could generate a single HLO here, rather than one for each
|
||||
# elt in canonical idx. For example, x[0, :, 0] generates three HLOs now.
|
||||
for elt in (elt for elt in canonical_idx if elt is not None):
|
||||
result = _rewriting_take(result, elt, axis=axis)
|
||||
axis += isinstance(elt, slice) # advance axis index if not eliminated
|
||||
@ -1765,6 +1769,8 @@ def _rewriting_take(arr, idx, axis=0):
|
||||
idx = [idx]
|
||||
|
||||
flat_idx = tuple([mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx)])
|
||||
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
|
||||
# eliminate the reshape here.
|
||||
out = lax.index_take(arr, flat_idx, tuple(range(len(idx))))
|
||||
return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):])
|
||||
|
||||
@ -1782,6 +1788,8 @@ def _rewriting_take(arr, idx, axis=0):
|
||||
|
||||
flat_idx = tuple(mod(ravel(x), arr_sliced.shape[i])
|
||||
for i, x in zip(axes, idx_advanced))
|
||||
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
|
||||
# eliminate the reshape here.
|
||||
out = lax.index_take(arr_sliced, flat_idx, axes)
|
||||
shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes))
|
||||
out = lax.reshape(out, idx_advanced[0].shape + shape_suffix)
|
||||
|
@ -1344,28 +1344,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dst_shape={}_idxs={}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(dst_shape, dtype), idxs, axes),
|
||||
"dst_shape": dst_shape, "dtype": dtype, "idxs": idxs, "axes": axes,
|
||||
"rng": rng}
|
||||
for dtype in default_dtypes
|
||||
for dst_shape, idxs, axes in [
|
||||
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
|
||||
]
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testIndexUntake(self, dst_shape, dtype, idxs, axes, rng):
|
||||
# We call lax.index_take to get the shapes right
|
||||
src_shape = lax.index_take(rng(dst_shape, dtype), idxs, axes).shape
|
||||
ridxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
|
||||
args_maker = lambda: [rng(src_shape, dtype), rng(dst_shape, dtype), ridxs()]
|
||||
fun = lambda src, dst, idxs: lax.index_untake(src, dst, idxs, axes)
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
|
||||
@ -2107,29 +2085,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
index_take = lambda src: lax.index_take(src, idxs, axes)
|
||||
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dst_shape={}_idxs={}_axes={}".format(
|
||||
jtu.format_shape_dtype_string(dst_shape, dtype), idxs, axes),
|
||||
"dst_shape": dst_shape, "dtype": dtype, "idxs": idxs, "axes": axes,
|
||||
"rng": rng}
|
||||
for dtype in float_dtypes
|
||||
for dst_shape, idxs, axes in [
|
||||
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
|
||||
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
|
||||
]
|
||||
for rng in [jtu.rand_default()]))
|
||||
def testIndexUntakeGrad(self, dst_shape, dtype, idxs, axes, rng):
|
||||
# We call lax.index_take to get the shapes right
|
||||
src_shape = lax.index_take(rng(dst_shape, dtype), idxs, axes).shape
|
||||
|
||||
idxs = tuple(rng(e.shape, e.dtype) for e in idxs)
|
||||
src = rng(src_shape, dtype)
|
||||
dst = rng(dst_shape, dtype)
|
||||
index_untake = lambda src, dst: lax.index_untake(src, dst, idxs, axes)
|
||||
check_grads(index_untake, (src, dst), 2, 1e-2, 1e-2, 1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
|
||||
|
Loading…
x
Reference in New Issue
Block a user