Fix batching rule for gather where the batch dimension has size 0.

This commit is contained in:
Peter Hawkins 2020-12-14 22:27:34 -05:00
parent 34bc6ca987
commit 308e7f95b0
2 changed files with 9 additions and 2 deletions

View File

@ -4352,7 +4352,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
counts = broadcasted_iota(start_indices.dtype, tuple(count_shape), 0)
start_indices = concatenate([counts, start_indices], len(count_shape) - 1)
slice_sizes = (1,) + slice_sizes
slice_sizes = (_min(operand.shape[0], 1),) + slice_sizes
collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map))

View File

@ -74,7 +74,12 @@ class LaxVmapTest(jtu.JaxTestCase):
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
args_slice = args_slicer(args, bdims)
ans = api.vmap(op, bdims)(*args)
expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
if bdim_size == 0:
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
out = op(*args)
expected = np.zeros((0,) + out.shape, out.dtype)
else:
expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
@parameterized.named_parameters(itertools.chain.from_iterable(
@ -642,6 +647,8 @@ class LaxVmapTest(jtu.JaxTestCase):
for bdims in all_bdims(shape, idxs.shape)))
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
jtu.rand_default(self.rng()))
self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
jtu.rand_default(self.rng()))