Always treat all mesh axes controlled by xmap as MANUAL

PiperOrigin-RevId: 430192736
This commit is contained in:
Adam Paszke 2022-02-22 06:01:36 -08:00 committed by jax authors
parent a65841f5db
commit 2641f06152
3 changed files with 50 additions and 23 deletions

View File

@ -731,10 +731,12 @@ def make_xmap_callable(fun: lu.WrappedFun,
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics]
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
tiling_method = pxla.TilingMethod.MANUAL
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
tiling_method = pxla.TilingMethod.VECTORIZE
tiling_method = pxla.TileVectorize()
return pxla.lower_mesh_computation(
f, name, mesh,
mesh_in_axes, mesh_out_axes, donated_invars,
@ -1519,6 +1521,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
xla.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
@ -1528,7 +1531,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# NOTE: Sharding constraints are handled entirely by vtile_manual!
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
f = pxla.vtile_manual(f, mesh, mesh_in_axes, mesh_out_axes)
f = pxla.vtile_manual(f, tuple(manual_mesh_axes), mesh, mesh_in_axes, mesh_out_axes)
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
@ -1539,7 +1542,6 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,

View File

@ -37,9 +37,8 @@ from functools import partial
import itertools as it
import operator as op
import threading
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
import enum
import sys
from absl import logging
@ -2047,11 +2046,11 @@ def vtile_by_mesh(fun: lu.WrappedFun,
full_to_shard_p = core.Primitive('full_to_shard')
@full_to_shard_p.def_abstract_eval
def _full_to_shard_abstract_eval(x, axes, mesh):
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return tile_aval_nd(mesh.shape, axes, x)
def _manual_proto(aval, axes, mesh):
def _manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
"""
@ -2059,8 +2058,8 @@ def _manual_proto(aval, axes, mesh):
mesh_shape = list(named_mesh_shape.values())
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
manual_axes = list(axes)
replicated_axes = list(axis for axis in mesh.axis_names if axis not in axes)
manual_axes = list(sorted(manual_axes_set, key=str))
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
tad_perm = ([axis_order[a] for a in replicated_axes] +
[axis_order[a] for a in manual_axes])
@ -2077,29 +2076,29 @@ def _manual_proto(aval, axes, mesh):
return proto
@partial(mlir.register_lowering, full_to_shard_p)
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims)
manual_proto = _manual_proto(aval_in, axes, mesh)
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
return mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, unspecified_dims=unspecified_dims),
shard_to_full_p = core.Primitive('shard_to_full')
@shard_to_full_p.def_abstract_eval
def _shard_to_full_abstract_eval(x, axes, mesh):
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return untile_aval_nd(mesh.shape, axes, x)
@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
manual_proto = _manual_proto(aval_in, axes, mesh)
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=unspecified_dims)
@ -2107,18 +2106,29 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
@lu.transformation
def vtile_manual(mesh: Mesh,
def vtile_manual(manual_axes: FrozenSet[MeshAxisName],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
*args):
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh)
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
for arg, axes in zip(args, in_axes)]
tiled_outs = yield tiled_args, {}
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh)
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
for out, axes in zip(tiled_outs, out_axes)]
yield outs
TilingMethod = enum.Enum("TilingMethod", ["VECTORIZE", "MANUAL"])
@dataclasses.dataclass(frozen=True)
class TileVectorize:
pass
@dataclasses.dataclass(frozen=True)
class TileManual:
manual_axes: FrozenSet[MeshAxisName]
TilingMethod = Union[TileVectorize, TileManual]
@profiler.annotate_function
def lower_mesh_computation(
@ -2152,17 +2162,17 @@ def lower_mesh_computation(
if spmd_lowering:
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
if tiling_method is not None:
if tiling_method is TilingMethod.VECTORIZE:
if isinstance(tiling_method, TileVectorize):
tiling_transform = vtile_by_mesh
elif tiling_method is TilingMethod.MANUAL:
tiling_transform = vtile_manual
elif isinstance(tiling_method, TileManual):
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
else:
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_axes)
fun = tiling_transform(fun, mesh, in_axes, out_axes)
in_jaxpr_avals = global_in_avals
else:
assert tiling_method is TilingMethod.VECTORIZE
assert isinstance(tiling_method, TileVectorize)
in_jaxpr_avals = in_tiled_avals
with core.extend_axis_env_nd(mesh.shape.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "

View File

@ -717,6 +717,14 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(fx(x), f(x))
@jtu.with_mesh([('x', 2)])
def testReplicated(self):
# TODO(apaszke): This seems to be failing if I try to have a replicated and a mapped argument?
f = lambda x: jnp.sin(jnp.cos(x) + x) * x
fx = xmap(f, in_axes=[...], out_axes=[...], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(fx(x), f(x))
@jtu.with_mesh([('x', 2), ('y', 1)])
def testInPJit(self):
f = xmap(lambda x: jnp.sin(x) + x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
@ -724,6 +732,13 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testInPJitReplicated(self):
f = xmap(lambda x: jnp.sin(x) + x, in_axes={}, out_axes={}, axis_sizes={'i': 4}, axis_resources={'i': 'x'})
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testNestedConstraint(self):
# TODO(b/219691408): Using P('y') instead of P() causes an XLA crash!