mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Always treat all mesh axes controlled by xmap as MANUAL
PiperOrigin-RevId: 430192736
This commit is contained in:
parent
a65841f5db
commit
2641f06152
@ -731,10 +731,12 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
|
||||
]
|
||||
in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics]
|
||||
tiling_method: pxla.TilingMethod
|
||||
if config.experimental_xmap_spmd_lowering_manual:
|
||||
tiling_method = pxla.TilingMethod.MANUAL
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
tiling_method = pxla.TileManual(manual_mesh_axes)
|
||||
else:
|
||||
tiling_method = pxla.TilingMethod.VECTORIZE
|
||||
tiling_method = pxla.TileVectorize()
|
||||
return pxla.lower_mesh_computation(
|
||||
f, name, mesh,
|
||||
mesh_in_axes, mesh_out_axes, donated_invars,
|
||||
@ -1519,6 +1521,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
|
||||
resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
||||
@ -1528,7 +1531,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
# NOTE: Sharding constraints are handled entirely by vtile_manual!
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
mesh = resource_env.physical_mesh
|
||||
f = pxla.vtile_manual(f, mesh, mesh_in_axes, mesh_out_axes)
|
||||
f = pxla.vtile_manual(f, tuple(manual_mesh_axes), mesh, mesh_in_axes, mesh_out_axes)
|
||||
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
@ -1539,7 +1542,6 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
|
@ -37,9 +37,8 @@ from functools import partial
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import threading
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
|
||||
import enum
|
||||
import sys
|
||||
|
||||
from absl import logging
|
||||
@ -2047,11 +2046,11 @@ def vtile_by_mesh(fun: lu.WrappedFun,
|
||||
full_to_shard_p = core.Primitive('full_to_shard')
|
||||
|
||||
@full_to_shard_p.def_abstract_eval
|
||||
def _full_to_shard_abstract_eval(x, axes, mesh):
|
||||
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
|
||||
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
||||
return tile_aval_nd(mesh.shape, axes, x)
|
||||
|
||||
def _manual_proto(aval, axes, mesh):
|
||||
def _manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh):
|
||||
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
|
||||
and all others as replicated.
|
||||
"""
|
||||
@ -2059,8 +2058,8 @@ def _manual_proto(aval, axes, mesh):
|
||||
mesh_shape = list(named_mesh_shape.values())
|
||||
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
|
||||
|
||||
manual_axes = list(axes)
|
||||
replicated_axes = list(axis for axis in mesh.axis_names if axis not in axes)
|
||||
manual_axes = list(sorted(manual_axes_set, key=str))
|
||||
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
|
||||
|
||||
tad_perm = ([axis_order[a] for a in replicated_axes] +
|
||||
[axis_order[a] for a in manual_axes])
|
||||
@ -2077,29 +2076,29 @@ def _manual_proto(aval, axes, mesh):
|
||||
return proto
|
||||
|
||||
@partial(mlir.register_lowering, full_to_shard_p)
|
||||
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
|
||||
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
|
||||
# TODO: Can we short-circuit for replicated values? Probably not.
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
||||
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims)
|
||||
manual_proto = _manual_proto(aval_in, axes, mesh)
|
||||
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
return mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, unspecified_dims=unspecified_dims),
|
||||
|
||||
shard_to_full_p = core.Primitive('shard_to_full')
|
||||
|
||||
@shard_to_full_p.def_abstract_eval
|
||||
def _shard_to_full_abstract_eval(x, axes, mesh):
|
||||
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
||||
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
||||
return untile_aval_nd(mesh.shape, axes, x)
|
||||
|
||||
@partial(mlir.register_lowering, shard_to_full_p)
|
||||
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
|
||||
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
manual_proto = _manual_proto(aval_in, axes, mesh)
|
||||
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
||||
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=unspecified_dims)
|
||||
@ -2107,18 +2106,29 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
|
||||
|
||||
@lu.transformation
|
||||
def vtile_manual(mesh: Mesh,
|
||||
def vtile_manual(manual_axes: FrozenSet[MeshAxisName],
|
||||
mesh: Mesh,
|
||||
in_axes: Sequence[ArrayMapping],
|
||||
out_axes: Sequence[ArrayMapping],
|
||||
*args):
|
||||
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh)
|
||||
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
||||
for arg, axes in zip(args, in_axes)]
|
||||
tiled_outs = yield tiled_args, {}
|
||||
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh)
|
||||
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
|
||||
for out, axes in zip(tiled_outs, out_axes)]
|
||||
yield outs
|
||||
|
||||
TilingMethod = enum.Enum("TilingMethod", ["VECTORIZE", "MANUAL"])
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TileVectorize:
|
||||
pass
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TileManual:
|
||||
manual_axes: FrozenSet[MeshAxisName]
|
||||
|
||||
TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_mesh_computation(
|
||||
@ -2152,17 +2162,17 @@ def lower_mesh_computation(
|
||||
if spmd_lowering:
|
||||
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
||||
if tiling_method is not None:
|
||||
if tiling_method is TilingMethod.VECTORIZE:
|
||||
if isinstance(tiling_method, TileVectorize):
|
||||
tiling_transform = vtile_by_mesh
|
||||
elif tiling_method is TilingMethod.MANUAL:
|
||||
tiling_transform = vtile_manual
|
||||
elif isinstance(tiling_method, TileManual):
|
||||
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
|
||||
else:
|
||||
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
|
||||
assert not callable(out_axes)
|
||||
fun = tiling_transform(fun, mesh, in_axes, out_axes)
|
||||
in_jaxpr_avals = global_in_avals
|
||||
else:
|
||||
assert tiling_method is TilingMethod.VECTORIZE
|
||||
assert isinstance(tiling_method, TileVectorize)
|
||||
in_jaxpr_avals = in_tiled_avals
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
|
@ -717,6 +717,14 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
||||
x = jnp.arange(20, dtype=jnp.float32)
|
||||
self.assertAllClose(fx(x), f(x))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testReplicated(self):
|
||||
# TODO(apaszke): This seems to be failing if I try to have a replicated and a mapped argument?
|
||||
f = lambda x: jnp.sin(jnp.cos(x) + x) * x
|
||||
fx = xmap(f, in_axes=[...], out_axes=[...], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
x = jnp.arange(20, dtype=jnp.float32)
|
||||
self.assertAllClose(fx(x), f(x))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testInPJit(self):
|
||||
f = xmap(lambda x: jnp.sin(x) + x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
||||
@ -724,6 +732,13 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
||||
x = jnp.arange(20, dtype=jnp.float32)
|
||||
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testInPJitReplicated(self):
|
||||
f = xmap(lambda x: jnp.sin(x) + x, in_axes={}, out_axes={}, axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
|
||||
x = jnp.arange(20, dtype=jnp.float32)
|
||||
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testNestedConstraint(self):
|
||||
# TODO(b/219691408): Using P('y') instead of P() causes an XLA crash!
|
||||
|
Loading…
x
Reference in New Issue
Block a user