Fix batching formula of xmap

Turns out that once you insert multiple dimensions things become much
more tricky than in the case of batching a one-dimensional map. Also
strenghten our tests to make sure we don't depend too much on the
semantics of the einsum batching rule.
This commit is contained in:
Adam Paszke 2021-03-04 18:08:45 +00:00
parent 5ed0633d54
commit 6884f21b60
2 changed files with 47 additions and 17 deletions

View File

@ -686,17 +686,24 @@ def _batch_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params
for d, in_axis in zip(dims, params['in_axes']))
f, mapped_dims_out = batching.batch_subtrace(f, self.main, mapped_dims_in)
out_axes_thunk = params['out_axes_thunk']
def axis_after_insertion(axis, inserted_named_axes):
for inserted_axis in sorted(inserted_named_axes.values()):
if inserted_axis >= axis:
break
axis += 1
return axis
# NOTE: This assumes that the choice of the dimensions over which outputs
# are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(
fmap_dims(out_axes, lambda a: a + (d is not not_mapped and d <= a))
out_axes if d is not_mapped else
fmap_dims(out_axes, lambda a, nd=axis_after_insertion(d, out_axes): a + (nd <= a))
for out_axes, d in zip(out_axes_thunk(), mapped_dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
vals_out = primitive.bind(f, *vals, **new_params)
dims_out = tuple(d if d is not_mapped else d + sum(a < d for a in out_axes.values())
dims_out = tuple(d if d is not_mapped else axis_after_insertion(d, out_axes)
for d, out_axes in zip(mapped_dims_out(), out_axes_thunk()))
return [batching.BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
batching.BatchTrace.process_xmap = _batch_trace_process_xmap # type: ignore

View File

@ -35,13 +35,14 @@ import jax.scipy as jscipy
from jax import test_util as jtu
from jax import vmap
from jax import lax
from jax import core
from jax.core import NamedShape
from jax.experimental.maps import Mesh, mesh, xmap
from jax.lib import xla_bridge
from jax._src.util import curry, unzip2, split_list, prod
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lax.parallel import pgather
from jax.interpreters import pxla
from jax.interpreters import batching, pxla
from jax.config import config
config.parse_flags_with_absl()
@ -112,6 +113,23 @@ def powerset(s):
s = list(s)
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
# -------------------- vmap test helpers --------------------
ensure_bdim_p = core.Primitive('ensure_bdim')
ensure_bdim_p.def_abstract_eval(lambda x, **kwargs: core.raise_to_shaped(x))
def _ensure_bdim_batcher(frame, vals_in, dims_in, axis_name, bdim):
v, = vals_in
d, = dims_in
assert d is not batching.not_mapped
return jnp.moveaxis(v, d, bdim), bdim
batching.collective_rules[ensure_bdim_p] = _ensure_bdim_batcher
batching.primitive_batchers[ensure_bdim_p] = lambda v, d: (v[0], d[0])
core.axis_substitution_rules[ensure_bdim_p] = partial(jax._src.lax.parallel._subst_all_names_in_param,
'axis_name')
def ensure_bdim(x, axis_name, bdim):
return ensure_bdim_p.bind(x, axis_name=(axis_name,), bdim=bdim)
# -------------------- Axis resources generation --------------------
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
@ -444,21 +462,26 @@ class XMapTest(XMapTestCase):
for vmap_dim_y in [*range(2 + len(xmap_dim_y)), None]:
if vmap_dim_x is None and vmap_dim_y is None:
continue
for vmap_dim_z in range(2 + len(xmap_axes)):
for vmap_as_xmap in [False, True]:
yield {"testcase_name":
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
f"vout={vmap_dim_z}_vmap_as_xmap={vmap_as_xmap}",
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
"xmap_out_axes": xmap_dim_z,
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
"vmap_out_axes": vmap_dim_z,
"vmap_as_xmap": vmap_as_xmap}
for vmap_dim_result in range(3):
for vmap_dim_z in range(2 + len(xmap_axes)):
for vmap_as_xmap in [False, True]:
yield {"testcase_name":
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
f"vout={vmap_dim_z}_vresult={vmap_dim_result}_vmap_as_xmap={vmap_as_xmap}",
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
"xmap_out_axes": xmap_dim_z,
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
"vmap_out_axes": vmap_dim_z,
"vmap_result_axis": vmap_dim_result,
"vmap_as_xmap": vmap_as_xmap}
@parameterized.named_parameters(jtu.cases_from_list(VmapOfXmapCases()))
@ignore_xmap_warning()
def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_as_xmap):
def testNestedMap(self,
xmap_in_axes, xmap_out_axes,
vmap_in_axes, vmap_out_axes, vmap_result_axis,
vmap_as_xmap):
"""Test various vmap(xmap) and xmap(xmap) combinations.
The outer map always introduces a single dimension, the inner map introduces one or two.
@ -474,7 +497,7 @@ class XMapTest(XMapTestCase):
xind = ['n', 'k']
yind = ['k', 'm']
zind = ['n', 'm']
f = partial(jnp.einsum, 'nk,km->nm')
f = lambda x, y: ensure_bdim(jnp.einsum('nk,km->nm', x, y), 'v', vmap_result_axis)
for pos, name in sorted(xin_x.items()):
xshape.insert(pos, xmap_sizes[name])
@ -501,7 +524,7 @@ class XMapTest(XMapTestCase):
{vin_y: 'v'} if vin_y is not None else {}),
out_axes={vmap_out_axes: 'v'})
else:
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes)
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes, axis_name='v')
fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")