mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Implement lax.reduce() batching rule case where batch dimensions differ between operands.
This commit is contained in:
parent
31c2da70e6
commit
ba4db33f48
@ -5596,19 +5596,15 @@ def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
|
||||
operand_bdims, init_value_bdims = split_list(batch_dims, [num_operands])
|
||||
if all(init_value_bdim is batching.not_mapped
|
||||
for init_value_bdim in init_value_bdims):
|
||||
# Assume all batch dims are the same for each of the operands
|
||||
# TODO(sharadmv): handle the case when batch dims are different across
|
||||
# operands or when some are unbatched
|
||||
if not all(operand_bdim is not batching.not_mapped for operand_bdim in operand_bdims):
|
||||
raise NotImplementedError
|
||||
if not all(operand_bdim == operand_bdims[0] for operand_bdim in operand_bdims):
|
||||
raise NotImplementedError
|
||||
operand_bdim = operand_bdims[0]
|
||||
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
|
||||
new_operand_bdim = operand_bdim - int(np.sum(np.less(dimensions, operand_bdim)))
|
||||
new_operand_bdims = [new_operand_bdim] * num_operands
|
||||
size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims)
|
||||
if ax is not None)
|
||||
operands = [batching.bdim_at_front(arg, bdim, size)
|
||||
for arg, bdim in zip(operands, operand_bdims)]
|
||||
new_dimensions = [d + 1 for d in dimensions]
|
||||
new_operand_bdims = [0] * num_operands
|
||||
return reduce_p.bind(*(operands + init_values),
|
||||
computation=computation, dimensions=tuple(new_dimensions),
|
||||
computation=computation,
|
||||
dimensions=tuple(new_dimensions),
|
||||
consts=consts,
|
||||
jaxpr=jaxpr), new_operand_bdims
|
||||
else:
|
||||
|
@ -69,7 +69,7 @@ def args_slicer(args, bdims):
|
||||
class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
|
||||
rtol=None, atol=None):
|
||||
rtol=None, atol=None, multiple_results=False):
|
||||
batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
|
||||
args_slice = args_slicer(args, bdims)
|
||||
@ -77,9 +77,16 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
if bdim_size == 0:
|
||||
args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
out = op(*args)
|
||||
expected = np.zeros((0,) + out.shape, out.dtype)
|
||||
if not multiple_results:
|
||||
expected = np.zeros((0,) + out.shape, out.dtype)
|
||||
else:
|
||||
expected = [np.zeros((0,) + o.shape, o.dtype) for o in out]
|
||||
else:
|
||||
expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
|
||||
outs = [op(*args_slice(i)) for i in range(bdim_size)]
|
||||
if not multiple_results:
|
||||
expected = np.stack(outs)
|
||||
else:
|
||||
expected = [np.stack(xs) for xs in zip(*outs)]
|
||||
self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
@ -481,6 +488,27 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_reducedims={}_bdims={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, dtype), dims, bdims),
|
||||
"shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims}
|
||||
for dtype in default_dtypes
|
||||
for shape, dims in [
|
||||
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
||||
]
|
||||
for bdims in all_bdims(shape, shape)))
|
||||
def testVariadicReduce(self, shape, dtype, dims, bdims):
|
||||
def op(a, b):
|
||||
x1, y1 = a
|
||||
x2, y2 = b
|
||||
return x1 + x2, y1 * y2
|
||||
rng = jtu.rand_small(self.rng())
|
||||
init_val = tuple(np.asarray([0, 1], dtype=dtype))
|
||||
fun = lambda x, y: lax.reduce((x, y), init_val, op, dims)
|
||||
self._CheckBatching(fun, 5, bdims, (shape, shape), (dtype, dtype), rng,
|
||||
multiple_results=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
|
||||
|
Loading…
x
Reference in New Issue
Block a user