mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...)
is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
This commit is contained in:
parent
cb0d326e16
commit
8305803b76
@ -964,13 +964,14 @@ def shard_map_error_check(
|
||||
new_in_names = (*([{}] * num_error_vals), *in_names)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = list(map(core.get_aval, new_vals_in))
|
||||
auto = kwargs.get('auto')
|
||||
for i, v in enumerate(in_avals):
|
||||
if not (sharder := core.shard_aval_handlers.get(type(v))):
|
||||
raise ValueError(f'Unsupported aval type: {type(v)}')
|
||||
in_avals[i] = sharder(mesh, new_in_names[i], v)
|
||||
in_avals[i] = sharder(mesh, auto, new_in_names[i], v)
|
||||
|
||||
with (core.extend_axis_env_nd(mesh.shape.items()),
|
||||
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh))):
|
||||
with (shard_map._extend_axis_env(mesh, auto),
|
||||
mesh_lib.set_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))):
|
||||
# jaxpr to checked_jaxpr
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
|
||||
|
@ -1808,8 +1808,7 @@ def _maybe_modify_sharding(sharding, ndim):
|
||||
out = sharding.with_spec(modify_spec_for_auto_manual(
|
||||
sharding.spec, sharding.mesh))
|
||||
if (len(out.spec) != ndim and
|
||||
(out.mesh.empty or out.mesh._are_all_axes_auto or
|
||||
out.mesh._are_all_axes_manual)):
|
||||
(out.mesh.empty or out.mesh._are_all_axes_auto_or_manual)):
|
||||
out = _make_lengths_same(out, ndim)
|
||||
return out
|
||||
|
||||
|
@ -68,8 +68,8 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
|
||||
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
aval_mesh = _get_abstract_mesh_from_avals(avals)
|
||||
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and
|
||||
(aval_mesh.empty or aval_mesh._are_all_axes_auto or aval_mesh._are_all_axes_manual)):
|
||||
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and
|
||||
(aval_mesh.empty or aval_mesh._are_all_axes_auto_or_manual)):
|
||||
aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh
|
||||
s = NamedSharding(aval_mesh, P())
|
||||
return s if num_out is None else [s] * num_out
|
||||
|
@ -159,6 +159,13 @@ class _BaseMesh:
|
||||
def _are_all_axes_explicit(self) -> bool:
|
||||
return all_axis_types_match(self.axis_types, AxisTypes.Explicit)
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto_or_manual(self) -> bool:
|
||||
if not self.axis_types:
|
||||
return False
|
||||
return all(t == AxisTypes.Auto or t == AxisTypes.Manual
|
||||
for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_manual(self) -> bool:
|
||||
return any_axis_types_match(self.axis_types, AxisTypes.Manual)
|
||||
@ -173,6 +180,8 @@ class _BaseMesh:
|
||||
|
||||
@functools.cached_property
|
||||
def axis_types(self):
|
||||
if not self.axis_names:
|
||||
return {}
|
||||
d = collections.defaultdict(list)
|
||||
for n, t in safe_zip(self.axis_names, self._axis_types_tuple):
|
||||
d[t].append(n)
|
||||
|
@ -401,7 +401,7 @@ def shaped_array_ref(
|
||||
shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef:
|
||||
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type))
|
||||
|
||||
def _shard_ref(mesh, names, ref_aval: AbstractRef):
|
||||
def _shard_ref(mesh, auto, names, ref_aval: AbstractRef):
|
||||
del mesh
|
||||
if names:
|
||||
# Can't actually shard a ref, can only close over it.
|
||||
|
@ -479,9 +479,27 @@ shard_map_p = ShardMapPrimitive('shard_map')
|
||||
|
||||
# Staging
|
||||
|
||||
def _as_manual_mesh(mesh):
|
||||
@util.cache(max_size=256, trace_context_in_key=True)
|
||||
def _as_manual_mesh(mesh, auto: frozenset):
|
||||
manual_axes = tuple(set(mesh.axis_names) - auto)
|
||||
cur_mesh = get_abstract_mesh()
|
||||
if cur_mesh.empty:
|
||||
auto_axes = tuple(auto)
|
||||
explicit_axes = ()
|
||||
else:
|
||||
explicit_axes, auto_axes = [], [] # type: ignore
|
||||
for a in auto:
|
||||
if cur_mesh._name_to_type[a] == AxisTypes.Auto:
|
||||
auto_axes.append(a)
|
||||
else:
|
||||
assert cur_mesh._name_to_type[a] == AxisTypes.Explicit
|
||||
explicit_axes.append(a)
|
||||
explicit_axes, auto_axes = tuple(explicit_axes), tuple(auto_axes) # type: ignore
|
||||
return AbstractMesh(
|
||||
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
|
||||
mesh.shape_tuple,
|
||||
axis_types={
|
||||
AxisTypes.Manual: manual_axes, AxisTypes.Auto: auto_axes,
|
||||
AxisTypes.Explicit: explicit_axes})
|
||||
|
||||
def _extend_axis_env(mesh, auto):
|
||||
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
|
||||
@ -498,10 +516,9 @@ def _shard_map_staging(
|
||||
) -> Sequence[pe.DynamicJaxprTracer]:
|
||||
in_tracers = map(trace.to_jaxpr_tracer, in_tracers)
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
manual_mesh = _as_manual_mesh(mesh)
|
||||
with (_extend_axis_env(mesh, auto),
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals)
|
||||
manual_mesh = _as_manual_mesh(mesh, auto)
|
||||
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
|
||||
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
@ -517,8 +534,7 @@ def _shard_map_staging(
|
||||
constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts))
|
||||
outvars = map(trace.makevar, out_tracers)
|
||||
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
|
||||
with (_extend_axis_env(mesh, auto),
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
with _extend_axis_env(mesh, auto), set_abstract_mesh(manual_mesh):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
params = dict(mesh=mesh, in_names=in_names_staged,
|
||||
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
|
||||
@ -534,10 +550,10 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return aval
|
||||
|
||||
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
def _shard_aval(mesh: Mesh, auto, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if type(aval) in core.shard_aval_handlers:
|
||||
return core.shard_aval_handlers[type(aval)](mesh, names, aval)
|
||||
return core.shard_aval_handlers[type(aval)](mesh, auto, names, aval)
|
||||
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
|
||||
|
||||
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
@ -547,14 +563,13 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
|
||||
|
||||
def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames,
|
||||
aval: core.AbstractValue) -> core.AbstractValue:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape))
|
||||
new_mesh = AbstractMesh(
|
||||
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
|
||||
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
|
||||
manual_mesh = _as_manual_mesh(mesh, auto)
|
||||
new_sharding = NamedSharding(manual_mesh, aval.sharding.spec)
|
||||
return aval.update(shape=new_shape, sharding=new_sharding)
|
||||
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
|
||||
|
||||
@ -563,10 +578,27 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape))
|
||||
spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim)
|
||||
names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim)
|
||||
if aval.ndim == 0:
|
||||
out_spec = names_spec
|
||||
else:
|
||||
out_spec = [] # type: ignore
|
||||
for name_s, aval_s in zip(names_spec, aval.sharding.spec):
|
||||
if name_s and not aval_s:
|
||||
out_spec.append(name_s)
|
||||
elif aval_s and not name_s:
|
||||
out_spec.append(aval_s)
|
||||
elif not name_s and not aval_s:
|
||||
out_spec.append(None)
|
||||
else:
|
||||
assert name_s and aval_s
|
||||
name_s = name_s if isinstance(name_s, tuple) else (name_s,)
|
||||
aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,)
|
||||
out_spec.append(name_s + aval_s)
|
||||
out_spec = PartitionSpec(*out_spec)
|
||||
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
|
||||
get_abstract_mesh())
|
||||
new_sharding = NamedSharding(new_mesh, spec)
|
||||
new_sharding = NamedSharding(new_mesh, out_spec)
|
||||
return aval.update(shape=new_shape, sharding=new_sharding)
|
||||
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
|
||||
|
||||
@ -578,7 +610,7 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
# TODO(mattjj,parkers): check auto
|
||||
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
|
||||
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
|
||||
if not core.typecompat(v.aval, _shard_aval(mesh, auto, in_name, x.aval)):
|
||||
raise core.JaxprTypeError("shard_map argument avals not compatible with "
|
||||
"jaxpr binder avals and in_names")
|
||||
with _extend_axis_env(mesh, auto):
|
||||
@ -702,9 +734,11 @@ def _shard_map_lowering_shardy(
|
||||
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
del check_rep, rewrite
|
||||
|
||||
if config.use_shardy_partitioner.value:
|
||||
return _shard_map_lowering_shardy(
|
||||
ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto)
|
||||
|
||||
in_avals_ = [v.aval for v in jaxpr.invars]
|
||||
out_avals_ = [x.aval for x in jaxpr.outvars]
|
||||
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
|
||||
@ -814,9 +848,8 @@ core.EvalTrace.process_shard_map = _shard_map_impl
|
||||
def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh):
|
||||
trace = ShardMapTrace(mesh, auto, check_rep, context_mesh)
|
||||
in_tracers = map(partial(ShardMapTracer, trace), reps, args)
|
||||
manual_mesh = _as_manual_mesh(mesh)
|
||||
with (core.set_current_trace(trace),
|
||||
_extend_axis_env(mesh, auto),
|
||||
manual_mesh = _as_manual_mesh(mesh, auto)
|
||||
with (core.set_current_trace(trace), _extend_axis_env(mesh, auto),
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
ans = f.call_wrapped(*in_tracers)
|
||||
outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans))
|
||||
@ -971,8 +1004,9 @@ class ShardMapTracer(core.Tracer):
|
||||
def aval(self):
|
||||
aval = core.get_aval(self.val)
|
||||
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh),
|
||||
out.sharding.spec) # pytype: disable=attribute-error
|
||||
new_sharding = NamedSharding(
|
||||
_as_manual_mesh(self._trace.mesh, self._trace.auto),
|
||||
out.sharding.spec) # pytype: disable=attribute-error
|
||||
return out.update(sharding=new_sharding)
|
||||
|
||||
def to_concrete_value(self):
|
||||
@ -1526,7 +1560,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
||||
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
||||
in_avals_sharded = map(partial(_shard_aval, mesh, auto), unk_in_names, in_avals)
|
||||
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
|
||||
f = _promote_scalar_residuals(f)
|
||||
f_known, aux = pe.partial_eval_wrapper_nounits(
|
||||
@ -1675,12 +1709,12 @@ def _shard_map_transpose(out_cts, *args,
|
||||
jaxpr: core.Jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite or dtypes.dtype(x) == dtypes.float0
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = tuple(x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval))
|
||||
for ns, x in zip(in_names, args))
|
||||
all_args, in_tree = tree_flatten((out_cts, args))
|
||||
|
||||
@ -1693,9 +1727,9 @@ def _shard_map_transpose(out_cts, *args,
|
||||
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
|
||||
)
|
||||
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite
|
||||
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
else x if rewrite
|
||||
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
return out
|
||||
|
||||
fun_trans = lu.wrap_init(fun_trans_callable,
|
||||
@ -1748,7 +1782,7 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
mesh = eqn.params['mesh']
|
||||
with (_extend_axis_env(mesh, auto),
|
||||
set_abstract_mesh(_as_manual_mesh(mesh))):
|
||||
set_abstract_mesh(_as_manual_mesh(mesh, auto))):
|
||||
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
||||
|
@ -40,7 +40,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.mesh import AbstractMesh
|
||||
from jax._src.mesh import AbstractMesh, AxisTypes
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
@ -1890,6 +1890,8 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
|
||||
{AxisTypes.Manual: ('i',), AxisTypes.Auto: ('j',)})
|
||||
x = jax.lax.with_sharding_constraint(
|
||||
x, jax.sharding.NamedSharding(mesh, P(None, 'j')))
|
||||
return x * x
|
||||
@ -1917,7 +1919,78 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
' replicated}}',
|
||||
f.lower(v).as_text('hlo'),
|
||||
)
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
self.assertAllClose(v * v, f(v), check_dtypes=False)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('i', 'j'))
|
||||
def test_partial_auto_explicit(self, mesh):
|
||||
def g(x):
|
||||
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
|
||||
{AxisTypes.Manual: ('i',), AxisTypes.Explicit: ('j',)})
|
||||
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
|
||||
out = x * x
|
||||
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
|
||||
return out
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x = shard_map(g, mesh,
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
auto=frozenset({'j'}))(x)
|
||||
self.assertEqual(x.aval.sharding.spec, P('i', 'j'))
|
||||
return x
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
|
||||
out = f(v)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertAllClose(v * v, out, check_dtypes=False)
|
||||
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn(
|
||||
'sdy.sharding_constraint %1 <@mesh, [{}, {"j"}]>',
|
||||
f.lower(v).as_text(),
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
'mhlo.sharding = "{devices=[1,2,2]<=[2,2]T(1,0) last_tile_dims={manual}}"}',
|
||||
f.lower(v).as_text(),
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
def h(x):
|
||||
return jnp.sum(f(x))
|
||||
|
||||
jax.grad(h)(v) # doesn't crash
|
||||
jax.jit(jax.grad(h))(v) # doesn't crash
|
||||
|
||||
@jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l'))
|
||||
def test_partial_auto_explicit_multi_explicit(self, mesh):
|
||||
def g(x):
|
||||
self.assertDictEqual(x.aval.sharding.mesh.axis_types,
|
||||
{AxisTypes.Manual: ('i', 'j'),
|
||||
AxisTypes.Explicit: ('k', 'l')})
|
||||
self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l'))
|
||||
out = x.T
|
||||
self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None))
|
||||
return out
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x = shard_map(g, mesh,
|
||||
in_specs=P('i', 'j', None, None),
|
||||
out_specs=P('i', 'j', None, None),
|
||||
auto=frozenset({'k', 'l'}))(x)
|
||||
self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None))
|
||||
return x
|
||||
|
||||
v = jnp.arange(64.).reshape(4, 2, 2, 4)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j', 'k', 'l')))
|
||||
|
||||
out = f(v)
|
||||
self.assertEqual(
|
||||
out.sharding, NamedSharding(mesh, P(('i', 'l'), ('j', 'k'), None, None)))
|
||||
|
||||
def test_partial_auto_propagate_through(self):
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user