index_take in terms of gather, delete index_untake

(c.f. #304)
This commit is contained in:
Matthew Johnson 2019-02-02 09:22:37 -08:00
parent fec5f47596
commit 9f3060a0e6
3 changed files with 20 additions and 92 deletions

View File

@ -255,54 +255,19 @@ def scatter_add(operand, scatter_indices, updates, dimension_numbers=None):
update_consts=consts, dimension_numbers=dimension_numbers,
updates_shape=updates.shape)
def index_take(src, idxs, axes):
pvals = [_abstractify(arg) for arg in (src,) + idxs]
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(partial(_index_take, axes), pvals)
return index_take_p.bind(src, *idxs, axes=tuple(axes),
input_shape=src.shape, jaxpr=jaxpr, consts=consts)
def _index_take(axes, src, *idxs):
n = idxs[0].shape[0]
slice_sizes = subvals(src.shape, zip(axes, [1] * len(axes)))
def body_fun(i, state):
src, idxs, out = state
src_ind = (dynamic_index_in_dim(x, i, 0, False) for x in idxs)
start_indices = subvals([0] * src.ndim, zip(axes, src_ind))
update = dynamic_slice(src, start_indices, slice_sizes)
update = reshape(update, (1,) + out.shape[1:])
out = dynamic_update_slice(out, update, [i] + [0] * (out.ndim - 1))
return src, idxs, out
out = full_like(src, 0, shape=(n,) + tuple(onp.delete(src.shape, axes)))
init_val = src, idxs, out
_, _, out = fori_loop(0, n, body_fun, init_val)
return out
def index_untake(src, dst, idxs, axes):
pvals = [_abstractify(arg) for arg in (src, dst) + idxs]
jaxpr, _, consts = pe.trace_unwrapped_to_jaxpr(partial(_index_untake, axes), pvals)
return index_untake_p.bind(src, dst, *idxs, axes=tuple(axes),
jaxpr=jaxpr, consts=consts)
def _index_untake(axes, src, dst, *idxs):
n = idxs[0].shape[0]
slice_sizes = subvals(dst.shape, zip(axes, [1] * len(axes)))
def body_fun(i, state):
src, dst, idxs = state
vals = dynamic_slice(src, [i] + [0] * (src.ndim - 1), (1,) + src.shape[1:])
vals = reshape(vals, subvals(dst.shape, zip(axes, [1] * len(axes))))
dst_ind = (dynamic_index_in_dim(x, i, 0, False) for x in idxs)
start_indices = subvals([0] * dst.ndim, zip(axes, dst_ind))
update = add(vals, dynamic_slice(dst, start_indices, slice_sizes))
dst = dynamic_update_slice(dst, update, start_indices)
return src, dst, idxs
init_val = src, dst, idxs
_, dst, _ = fori_loop(0, n, body_fun, init_val)
return dst
indices = concatenate([reshape(i, [i.shape[0], 1]) for i in idxs], 1)
slice_sizes = list(src.shape)
for ax in axes:
slice_sizes[ax] = 1
slice_sizes = tuple(slice_sizes)
offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1))
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=axes,
start_index_map=axes,
index_vector_dim=1)
return gather(src, indices, dimension_numbers=dnums, slice_sizes=slice_sizes)
def transpose(operand, permutation):
permutation = tuple(permutation)

View File

@ -1665,6 +1665,8 @@ def take_along_axis(arr, indices, axis):
elif ndim(arr) == 1:
return lax.index_take(arr, (indices,), (0,))
else:
# TODO(mattjj): if we lower directly to lax.gather here, we might be able to
# avoid the reshape on the output.
all_indices = [lax.broadcasted_iota(_dtype(indices), shape(indices), i)
for i in range(ndim(arr))]
all_indices[axis] = indices
@ -1737,6 +1739,8 @@ def _rewriting_take(arr, idx, axis=0):
elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
canonical_idx = _canonicalize_tuple_index(arr, idx)
result, axis = arr, 0
# TODO(mattjj): could generate a single HLO here, rather than one for each
# elt in canonical idx. For example, x[0, :, 0] generates three HLOs now.
for elt in (elt for elt in canonical_idx if elt is not None):
result = _rewriting_take(result, elt, axis=axis)
axis += isinstance(elt, slice) # advance axis index if not eliminated
@ -1765,6 +1769,8 @@ def _rewriting_take(arr, idx, axis=0):
idx = [idx]
flat_idx = tuple([mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx)])
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
# eliminate the reshape here.
out = lax.index_take(arr, flat_idx, tuple(range(len(idx))))
return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):])
@ -1782,6 +1788,8 @@ def _rewriting_take(arr, idx, axis=0):
flat_idx = tuple(mod(ravel(x), arr_sliced.shape[i])
for i, x in zip(axes, idx_advanced))
# TODO(mattjj): if we instead lower directly to lax.gather, we can probably
# eliminate the reshape here.
out = lax.index_take(arr_sliced, flat_idx, axes)
shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes))
out = lax.reshape(out, idx_advanced[0].shape + shape_suffix)

View File

@ -1344,28 +1344,6 @@ class LaxTest(jtu.JaxTestCase):
fun = lambda src, idxs: lax.index_take(src, idxs, axes)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dst_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(dst_shape, dtype), idxs, axes),
"dst_shape": dst_shape, "dtype": dtype, "idxs": idxs, "axes": axes,
"rng": rng}
for dtype in default_dtypes
for dst_shape, idxs, axes in [
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
]
for rng in [jtu.rand_default()]))
def testIndexUntake(self, dst_shape, dtype, idxs, axes, rng):
# We call lax.index_take to get the shapes right
src_shape = lax.index_take(rng(dst_shape, dtype), idxs, axes).shape
ridxs = lambda: tuple(rng(e.shape, e.dtype) for e in idxs)
args_maker = lambda: [rng(src_shape, dtype), rng(dst_shape, dtype), ridxs()]
fun = lambda src, dst, idxs: lax.index_untake(src, dst, idxs, axes)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
@ -2107,29 +2085,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
index_take = lambda src: lax.index_take(src, idxs, axes)
check_grads(index_take, (src,), 2, 1e-2, 1e-2, 1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dst_shape={}_idxs={}_axes={}".format(
jtu.format_shape_dtype_string(dst_shape, dtype), idxs, axes),
"dst_shape": dst_shape, "dtype": dtype, "idxs": idxs, "axes": axes,
"rng": rng}
for dtype in float_dtypes
for dst_shape, idxs, axes in [
[(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (onp.array([-1, -2]),), (0,)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
[(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
]
for rng in [jtu.rand_default()]))
def testIndexUntakeGrad(self, dst_shape, dtype, idxs, axes, rng):
# We call lax.index_take to get the shapes right
src_shape = lax.index_take(rng(dst_shape, dtype), idxs, axes).shape
idxs = tuple(rng(e.shape, e.dtype) for e in idxs)
src = rng(src_shape, dtype)
dst = rng(dst_shape, dtype)
index_untake = lambda src, dst: lax.index_untake(src, dst, idxs, axes)
check_grads(index_untake, (src, dst), 2, 1e-2, 1e-2, 1e-2)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,