[sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.

PiperOrigin-RevId: 728920514
This commit is contained in:
Yash Katariya 2025-02-19 20:04:19 -08:00 committed by jax authors
parent cb0d326e16
commit 8305803b76
7 changed files with 156 additions and 40 deletions

View File

@ -964,13 +964,14 @@ def shard_map_error_check(
new_in_names = (*([{}] * num_error_vals), *in_names)
new_vals_in = [*err_vals, *vals_in]
in_avals = list(map(core.get_aval, new_vals_in))
auto = kwargs.get('auto')
for i, v in enumerate(in_avals):
if not (sharder := core.shard_aval_handlers.get(type(v))):
raise ValueError(f'Unsupported aval type: {type(v)}')
in_avals[i] = sharder(mesh, new_in_names[i], v)
in_avals[i] = sharder(mesh, auto, new_in_names[i], v)
with (core.extend_axis_env_nd(mesh.shape.items()),
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh))):
with (shard_map._extend_axis_env(mesh, auto),
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
# jaxpr to checked_jaxpr
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals

View File

@ -1808,8 +1808,7 @@ def _maybe_modify_sharding(sharding, ndim):
out = sharding.with_spec(modify_spec_for_auto_manual(
sharding.spec, sharding.mesh))
if (len(out.spec) != ndim and
(out.mesh.empty or out.mesh._are_all_axes_auto or
out.mesh._are_all_axes_manual)):
(out.mesh.empty or out.mesh._are_all_axes_auto_or_manual)):
out = _make_lengths_same(out, ndim)
return out

View File

@ -68,8 +68,8 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
cur_mesh = mesh_lib.get_abstract_mesh()
aval_mesh = _get_abstract_mesh_from_avals(avals)
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and
(aval_mesh.empty or aval_mesh._are_all_axes_auto or aval_mesh._are_all_axes_manual)):
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and
(aval_mesh.empty or aval_mesh._are_all_axes_auto_or_manual)):
aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh
s = NamedSharding(aval_mesh, P())
return s if num_out is None else [s] * num_out

View File

@ -159,6 +159,13 @@ class _BaseMesh:
def _are_all_axes_explicit(self) -> bool:
return all_axis_types_match(self.axis_types, AxisTypes.Explicit)
@functools.cached_property
def _are_all_axes_auto_or_manual(self) -> bool:
if not self.axis_types:
return False
return all(t == AxisTypes.Auto or t == AxisTypes.Manual
for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_manual(self) -> bool:
return any_axis_types_match(self.axis_types, AxisTypes.Manual)
@ -173,6 +180,8 @@ class _BaseMesh:
@functools.cached_property
def axis_types(self):
if not self.axis_names:
return {}
d = collections.defaultdict(list)
for n, t in safe_zip(self.axis_names, self._axis_types_tuple):
d[t].append(n)

View File

@ -401,7 +401,7 @@ def shaped_array_ref(
shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type))
def _shard_ref(mesh, names, ref_aval: AbstractRef):
def _shard_ref(mesh, auto, names, ref_aval: AbstractRef):
del mesh
if names:
# Can't actually shard a ref, can only close over it.

View File

@ -479,9 +479,27 @@ shard_map_p = ShardMapPrimitive('shard_map')
# Staging
def _as_manual_mesh(mesh):
@util.cache(max_size=256, trace_context_in_key=True)
def _as_manual_mesh(mesh, auto: frozenset):
manual_axes = tuple(set(mesh.axis_names) - auto)
cur_mesh = get_abstract_mesh()
if cur_mesh.empty:
auto_axes = tuple(auto)
explicit_axes = ()
else:
explicit_axes, auto_axes = [], [] # type: ignore
for a in auto:
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
auto_axes.append(a)
else:
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
explicit_axes.append(a)
explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore
return AbstractMesh(
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
mesh.shape_tuple,
axis_types={
AxisTypes.Manual: manual_axes, AxisTypes.Auto: auto_axes,
AxisTypes.Explicit: explicit_axes})
def _extend_axis_env(mesh, auto):
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
@ -498,10 +516,9 @@ def _shard_map_staging(
) -> Sequence[pe.DynamicJaxprTracer]:
in_tracers = map(trace.to_jaxpr_tracer, in_tracers)
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
manual_mesh = _as_manual_mesh(mesh)
with (_extend_axis_env(mesh, auto),
set_abstract_mesh(manual_mesh)):
in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals)
manual_mesh = _as_manual_mesh(mesh, auto)
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
@ -517,8 +534,7 @@ def _shard_map_staging(
constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts))
outvars = map(trace.makevar, out_tracers)
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
with (_extend_axis_env(mesh, auto),
set_abstract_mesh(manual_mesh)):
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_names=in_names_staged,
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
@ -534,10 +550,10 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
def _shard_aval(mesh: Mesh, auto, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if type(aval) in core.shard_aval_handlers:
return core.shard_aval_handlers[type(aval)](mesh, names, aval)
return core.shard_aval_handlers[type(aval)](mesh, auto, names, aval)
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
@ -547,14 +563,13 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
else:
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames,
aval: core.AbstractValue) -> core.AbstractValue:
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
new_mesh = AbstractMesh(
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
manual_mesh = _as_manual_mesh(mesh, auto)
new_sharding = NamedSharding(manual_mesh, aval.sharding.spec)
return aval.update(shape=new_shape, sharding=new_sharding)
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
@ -563,10 +578,27 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim)
names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim)
if aval.ndim == 0:
out_spec = names_spec
else:
out_spec = [] # type: ignore
for name_s, aval_s in zip(names_spec, aval.sharding.spec):
if name_s and not aval_s:
out_spec.append(name_s)
elif aval_s and not name_s:
out_spec.append(aval_s)
elif not name_s and not aval_s:
out_spec.append(None)
else:
assert name_s and aval_s
name_s = name_s if isinstance(name_s, tuple) else (name_s,)
aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,)
out_spec.append(name_s + aval_s)
out_spec = PartitionSpec(*out_spec)
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
get_abstract_mesh())
new_sharding = NamedSharding(new_mesh, spec)
new_sharding = NamedSharding(new_mesh, out_spec)
return aval.update(shape=new_shape, sharding=new_sharding)
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
@ -578,7 +610,7 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
# TODO(mattjj,parkers): check auto
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
if not core.typecompat(v.aval, _shard_aval(mesh, auto, in_name, x.aval)):
raise core.JaxprTypeError("shard_map argument avals not compatible with "
"jaxpr binder avals and in_names")
with _extend_axis_env(mesh, auto):
@ -702,9 +734,11 @@ def _shard_map_lowering_shardy(
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
del check_rep, rewrite
if config.use_shardy_partitioner.value:
return _shard_map_lowering_shardy(
ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto)
in_avals_ = [v.aval for v in jaxpr.invars]
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
@ -814,9 +848,8 @@ core.EvalTrace.process_shard_map = _shard_map_impl
def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh):
trace = ShardMapTrace(mesh, auto, check_rep, context_mesh)
in_tracers = map(partial(ShardMapTracer, trace), reps, args)
manual_mesh = _as_manual_mesh(mesh)
with (core.set_current_trace(trace),
_extend_axis_env(mesh, auto),
manual_mesh = _as_manual_mesh(mesh, auto)
with (core.set_current_trace(trace), _extend_axis_env(mesh, auto),
set_abstract_mesh(manual_mesh)):
ans = f.call_wrapped(*in_tracers)
outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans))
@ -971,8 +1004,9 @@ class ShardMapTracer(core.Tracer):
def aval(self):
aval = core.get_aval(self.val)
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh),
out.sharding.spec) # pytype: disable=attribute-error
new_sharding = NamedSharding(
_as_manual_mesh(self._trace.mesh, self._trace.auto),
out.sharding.spec) # pytype: disable=attribute-error
return out.update(sharding=new_sharding)
def to_concrete_value(self):
@ -1526,7 +1560,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
in_avals_sharded = map(partial(_shard_aval, mesh, auto), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits(
@ -1675,12 +1709,12 @@ def _shard_map_transpose(out_cts, *args,
jaxpr: core.Jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
out_cts = [ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
args = tuple(x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval))
for ns, x in zip(in_names, args))
all_args, in_tree = tree_flatten((out_cts, args))
@ -1693,9 +1727,9 @@ def _shard_map_transpose(out_cts, *args,
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
)
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_names, out)]
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_names, out)]
return out
fun_trans = lu.wrap_init(fun_trans_callable,
@ -1748,7 +1782,7 @@ def _partial_eval_jaxpr_custom_rule(
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
mesh = eqn.params['mesh']
with (_extend_axis_env(mesh, auto),
set_abstract_mesh(_as_manual_mesh(mesh))):
set_abstract_mesh(_as_manual_mesh(mesh, auto))):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)

View File

@ -40,7 +40,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
from jax._src.mesh import AbstractMesh
from jax._src.mesh import AbstractMesh, AxisTypes
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
@ -1890,6 +1890,8 @@ class ShardMapTest(jtu.JaxTestCase):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
def g(x):
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
{AxisTypes.Manual: ('i',), AxisTypes.Auto: ('j',)})
x = jax.lax.with_sharding_constraint(
x, jax.sharding.NamedSharding(mesh, P(None, 'j')))
return x * x
@ -1917,7 +1919,78 @@ class ShardMapTest(jtu.JaxTestCase):
' replicated}}',
f.lower(v).as_text('hlo'),
)
self.assertAllClose(v*v, f(v), check_dtypes=False)
self.assertAllClose(v * v, f(v), check_dtypes=False)
@jtu.with_user_mesh((2, 2), ('i', 'j'))
def test_partial_auto_explicit(self, mesh):
def g(x):
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
{AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)})
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
out = x * x
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
return out
@jax.jit
def f(x):
x = shard_map(g, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
auto=frozenset({'j'}))(x)
self.assertEqual(x.aval.sharding.spec, P('i', 'j'))
return x
v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
out = f(v)
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v * v, out, check_dtypes=False)
if config.use_shardy_partitioner.value:
self.assertIn(
'sdy.sharding_constraint %1 <@mesh, [{}, {"j"}]>',
f.lower(v).as_text(),
)
else:
self.assertIn(
'mhlo.sharding = "{devices=[1,2,2]<=[2,2]T(1,0) last_tile_dims={manual}}"}',
f.lower(v).as_text(),
)
@jax.jit
def h(x):
return jnp.sum(f(x))
jax.grad(h)(v) # doesn't crash
jax.jit(jax.grad(h))(v) # doesn't crash
@jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l'))
def test_partial_auto_explicit_multi_explicit(self, mesh):
def g(x):
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
{AxisTypes.Manual: ('i', 'j'),
AxisTypes.Explicit: ('k', 'l')})
self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l'))
out = x.T
self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None))
return out
@jax.jit
def f(x):
x = shard_map(g, mesh,
in_specs=P('i', 'j', None, None),
out_specs=P('i', 'j', None, None),
auto=frozenset({'k', 'l'}))(x)
self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None))
return x
v = jnp.arange(64.).reshape(4, 2, 2, 4)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j', 'k', 'l')))
out = f(v)
self.assertEqual(
out.sharding, NamedSharding(mesh, P(('i', 'l'), ('j', 'k'), None, None)))
def test_partial_auto_propagate_through(self):
mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k'))