Implement lax.reduce() batching rule case where batch dimensions differ between operands.

This commit is contained in:
Peter Hawkins 2021-11-03 09:36:31 -04:00
parent 31c2da70e6
commit ba4db33f48
2 changed files with 39 additions and 15 deletions

View File

@ -5596,19 +5596,15 @@ def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
if all(init_value_bdim is batching.not_mapped
for init_value_bdim in init_value_bdims):
# Assume all batch dims are the same for each of the operands
# TODO(sharadmv): handle the case when batch dims are different across
# operands or when some are unbatched
if not all(operand_bdim is not batching.not_mapped for operand_bdim in operand_bdims):
raise NotImplementedError
if not all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims):
raise NotImplementedError
operand_bdim = operand_bdims[0]
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
new_operand_bdims = [new_operand_bdim] * num_operands
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
if ax is not None)
operands = [batching.bdim_at_front(arg, bdim, size)
for arg, bdim in zip(operands, operand_bdims)]
new_dimensions = [d + 1 for d in dimensions]
new_operand_bdims = [0] * num_operands
return reduce_p.bind(*(operands + init_values),
computation=computation, dimensions=tuple(new_dimensions),
computation=computation,
dimensions=tuple(new_dimensions),
consts=consts,
jaxpr=jaxpr), new_operand_bdims
else:

View File

@ -69,7 +69,7 @@ def args_slicer(args, bdims):
class LaxVmapTest(jtu.JaxTestCase):
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
rtol=None, atol=None):
rtol=None, atol=None, multiple_results=False):
batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
args_slice = args_slicer(args, bdims)
@ -77,9 +77,16 @@ class LaxVmapTest(jtu.JaxTestCase):
if bdim_size == 0:
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
out = op(*args)
expected = np.zeros((0,) + out.shape, out.dtype)
if not multiple_results:
expected = np.zeros((0,) + out.shape, out.dtype)
else:
expected = [np.zeros((0,) + o.shape, o.dtype) for o in out]
else:
expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
outs = [op(*args_slice(i)) for i in range(bdim_size)]
if not multiple_results:
expected = np.stack(outs)
else:
expected = [np.stack(xs) for xs in zip(*outs)]
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
@parameterized.named_parameters(itertools.chain.from_iterable(
@ -481,6 +488,27 @@ class LaxVmapTest(jtu.JaxTestCase):
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_reducedims={}_bdims={}"
.format(jtu.format_shape_dtype_string(shape, dtype), dims, bdims),
"shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims}
for dtype in default_dtypes
for shape, dims in [
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
]
for bdims in all_bdims(shape, shape)))
def testVariadicReduce(self, shape, dtype, dims, bdims):
def op(a, b):
x1, y1 = a
x2, y2 = b
return x1 + x2, y1 * y2
rng = jtu.rand_small(self.rng())
init_val = tuple(np.asarray([0, 1], dtype=dtype))
fun = lambda x, y: lax.reduce((x, y), init_val, op, dims)
self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng,
multiple_results=True)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,