mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add axis_sizes to xmap
Right now, all axis sizes have to be inferred from arguments to xmap which is unnecessarily strict. This lets users specify explicit sizes, allowing them to handle e.g. empty dicts that were supposed to contain mapped arguments.
This commit is contained in:
parent
9840f34e31
commit
a8b1f5f78f
@ -210,6 +210,8 @@ def _prepare_axes(axes, arg_name):
|
||||
def xmap(fun: Callable,
|
||||
in_axes,
|
||||
out_axes,
|
||||
*,
|
||||
axis_sizes: Dict[AxisName, int] = {},
|
||||
axis_resources: Dict[AxisName, Union[ResourceAxisName, Tuple[ResourceAxisName, ...]]] = {},
|
||||
backend: Optional[str] = None):
|
||||
"""Assign a positional signature to a program that uses named array axes.
|
||||
@ -298,6 +300,9 @@ def xmap(fun: Callable,
|
||||
as in ``in_axes``. Note that ``out_axes`` can also be a prefix of the return
|
||||
container structure, in which case the mapping is repeated for all arrays
|
||||
in the collapsed subtree.
|
||||
axis_sizes: A dict mapping axis names to their sizes. All axes defined by xmap
|
||||
have to appear either in ``in_axes`` or ``axis_sizes``. Sizes of axes
|
||||
that appear in ``in_axes`` are inferred from arguments whenever possible.
|
||||
axis_resources: A dictionary mapping the axes introduced in this
|
||||
:py:func:`xmap` to one or more resource axes. Any array that has in its
|
||||
shape an axis with some resources assigned will be partitioned over the
|
||||
@ -374,28 +379,33 @@ def xmap(fun: Callable,
|
||||
if isinstance(out_axes, list):
|
||||
out_axes = tuple(out_axes)
|
||||
|
||||
in_axes, in_axes_entries = _prepare_axes(in_axes, "in_axes")
|
||||
if in_axes == (): # Allow empty argument lists
|
||||
in_axes, in_axes_entries = (), []
|
||||
else:
|
||||
in_axes, in_axes_entries = _prepare_axes(in_axes, "in_axes")
|
||||
out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
|
||||
|
||||
axis_sizes_names = set(axis_sizes.keys())
|
||||
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
|
||||
defined_names = axis_sizes_names | in_axes_names
|
||||
out_axes_names = set(it.chain(*(spec.keys() for spec in out_axes_entries)))
|
||||
normalized_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]] = \
|
||||
{axis: (resources if isinstance(resources, tuple) else (resources,))
|
||||
for axis, resources in axis_resources.items()}
|
||||
for axis in in_axes_names:
|
||||
for axis in defined_names:
|
||||
normalized_axis_resources.setdefault(axis, ())
|
||||
frozen_axis_resources = FrozenDict(normalized_axis_resources)
|
||||
necessary_resources = set(it.chain(*frozen_axis_resources.values()))
|
||||
|
||||
axes_with_resources = set(frozen_axis_resources.keys())
|
||||
if axes_with_resources > in_axes_names:
|
||||
if axes_with_resources > defined_names:
|
||||
raise ValueError(f"All axes that were assigned resources have to appear in "
|
||||
f"in_axes, but the following are missing: "
|
||||
f"{axes_with_resources - in_axes_names}")
|
||||
if out_axes_names > in_axes_names:
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{axes_with_resources - defined_names}")
|
||||
if out_axes_names > defined_names:
|
||||
raise ValueError(f"All axis names appearing in out_axes must also appear in "
|
||||
f"in_axes, but the following are missing: "
|
||||
f"{out_axes_names - in_axes_names}")
|
||||
f"in_axes or axis_sizes, but the following are missing: "
|
||||
f"{out_axes_names - defined_names}")
|
||||
|
||||
for axis, resources in frozen_axis_resources.items():
|
||||
if len(set(resources)) != len(resources):
|
||||
@ -422,13 +432,14 @@ def xmap(fun: Callable,
|
||||
out_axes_thunk = HashableFunction(
|
||||
lambda: tuple(flatten_axes("xmap out_axes", out_tree(), out_axes)),
|
||||
closure=out_axes)
|
||||
axis_sizes = _get_axis_sizes(args_flat, in_axes_flat)
|
||||
frozen_axis_sizes = FrozenDict(_get_axis_sizes(args_flat, in_axes_flat, axis_sizes))
|
||||
assert set(frozen_axis_sizes.keys()) == set(frozen_axis_resources.keys())
|
||||
out_flat = xmap_p.bind(
|
||||
fun_flat, *args_flat,
|
||||
name=getattr(fun, '__name__', '<unnamed function>'),
|
||||
in_axes=tuple(in_axes_flat),
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
axis_sizes=FrozenDict(axis_sizes),
|
||||
axis_sizes=frozen_axis_sizes,
|
||||
axis_resources=frozen_axis_resources,
|
||||
resource_env=resource_env,
|
||||
backend=backend)
|
||||
@ -448,7 +459,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_axes, out_axes_thunk, axis_sizes,
|
||||
axis_resources, resource_env, backend,
|
||||
*in_avals):
|
||||
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env)
|
||||
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, axis_sizes)
|
||||
|
||||
# TODO: Making axis substitution final style would allow us to avoid
|
||||
# tracing to jaxpr here
|
||||
@ -483,22 +494,31 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
|
||||
resource_env: ResourceEnv
|
||||
axis_sizes: Dict[AxisName, int]
|
||||
physical_axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
||||
axis_subst: Dict[AxisName, Tuple[ResourceAxisName, ...]]
|
||||
|
||||
@classmethod
|
||||
def from_axis_resources(cls, axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]], resource_env):
|
||||
def from_axis_resources(cls,
|
||||
axis_resources: Dict[AxisName, Tuple[ResourceAxisName, ...]],
|
||||
resource_env: ResourceEnv,
|
||||
axis_sizes: Dict[AxisName, int]):
|
||||
# TODO: Support sequential resources
|
||||
physical_axis_resources = axis_resources # NB: We only support physical resources at the moment
|
||||
axis_subst = {name: axes + (fresh_resource_name(name),) for name, axes in axis_resources.items()}
|
||||
return cls(physical_axis_resources, axis_subst)
|
||||
return cls(resource_env, axis_sizes, physical_axis_resources, axis_subst)
|
||||
|
||||
def vectorize(self, f: lu.WrappedFun, in_axes, out_axes):
|
||||
resource_shape = self.resource_env.shape
|
||||
for naxis, raxes in self.axis_subst.items():
|
||||
vaxis = raxes[-1]
|
||||
paxes, vaxis = raxes[:-1], raxes[-1]
|
||||
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
|
||||
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
|
||||
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=None, axis_name=vaxis)
|
||||
paxes_size = int(np.prod([resource_shape[paxis] for paxis in paxes], dtype=np.int64))
|
||||
assert self.axis_sizes[naxis] % paxes_size == 0
|
||||
tile_size = self.axis_sizes[naxis] // paxes_size
|
||||
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=tile_size, axis_name=vaxis)
|
||||
return f
|
||||
|
||||
def to_mesh_axes(self, in_axes, out_axes):
|
||||
@ -624,7 +644,7 @@ def _xmap_translation_rule_replica(c, axis_env,
|
||||
call_jaxpr, name,
|
||||
in_axes, out_axes, axis_sizes,
|
||||
axis_resources, resource_env, backend):
|
||||
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env)
|
||||
plan = EvaluationPlan.from_axis_resources(axis_resources, resource_env, axis_sizes)
|
||||
|
||||
local_mesh = resource_env.physical_mesh.local_mesh
|
||||
local_mesh_shape = local_mesh.shape
|
||||
@ -756,8 +776,10 @@ def _insert_aval_axes(aval, axes: AxisNamePos, axis_sizes):
|
||||
|
||||
|
||||
# TODO: pmap has some very fancy error messages for this function!
|
||||
def _get_axis_sizes(args_flat: Iterable[Any], in_axes_flat: Iterable[AxisNamePos]):
|
||||
axis_sizes: Dict[AxisName, int] = {}
|
||||
def _get_axis_sizes(args_flat: Iterable[Any],
|
||||
in_axes_flat: Iterable[AxisNamePos],
|
||||
axis_sizes: Dict[AxisName, int]):
|
||||
axis_sizes = dict(axis_sizes)
|
||||
for arg, in_axes in zip(args_flat, in_axes_flat):
|
||||
for name, dim in in_axes.items():
|
||||
if name in axis_sizes:
|
||||
|
@ -1534,9 +1534,6 @@ def vtile(f_flat,
|
||||
in_axes_flat: Tuple[Optional[int], ...],
|
||||
out_axes_flat: Tuple[Optional[int], ...],
|
||||
tile_size: Optional[int], axis_name):
|
||||
if tile_size == 1:
|
||||
return f_flat
|
||||
|
||||
@curry
|
||||
def tile_axis(arg, axis: Optional[int], tile_size):
|
||||
if axis is None:
|
||||
|
@ -82,6 +82,8 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
with mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
def with_mesh_from_kwargs(f):
|
||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
||||
|
||||
class XMapTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
@ -202,29 +204,27 @@ class XMapTest(jtu.JaxTestCase):
|
||||
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
|
||||
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@ignore_xmap_warning()
|
||||
def testNestedMesh(self, mesh, axis_resources):
|
||||
@with_mesh(mesh)
|
||||
def run_test():
|
||||
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
||||
axis_resources=dict([axis_resources[0]]))
|
||||
def f(x):
|
||||
y = x * 2
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
|
||||
axis_resources=dict([axis_resources[1]]))
|
||||
def h(y):
|
||||
return jnp.sin(y), lax.psum(y, ('a', 'b'))
|
||||
return h(y)
|
||||
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
||||
axis_resources=dict([axis_resources[0]]))
|
||||
def f(x):
|
||||
y = x * 2
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
|
||||
axis_resources=dict([axis_resources[1]]))
|
||||
def h(y):
|
||||
return jnp.sin(y), lax.psum(y, ('a', 'b'))
|
||||
return h(y)
|
||||
|
||||
xshape = (4, 2, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
|
||||
self.assertEqual(y[0].sharding_spec.sharding,
|
||||
(pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
run_test()
|
||||
xshape = (4, 2, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
|
||||
self.assertEqual(y[0].sharding_spec.sharding,
|
||||
(pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
@ -232,23 +232,36 @@ class XMapTest(jtu.JaxTestCase):
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@ignore_xmap_warning()
|
||||
def testMultipleCalls(self, mesh, axis_resources):
|
||||
def f(x, y):
|
||||
assert x.shape == y.shape == (3, 5)
|
||||
return jnp.tensordot(x, y, axes=([1], [1]))
|
||||
|
||||
@with_mesh(mesh)
|
||||
def run_test():
|
||||
f_mapped = xmap(f,
|
||||
in_axes=(['i', ...], ['j', ...]),
|
||||
out_axes=['i', 'j', ...],
|
||||
axis_resources=dict(axis_resources))
|
||||
x = jnp.arange(30).reshape(2, 3, 5)
|
||||
expected = jnp.einsum('imk,jnk->ijmn', x, x)
|
||||
for i in range(10):
|
||||
self.assertAllClose(f_mapped(x, x), expected)
|
||||
run_test()
|
||||
f_mapped = xmap(f,
|
||||
in_axes=(['i', ...], ['j', ...]),
|
||||
out_axes=['i', 'j', ...],
|
||||
axis_resources=dict(axis_resources))
|
||||
x = jnp.arange(30).reshape(2, 3, 5)
|
||||
expected = jnp.einsum('imk,jnk->ijmn', x, x)
|
||||
for i in range(10):
|
||||
self.assertAllClose(f_mapped(x, x), expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||
for name, mesh, axis_resources in (
|
||||
('', (), ()),
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))
|
||||
@with_mesh_from_kwargs
|
||||
@ignore_xmap_warning()
|
||||
def testAxisSizes(self, mesh, axis_resources):
|
||||
result = xmap(lambda: lax.axis_index('i'),
|
||||
in_axes=(), out_axes=['i', ...],
|
||||
axis_sizes={'i': 6},
|
||||
axis_resources=dict(axis_resources))()
|
||||
self.assertAllClose(result, jnp.arange(6, dtype=result.dtype))
|
||||
|
||||
def VmapOfXmapCases():
|
||||
xmap_in_axes = ([{}] +
|
||||
|
Loading…
x
Reference in New Issue
Block a user