mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
finish gather batching rule, pair w/ @hawkinsp
This commit is contained in:
parent
6dfe2d6e36
commit
cccc0304fd
21
jax/lax.py
21
jax/lax.py
@ -2111,19 +2111,34 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
slice_sizes=slice_sizes), 0
|
||||
|
||||
else:
|
||||
# TODO(mattjj,phawkins): this code is wrong
|
||||
# get rid of scalar index case (noticing our start_indices.ndim is
|
||||
# incremented by one compared to the original user code)
|
||||
if dimension_numbers.index_vector_dim == start_indices.ndim - 1:
|
||||
start_indices = reshape(start_indices, start_indices.shape + (1,))
|
||||
|
||||
# move our batch dimensions to the front to preserve sanity
|
||||
operand = batching.move_dim_to_front(operand, operand_bdim)
|
||||
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
|
||||
|
||||
# Example: user code had start_indices shape (3, 4, 5) and index_vector_dim
|
||||
# of 2, and we have to deal with start_indices shape (7, 3, 4, 5). We
|
||||
# transform that to an index_vector_dim of 3, and a start_indices of shape
|
||||
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
|
||||
# dimension to the front of the ndindex.
|
||||
index_vector_dim = dimension_numbers.index_vector_dim + 1
|
||||
counts = broadcasted_iota(start_indices.dtype, start_indices.shape, 0)
|
||||
start_indices = concatenate([counts, start_indices], index_vector_dim)
|
||||
|
||||
slice_sizes = (1,) + slice_sizes
|
||||
collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
|
||||
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
|
||||
start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map))
|
||||
start_index_map = (0,) + tuple(onp.add(1, dimension_numbers.start_index_map))
|
||||
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=collapsed_slice_dims,
|
||||
start_index_map=start_index_map,
|
||||
index_vector_dim=dimension_numbers.index_vector_dim + 1)
|
||||
index_vector_dim=index_vector_dim)
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
|
||||
|
@ -601,44 +601,43 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
for i in range(idxs.shape[axis])])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# TODO(mattjj,phawkins): finish this batching rule once and for all...
|
||||
# @parameterized.named_parameters(
|
||||
# {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
# jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
|
||||
# dnums, slice_sizes),
|
||||
# "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
|
||||
# dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
|
||||
# "rng": rng, "rng_idx": rng_idx}
|
||||
# for dtype in [onp.float32, onp.int32]
|
||||
# for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
||||
# (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
|
||||
# offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
# index_vector_dim=1), (1,)),
|
||||
# (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
|
||||
# lax.GatherDimensionNumbers(
|
||||
# offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
# index_vector_dim=1), (2,)),
|
||||
# (0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
|
||||
# lax.GatherDimensionNumbers(
|
||||
# offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
# index_vector_dim=1), (1, 3)),
|
||||
# ]
|
||||
# for rng_idx in [jtu.rand_int(max(shape))]
|
||||
# for rng in [jtu.rand_default()])
|
||||
# @jtu.skip_on_devices("tpu") # TODO(b/123834001): re-enable when fixed
|
||||
# def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
|
||||
# slice_sizes, rng, rng_idx):
|
||||
# fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
||||
# operand = rng(shape, dtype)
|
||||
# assert operand.shape[op_axis] == idxs.shape[idxs_axis]
|
||||
# ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
|
||||
# expected = onp.stack([fun(operand[(slice(None),) * op_axis + (i,)],
|
||||
# idxs[(slice(None),) * idxs_axis + (i,)])
|
||||
# for i in range(idxs.shape[idxs_axis])])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
|
||||
dnums, slice_sizes),
|
||||
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
|
||||
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
|
||||
"rng": rng, "rng_idx": rng_idx}
|
||||
for dtype in [onp.float32, onp.int32]
|
||||
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
|
||||
(0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
|
||||
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1,)),
|
||||
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
|
||||
index_vector_dim=1), (2,)),
|
||||
(0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
|
||||
lax.GatherDimensionNumbers(
|
||||
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
|
||||
index_vector_dim=1), (1, 3)),
|
||||
]
|
||||
for rng_idx in [jtu.rand_int(max(shape))]
|
||||
for rng in [jtu.rand_default()])
|
||||
@jtu.skip_on_devices("tpu") # TODO(b/123834001): re-enable when fixed
|
||||
def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
|
||||
slice_sizes, rng, rng_idx):
|
||||
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
||||
operand = rng(shape, dtype)
|
||||
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
|
||||
ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
|
||||
expected = onp.stack([fun(operand[(slice(None),) * op_axis + (i,)],
|
||||
idxs[(slice(None),) * idxs_axis + (i,)])
|
||||
for i in range(idxs.shape[idxs_axis])])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testNumpyIndexing1(self):
|
||||
a = np.arange(2*3*4).reshape((2, 3, 4))
|
||||
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
|
||||
ind = onp.array([[0, 1],
|
||||
[2, 0]])
|
||||
def f(a, ind):
|
||||
@ -648,7 +647,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
assert onp.all(ans == expected)
|
||||
|
||||
def testNumpyIndexing2(self):
|
||||
a = np.arange(2*3*4).reshape((2, 3, 4))
|
||||
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
|
||||
def f(a):
|
||||
inds = np.array([0, 2])
|
||||
return a[:, inds]
|
||||
|
Loading…
x
Reference in New Issue
Block a user