finish gather batching rule, pair w/ @hawkinsp

This commit is contained in:
Matthew Johnson 2019-02-11 09:28:21 -08:00
parent 6dfe2d6e36
commit cccc0304fd
2 changed files with 54 additions and 40 deletions

View File

@ -2111,19 +2111,34 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
slice_sizes=slice_sizes), 0
else:
# TODO(mattjj,phawkins): this code is wrong
# get rid of scalar index case (noticing our start_indices.ndim is
# incremented by one compared to the original user code)
if dimension_numbers.index_vector_dim == start_indices.ndim - 1:
start_indices = reshape(start_indices, start_indices.shape + (1,))
# move our batch dimensions to the front to preserve sanity
operand = batching.move_dim_to_front(operand, operand_bdim)
start_indices = batching.move_dim_to_front(start_indices, start_indices_bdim)
# Example: user code had start_indices shape (3, 4, 5) and index_vector_dim
# of 2, and we have to deal with start_indices shape (7, 3, 4, 5). We
# transform that to an index_vector_dim of 3, and a start_indices of shape
# (7, 3, 4, 6) where we concatenated an iota that counts along our batch
# dimension to the front of the ndindex.
index_vector_dim = dimension_numbers.index_vector_dim + 1
counts = broadcasted_iota(start_indices.dtype, start_indices.shape, 0)
start_indices = concatenate([counts, start_indices], index_vector_dim)
slice_sizes = (1,) + slice_sizes
collapsed_slice_dims = (0,) + tuple(onp.add(1, dimension_numbers.collapsed_slice_dims))
offset_dims = tuple(onp.add(1, dimension_numbers.offset_dims))
start_index_map = tuple(onp.add(1, dimension_numbers.start_index_map))
start_index_map = (0,) + tuple(onp.add(1, dimension_numbers.start_index_map))
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=collapsed_slice_dims,
start_index_map=start_index_map,
index_vector_dim=dimension_numbers.index_vector_dim + 1)
index_vector_dim=index_vector_dim)
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0

View File

@ -601,44 +601,43 @@ class BatchingTest(jtu.JaxTestCase):
for i in range(idxs.shape[axis])])
self.assertAllClose(ans, expected, check_dtypes=False)
# TODO(mattjj,phawkins): finish this batching rule once and for all...
# @parameterized.named_parameters(
# {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
# jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
# dnums, slice_sizes),
# "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
# dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
# "rng": rng, "rng_idx": rng_idx}
# for dtype in [onp.float32, onp.int32]
# for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
# (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
# offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
# index_vector_dim=1), (1,)),
# (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
# lax.GatherDimensionNumbers(
# offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
# index_vector_dim=1), (2,)),
# (0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
# lax.GatherDimensionNumbers(
# offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
# index_vector_dim=1), (1, 3)),
# ]
# for rng_idx in [jtu.rand_int(max(shape))]
# for rng in [jtu.rand_default()])
# @jtu.skip_on_devices("tpu") # TODO(b/123834001): re-enable when fixed
# def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
# slice_sizes, rng, rng_idx):
# fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
# operand = rng(shape, dtype)
# assert operand.shape[op_axis] == idxs.shape[idxs_axis]
# ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
# expected = onp.stack([fun(operand[(slice(None),) * op_axis + (i,)],
# idxs[(slice(None),) * idxs_axis + (i,)])
# for i in range(idxs.shape[idxs_axis])])
# self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
dnums, slice_sizes),
"op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
"rng": rng, "rng_idx": rng_idx}
for dtype in [onp.float32, onp.int32]
for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
(0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1,)),
(1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
index_vector_dim=1), (2,)),
(0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
index_vector_dim=1), (1, 3)),
]
for rng_idx in [jtu.rand_int(max(shape))]
for rng in [jtu.rand_default()])
@jtu.skip_on_devices("tpu") # TODO(b/123834001): re-enable when fixed
def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
slice_sizes, rng, rng_idx):
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
operand = rng(shape, dtype)
assert operand.shape[op_axis] == idxs.shape[idxs_axis]
ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
expected = onp.stack([fun(operand[(slice(None),) * op_axis + (i,)],
idxs[(slice(None),) * idxs_axis + (i,)])
for i in range(idxs.shape[idxs_axis])])
self.assertAllClose(ans, expected, check_dtypes=False)
def testNumpyIndexing1(self):
a = np.arange(2*3*4).reshape((2, 3, 4))
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
ind = onp.array([[0, 1],
[2, 0]])
def f(a, ind):
@ -648,7 +647,7 @@ class BatchingTest(jtu.JaxTestCase):
assert onp.all(ans == expected)
def testNumpyIndexing2(self):
a = np.arange(2*3*4).reshape((2, 3, 4))
a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
def f(a):
inds = np.array([0, 2])
return a[:, inds]