Make it possible to vmap xmapped functions

Or perhaps more importantly make it possible to nest xmaps that don't
specify any `axis_resources`. The math is a little tricky, so I've added
a fairly strong test that enumerates a wide range of potential ways of
interleaving vmapped and xmapped axes in both inputs and the output.
Thanks to that, I've actually caught one very subtle bug in the dynamic
tracing rule for xmap (sorting by dimension names instead of positional
axes).
This commit is contained in:
Adam Paszke 2021-01-18 10:42:06 +00:00
parent 040d268cf8
commit 6d2b307ced
2 changed files with 122 additions and 21 deletions

View File

@ -31,9 +31,10 @@ from ..api_util import flatten_fun_nokwargs, flatten_axes
from ..interpreters import partial_eval as pe
from ..interpreters import pxla
from ..interpreters import xla
from ..interpreters import batching
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from .._src.util import safe_map, safe_zip, HashableFunction
from .._src.util import safe_map, safe_zip, HashableFunction, as_hashable_function, unzip2
from .._src.lax.parallel import _axis_index_translation_rule
map, unsafe_map = safe_map, map
@ -252,7 +253,7 @@ def xmap(fun: Callable,
axis_sizes = _get_axis_sizes(args_flat, in_axes_flat)
out_flat = xmap_p.bind(
fun_flat, *args_flat,
name=fun.__name__,
name=getattr(fun, '__name__', '<unnamed function>'),
in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk,
axis_sizes=FrozenDict(axis_sizes),
@ -305,7 +306,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
*in_avals)
else:
# We have to trace again, because `f` is a linear function, so we can't just return it.
final_jaxpr, _, final_consts = pe.trace_to_jaxpr_final(f, in_avals)
final_jaxpr, out_avals, final_consts = pe.trace_to_jaxpr_final(f, in_avals)
return core.jaxpr_as_fun(core.ClosedJaxpr(final_jaxpr, final_consts))
class EvaluationPlan(NamedTuple):
@ -403,6 +404,39 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
return out_tracers
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
def _batch_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
not_mapped = batching.not_mapped
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
assert primitive is xmap_p
if all(dim is not_mapped for dim in dims):
return primitive.bind(f, *vals, **params)
else:
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
def fmap_dims(axes, f):
return AxisNamePos((name, f(axis)) for name, axis in axes.items())
new_in_axes = tuple(
fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
for d, in_axes in zip(dims, params['in_axes']))
new_dims = tuple(
d if d is not_mapped else d - sum(a < d for a in in_axis.values())
for d, in_axis in zip(dims, params['in_axes']))
f, dims_out = batching.batch_subtrace(f, self.main, new_dims)
out_axes_thunk = params['out_axes_thunk']
# NOTE: This assumes that the choice of the dimensions over which outputs
# are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(
fmap_dims(out_axes, lambda a: a + (d is not not_mapped and d <= a))
for out_axes, d in zip(out_axes_thunk(), dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
vals_out = primitive.bind(f, *vals, **new_params)
dims_out = tuple(d if d is not_mapped else d + sum(a < d for a in out_axes.values())
for d, out_axes in zip(dims_out(), out_axes_thunk()))
return [batching.BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
batching.BatchTrace.process_xmap = _batch_trace_process_xmap # type: ignore
# -------- nested xmap handling --------
@ -537,7 +571,7 @@ def _delete_aval_axes(aval, axes: AxisNamePos):
def _insert_aval_axes(aval, axes: AxisNamePos, axis_sizes):
assert isinstance(aval, core.ShapedArray)
shape = list(aval.shape)
for name, axis in sorted(axes.items()):
for name, axis in sorted(axes.items(), key=lambda x: x[1]):
shape.insert(axis, axis_sizes[name])
return core.ShapedArray(tuple(shape), aval.dtype)

View File

@ -18,6 +18,7 @@ import functools
import itertools
import os
import unittest
from itertools import product, permutations
from unittest import SkipTest, skip, skipIf
import numpy as np
@ -62,7 +63,6 @@ def tearDownModule():
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
@curry
def with_mesh(named_shape, f):
if not named_shape:
@ -178,22 +178,6 @@ class XMapTest(jtu.JaxTestCase):
python_should_be_executing = False
fm(x)
@skip("Need to implement vmap(xmap)")
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testNestedVectorize(self):
@partial(xmap, in_axes=[None, 'a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})
def f(x):
y = x * 2
@partial(xmap, in_axes=['b', ...], out_axes=[None, 'b', ...])
def h(y):
return jnp.sin(y)
return h(y)
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
self.assertAllClose(f(x),
jnp.sin(x * 2).transpose((1, 2, 0)))
@skip("Need to implement vmap(xmap)")
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 3)])
@ -255,6 +239,89 @@ class XMapTest(jtu.JaxTestCase):
self.assertAllClose(f_mapped(x, x), expected)
run_test()
def VmapOfXmapCases():
xmap_in_axes = ([{}] +
[{i: 'x'} for i in range(3)] +
[{i: 'x', j: 'y'} for i in range(4) for j in range(4) if i != j])
for xmap_dim_x, xmap_dim_y in product(xmap_in_axes, repeat=2):
xmap_axes = sorted(set(xmap_dim_x.values()) | set(xmap_dim_y.values()))
num_axes = len(xmap_axes)
if xmap_axes is None:
continue
xmap_out_axes = [dict(zip(dims, xmap_axes))
for dims in permutations(range(2 + num_axes), num_axes)]
for xmap_dim_z in xmap_out_axes:
for vmap_dim_x in [*range(2 + len(xmap_dim_x)), None]:
for vmap_dim_y in [*range(2 + len(xmap_dim_y)), None]:
if vmap_dim_x is None and vmap_dim_y is None:
continue
for vmap_dim_z in range(2 + len(xmap_axes)):
for vmap_as_xmap in [False, True]:
yield {"testcase_name":
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
f"vout={vmap_dim_z}_vmap_as_xmap={vmap_as_xmap}",
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
"xmap_out_axes": xmap_dim_z,
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
"vmap_out_axes": vmap_dim_z,
"vmap_as_xmap": vmap_as_xmap}
@parameterized.named_parameters(jtu.cases_from_list(VmapOfXmapCases()))
@ignore_xmap_warning()
def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_as_xmap):
"""Test various vmap(xmap) and xmap(xmap) combinations.
The outer map always introduces a single dimension, the inner map introduces one or two.
"""
(xin_x, xin_y) = xmap_in_axes
(vin_x, vin_y) = vmap_in_axes
vmap_size = 7
xmap_sizes = {'x': 11, 'y': 13}
xshape = [2, 3]
yshape = [3, 5]
zshape = [2, 5]
xind = ['n', 'k']
yind = ['k', 'm']
zind = ['n', 'm']
f = partial(jnp.einsum, 'nk,km->nm')
for pos, name in sorted(xin_x.items()):
xshape.insert(pos, xmap_sizes[name])
xind.insert(pos, name)
for pos, name in sorted(xin_y.items()):
yshape.insert(pos, xmap_sizes[name])
yind.insert(pos, name)
for pos, name in sorted(xmap_out_axes.items()):
zshape.insert(pos, xmap_sizes[name])
zind.insert(pos, name)
if vin_x is not None:
xshape.insert(vin_x, vmap_size)
xind.insert(vin_x, 'v')
if vin_y is not None:
yshape.insert(vin_y, vmap_size)
yind.insert(vin_y, 'v')
zshape.insert(vmap_out_axes, vmap_size)
zind.insert(vmap_out_axes, 'v')
if vmap_as_xmap:
do_vmap = partial(xmap,
in_axes=({vin_x: 'v'} if vin_x is not None else {},
{vin_y: 'v'} if vin_y is not None else {}),
out_axes={vmap_out_axes: 'v'})
else:
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes)
fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")
rng = np.random.RandomState(0)
x = rng.randn(*xshape)
y = rng.randn(*yshape)
self.assertAllClose(fm(x, y), fref(x, y))
class XMapTestSPMD(XMapTest):
"""Re-executes all tests with the SPMD partitioner enabled"""