Add axis_sizes to xmap

Right now, all axis sizes have to be inferred from arguments to xmap
which is unnecessarily strict. This lets users specify explicit sizes,
allowing them to handle e.g. empty dicts that were supposed to contain
mapped arguments.
This commit is contained in:
Adam Paszke 2021-02-02 15:39:23 +00:00
parent 9840f34e31
commit a8b1f5f78f
3 changed files with 84 additions and 52 deletions

View File

@ -210,6 +210,8 @@ def _prepare_axes(axes, arg_name):
def xmap(fun: Callable,
in_axes,
out_axes,
*,
axis_sizes: Dict[AxisName, int] = {},
axis_resources: Dict[AxisName, Union[ResourceAxisName, Tuple[ResourceAxisName, ...]]] = {},
backend: Optional[str] = None):
"""Assign a positional signature to a program that uses named array axes.
@ -298,6 +300,9 @@ def xmap(fun: Callable,
as in ``in_axes``. Note that ``out_axes`` can also be a prefix of the return
container structure, in which case the mapping is repeated for all arrays
in the collapsed subtree.
axis_sizes: A dict mapping axis names to their sizes. All axes defined by xmap
have to appear either in ``in_axes`` or ``axis_sizes``. Sizes of axes
that appear in ``in_axes`` are inferred from arguments whenever possible.
axis_resources: A dictionary mapping the axes introduced in this
:py:func:`xmap` to one or more resource axes. Any array that has in its
shape an axis with some resources assigned will be partitioned over the
@ -374,28 +379,33 @@ def xmap(fun: Callable,
if isinstance(out_axes, list):
out_axes = tuple(out_axes)
in_axes, in_axes_entries = _prepare_axes(in_axes, "in_axes")
if in_axes == (): # Allow empty argument lists
in_axes, in_axes_entries = (), []
else:
in_axes, in_axes_entries = _prepare_axes(in_axes, "in_axes")
out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
axis_sizes_names = set(axis_sizes.keys())
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
defined_names = axis_sizes_names | in_axes_names
out_axes_names = set(it.chain(*(spec.keys() for spec in out_axes_entries)))
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = \
{axis: (resources if isinstance(resources, tuple) else (resources,))
for axis, resources in axis_resources.items()}
for axis in in_axes_names:
for axis in defined_names:
normalized_axis_resources.setdefault(axis, ())
frozen_axis_resources = FrozenDict(normalized_axis_resources)
necessary_resources = set(it.chain(*frozen_axis_resources.values()))
axes_with_resources = set(frozen_axis_resources.keys())
if axes_with_resources > in_axes_names:
if axes_with_resources > defined_names:
raise ValueError(f"All axes that were assigned resources have to appear in "
f"in_axes, but the following are missing: "
f"{axes_with_resources - in_axes_names}")
if out_axes_names > in_axes_names:
f"in_axes or axis_sizes, but the following are missing: "
f"{axes_with_resources - defined_names}")
if out_axes_names > defined_names:
raise ValueError(f"All axis names appearing in out_axes must also appear in "
f"in_axes, but the following are missing: "
f"{out_axes_names - in_axes_names}")
f"in_axes or axis_sizes, but the following are missing: "
f"{out_axes_names - defined_names}")
for axis, resources in frozen_axis_resources.items():
if len(set(resources)) != len(resources):
@ -422,13 +432,14 @@ def xmap(fun: Callable,
out_axes_thunk = HashableFunction(
lambda: tuple(flatten_axes("xmap out_axes", out_tree(), out_axes)),
closure=out_axes)
axis_sizes = _get_axis_sizes(args_flat, in_axes_flat)
frozen_axis_sizes = FrozenDict(_get_axis_sizes(args_flat, in_axes_flat, axis_sizes))
assert set(frozen_axis_sizes.keys()) == set(frozen_axis_resources.keys())
out_flat = xmap_p.bind(
fun_flat, *args_flat,
name=getattr(fun, '__name__', '<unnamed function>'),
in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk,
axis_sizes=FrozenDict(axis_sizes),
axis_sizes=frozen_axis_sizes,
axis_resources=frozen_axis_resources,
resource_env=resource_env,
backend=backend)
@ -448,7 +459,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
in_axes, out_axes_thunk, axis_sizes,
axis_resources, resource_env, backend,
*in_avals):
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env)
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, axis_sizes)
# TODO: Making axis substitution final style would allow us to avoid
# tracing to jaxpr here
@ -483,22 +494,31 @@ def make_xmap_callable(fun: lu.WrappedFun,
class EvaluationPlan(NamedTuple):
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
resource_env: ResourceEnv
axis_sizes: Dict[AxisName, int]
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
axis_subst: Dict[AxisName, Tuple[ResourceAxisName, ...]]
@classmethod
def from_axis_resources(cls, axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]], resource_env):
def from_axis_resources(cls,
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
resource_env: ResourceEnv,
axis_sizes: Dict[AxisName, int]):
# TODO: Support sequential resources
physical_axis_resources = axis_resources # NB: We only support physical resources at the moment
axis_subst = {name: axes + (fresh_resource_name(name),) for name, axes in axis_resources.items()}
return cls(physical_axis_resources, axis_subst)
return cls(resource_env, axis_sizes, physical_axis_resources, axis_subst)
def vectorize(self, f: lu.WrappedFun, in_axes, out_axes):
resource_shape = self.resource_env.shape
for naxis, raxes in self.axis_subst.items():
vaxis = raxes[-1]
paxes, vaxis = raxes[:-1], raxes[-1]
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=None, axis_name=vaxis)
paxes_size = int(np.prod([resource_shape[paxis] for paxis in paxes], dtype=np.int64))
assert self.axis_sizes[naxis] % paxes_size == 0
tile_size = self.axis_sizes[naxis] // paxes_size
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=tile_size, axis_name=vaxis)
return f
def to_mesh_axes(self, in_axes, out_axes):
@ -624,7 +644,7 @@ def _xmap_translation_rule_replica(c, axis_env,
call_jaxpr, name,
in_axes, out_axes, axis_sizes,
axis_resources, resource_env, backend):
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env)
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, axis_sizes)
local_mesh = resource_env.physical_mesh.local_mesh
local_mesh_shape = local_mesh.shape
@ -756,8 +776,10 @@ def _insert_aval_axes(aval, axes: AxisNamePos, axis_sizes):
# TODO: pmap has some very fancy error messages for this function!
def _get_axis_sizes(args_flat: Iterable[Any], in_axes_flat: Iterable[AxisNamePos]):
axis_sizes: Dict[AxisName, int] = {}
def _get_axis_sizes(args_flat: Iterable[Any],
in_axes_flat: Iterable[AxisNamePos],
axis_sizes: Dict[AxisName, int]):
axis_sizes = dict(axis_sizes)
for arg, in_axes in zip(args_flat, in_axes_flat):
for name, dim in in_axes.items():
if name in axis_sizes:

View File

@ -1534,9 +1534,6 @@ def vtile(f_flat,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int], axis_name):
if tile_size == 1:
return f_flat
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:

View File

@ -82,6 +82,8 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
class XMapTest(jtu.JaxTestCase):
def setUp(self):
@ -202,29 +204,27 @@ class XMapTest(jtu.JaxTestCase):
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
))
@with_mesh_from_kwargs
@ignore_xmap_warning()
def testNestedMesh(self, mesh, axis_resources):
@with_mesh(mesh)
def run_test():
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
axis_resources=dict([axis_resources[0]]))
def f(x):
y = x * 2
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
axis_resources=dict([axis_resources[1]]))
def h(y):
return jnp.sin(y), lax.psum(y, ('a', 'b'))
return h(y)
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
axis_resources=dict([axis_resources[0]]))
def f(x):
y = x * 2
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
axis_resources=dict([axis_resources[1]]))
def h(y):
return jnp.sin(y), lax.psum(y, ('a', 'b'))
return h(y)
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
y = f(x)
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
self.assertEqual(y[0].sharding_spec.sharding,
(pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(y[0].sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
run_test()
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
y = f(x)
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
self.assertEqual(y[0].sharding_spec.sharding,
(pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(y[0].sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
@ -232,23 +232,36 @@ class XMapTest(jtu.JaxTestCase):
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))
@with_mesh_from_kwargs
@ignore_xmap_warning()
def testMultipleCalls(self, mesh, axis_resources):
def f(x, y):
assert x.shape == y.shape == (3, 5)
return jnp.tensordot(x, y, axes=([1], [1]))
@with_mesh(mesh)
def run_test():
f_mapped = xmap(f,
in_axes=(['i', ...], ['j', ...]),
out_axes=['i', 'j', ...],
axis_resources=dict(axis_resources))
x = jnp.arange(30).reshape(2, 3, 5)
expected = jnp.einsum('imk,jnk->ijmn', x, x)
for i in range(10):
self.assertAllClose(f_mapped(x, x), expected)
run_test()
f_mapped = xmap(f,
in_axes=(['i', ...], ['j', ...]),
out_axes=['i', 'j', ...],
axis_resources=dict(axis_resources))
x = jnp.arange(30).reshape(2, 3, 5)
expected = jnp.einsum('imk,jnk->ijmn', x, x)
for i in range(10):
self.assertAllClose(f_mapped(x, x), expected)
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))
@with_mesh_from_kwargs
@ignore_xmap_warning()
def testAxisSizes(self, mesh, axis_resources):
result = xmap(lambda: lax.axis_index('i'),
in_axes=(), out_axes=['i', ...],
axis_sizes={'i': 6},
axis_resources=dict(axis_resources))()
self.assertAllClose(result, jnp.arange(6, dtype=result.dtype))
def VmapOfXmapCases():
xmap_in_axes = ([{}] +