mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Change the internals of with_sharding_constraint
to use the sharding instances.
PiperOrigin-RevId: 459600050
This commit is contained in:
parent
fe1bbd59dd
commit
7da733f94b
@ -2719,14 +2719,12 @@ tf_impl_with_avals[pjit.pjit_p] = _pjit
|
||||
|
||||
|
||||
def _pjit_sharding_constraint(arg: TfVal, *,
|
||||
axis_resources: pjit.ParsedPartitionSpec,
|
||||
sharding: sharding.MeshPspecSharding,
|
||||
resource_env: maps.ResourceEnv,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray,
|
||||
**kwargs) -> TfVal:
|
||||
ms = sharding.MeshPspecSharding._from_parsed_pspec(
|
||||
resource_env.physical_mesh, axis_resources)
|
||||
return _shard_value(arg, _in_avals[0], ms)
|
||||
return _shard_value(arg, _in_avals[0], sharding)
|
||||
|
||||
|
||||
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
|
||||
|
@ -1858,6 +1858,7 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources):
|
||||
|
||||
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
from jax.experimental.pjit import sharding_constraint_p, ParsedPartitionSpec
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
return jaxpr.map_jaxpr(rec)
|
||||
@ -1871,10 +1872,13 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
new_eqns.append(eqn.replace(
|
||||
outvars=tmp_outvars, params=dict(eqn.params, **new_jaxpr_params)))
|
||||
for outvar, tmpvar in zip(eqn.outvars, tmp_outvars):
|
||||
new_eqns.append(core.JaxprEqn([tmpvar], [outvar], sharding_constraint_p,
|
||||
dict(resource_env=resource_env, axis_resources=ParsedPartitionSpec((), ())),
|
||||
set(),
|
||||
eqn.source_info))
|
||||
new_eqns.append(core.JaxprEqn(
|
||||
[tmpvar], [outvar], sharding_constraint_p,
|
||||
dict(resource_env=resource_env,
|
||||
sharding=MeshPspecSharding._from_parsed_pspec(
|
||||
resource_env.physical_mesh, ParsedPartitionSpec((), ()))),
|
||||
set(),
|
||||
eqn.source_info))
|
||||
return jaxpr.replace(eqns=new_eqns)
|
||||
|
||||
def _flatten_axes(what, tree, axes, tupled_args):
|
||||
|
@ -22,7 +22,7 @@ from functools import partial
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax.experimental.array import Array
|
||||
from jax.experimental import sharding
|
||||
from jax.experimental.sharding import MeshPspecSharding, Sharding
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
@ -275,7 +275,7 @@ def pjit(fun: Callable,
|
||||
# unspecified. For `AUTO` sharding, it can only be used with
|
||||
# MeshPspecSharding.
|
||||
if not _is_unspecified(out_axis_resources):
|
||||
if not all(isinstance(s, sharding.Sharding) for s in tree_flatten(out_axis_resources)[0]):
|
||||
if not all(isinstance(s, Sharding) for s in tree_flatten(out_axis_resources)[0]):
|
||||
raise ValueError('When `config.jax_array` flag is enabled, '
|
||||
'out_axis_resources should contain instances of '
|
||||
'`Sharding`.')
|
||||
@ -411,7 +411,7 @@ def _create_mesh_pspec_sharding(mesh, x):
|
||||
return x
|
||||
if _is_from_gda(x):
|
||||
return x
|
||||
return sharding.MeshPspecSharding._from_parsed_pspec(mesh, x)
|
||||
return MeshPspecSharding._from_parsed_pspec(mesh, x)
|
||||
|
||||
|
||||
def flatten_axis_resources(what, tree, shardings, tupled_args):
|
||||
@ -566,7 +566,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
|
||||
# TODO(yashkatariya): Replace this with shape check against sharding which
|
||||
# uses OpSharding.tile_assignment_dimension.
|
||||
def pjit_check_aval_sharding(
|
||||
sharding: sharding.MeshPspecSharding, aval, what_aval: str,
|
||||
sharding: MeshPspecSharding, aval, what_aval: str,
|
||||
allow_uneven_sharding: bool, local: bool = False):
|
||||
if local:
|
||||
m = sharding.mesh.local_mesh
|
||||
@ -907,7 +907,7 @@ def _pjit_batcher_for_sharding(s, dim, val):
|
||||
# `sync` attribute changes. To make sure we preserve that, we need to pass
|
||||
# that parsed partition spec when created the sharding instance.
|
||||
# Inferring the `PartitiionSpec` from that is easy as done in the classmethod.
|
||||
return sharding.MeshPspecSharding._from_parsed_pspec(s.mesh, parsed_pspec).normalize()
|
||||
return MeshPspecSharding._from_parsed_pspec(s.mesh, parsed_pspec).normalize()
|
||||
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
@ -970,7 +970,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
# shardings are that even in the `Array` codepath.
|
||||
out_shardings=(
|
||||
keep_where(out_shardings, known_outs) +
|
||||
(sharding.MeshPspecSharding(mesh, pxla.PartitionSpec(None)),) * num_residuals),
|
||||
(MeshPspecSharding(mesh, pxla.PartitionSpec(None)),) * num_residuals),
|
||||
resource_env=resource_env,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
name=name,
|
||||
@ -990,7 +990,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
residual_specs = tuple(output_ppspec[-num_residuals:])
|
||||
else:
|
||||
residual_specs = ()
|
||||
residual_sharding = tuple(sharding.MeshPspecSharding._from_parsed_pspec(mesh, r)
|
||||
residual_sharding = tuple(MeshPspecSharding._from_parsed_pspec(mesh, r)
|
||||
for r in residual_specs)
|
||||
known_params['out_shardings'] = (
|
||||
keep_where(out_shardings, known_outs) + residual_sharding)
|
||||
@ -1128,22 +1128,27 @@ pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit
|
||||
|
||||
def with_sharding_constraint(x, axis_resources):
|
||||
x_flat, tree = tree_flatten(x)
|
||||
parsed_axis_resources, entries, _, _ = _prepare_axis_resources(
|
||||
parsed_axis_resources, _, _, _ = _prepare_axis_resources(
|
||||
axis_resources, "axis_resources", allow_unconstrained_dims=True)
|
||||
axis_resources_flat = tuple(
|
||||
flatten_axes("with_sharding_constraint axis_resources",
|
||||
tree, parsed_axis_resources))
|
||||
resource_env = pxla.thread_resources.env
|
||||
mesh = resource_env.physical_mesh
|
||||
_check_shapes_against_resources(
|
||||
"with_sharding_constraint arguments",
|
||||
mesh.is_multi_process, mesh.shape,
|
||||
x_flat, axis_resources_flat, allow_uneven_sharding=True)
|
||||
outs = [sharding_constraint_p.bind(y, axis_resources=r, resource_env=resource_env)
|
||||
for y, r in safe_zip(x_flat, axis_resources_flat)]
|
||||
|
||||
sharding_flat = [MeshPspecSharding._from_parsed_pspec(mesh, a)
|
||||
for a in axis_resources_flat]
|
||||
|
||||
for xf, i in safe_zip(x_flat, sharding_flat):
|
||||
pjit_check_aval_sharding(i, xf, "with_sharding_constraint arguments",
|
||||
allow_uneven_sharding=True)
|
||||
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=i.normalize(),
|
||||
resource_env=resource_env)
|
||||
for xf, i in safe_zip(x_flat, sharding_flat)]
|
||||
return tree_unflatten(tree, outs)
|
||||
|
||||
def _sharding_constraint_impl(x, axis_resources, resource_env):
|
||||
def _sharding_constraint_impl(x, sharding, resource_env):
|
||||
# TODO(skye): can we also prevent this from being called in other
|
||||
# non-pjit contexts? (e.g. pmap, control flow)
|
||||
raise NotImplementedError(
|
||||
@ -1153,48 +1158,46 @@ sharding_constraint_p = core.Primitive("sharding_constraint")
|
||||
sharding_constraint_p.def_impl(_sharding_constraint_impl)
|
||||
sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
|
||||
ad.deflinear2(sharding_constraint_p,
|
||||
lambda ct, _, axis_resources, resource_env: (
|
||||
lambda ct, _, sharding, resource_env: (
|
||||
sharding_constraint_p.bind(
|
||||
ct, axis_resources=axis_resources, resource_env=resource_env),))
|
||||
ct, sharding=sharding,
|
||||
resource_env=resource_env),))
|
||||
|
||||
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, axis_resources,
|
||||
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
|
||||
resource_env):
|
||||
aval, = ctx.avals_in
|
||||
mesh = resource_env.physical_mesh
|
||||
assert isinstance(sharding, MeshPspecSharding)
|
||||
return [
|
||||
mlir.wrap_with_sharding_op(
|
||||
x_node,
|
||||
get_aval_sharding_proto(
|
||||
aval,
|
||||
axis_resources,
|
||||
mesh,
|
||||
ctx.module_context.axis_context,
|
||||
allow_uneven_axes=True),
|
||||
unspecified_dims=get_unconstrained_dims(axis_resources))
|
||||
sharding._to_xla_op_sharding(
|
||||
aval.ndim, axis_ctx=ctx.module_context.axis_context),
|
||||
unspecified_dims=get_unconstrained_dims(sharding))
|
||||
]
|
||||
mlir.register_lowering(sharding_constraint_p,
|
||||
_sharding_constraint_mhlo_lowering)
|
||||
|
||||
|
||||
def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type, vals_in, dims_in,
|
||||
axis_resources, resource_env):
|
||||
def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type,
|
||||
vals_in, dims_in, sharding, resource_env):
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
# None means unconstrained in ParsedPartitionSpec
|
||||
new_parts = (axis_name,) if insert_axis else None
|
||||
y = sharding_constraint_p.bind(
|
||||
x,
|
||||
axis_resources=axis_resources.insert_axis_partitions(d, new_parts),
|
||||
sharding=_pjit_batcher_for_sharding(sharding, d, new_parts),
|
||||
resource_env=resource_env)
|
||||
return y, d
|
||||
batching.axis_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, False)
|
||||
pxla.spmd_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, True)
|
||||
|
||||
def _resource_typing_sharding_constraint(avals, params, source_info, resource_env, named_axis_resources):
|
||||
def _resource_typing_sharding_constraint(avals, params, source_info,
|
||||
resource_env, named_axis_resources):
|
||||
aval, = avals
|
||||
_check_resources_against_named_axes(
|
||||
"with_sharding_constraint input", aval,
|
||||
params['axis_resources'], named_axis_resources)
|
||||
params['sharding']._parsed_pspec, named_axis_resources)
|
||||
pxla.custom_resource_typing_rules[sharding_constraint_p] = \
|
||||
_resource_typing_sharding_constraint
|
||||
|
||||
@ -1212,24 +1215,9 @@ def get_array_mapping(
|
||||
if axes is not None for axis in axes)
|
||||
|
||||
|
||||
def get_aval_sharding_proto(aval: core.AbstractValue,
|
||||
axis_resources: ParsedPartitionSpec,
|
||||
mesh: maps.Mesh,
|
||||
axis_ctx: Optional[mlir.SPMDAxisContext] = None,
|
||||
allow_uneven_axes: bool = False) -> xc.OpSharding:
|
||||
array_mapping = get_array_mapping(axis_resources)
|
||||
sharding_spec = pxla.mesh_sharding_specs(
|
||||
mesh.shape, mesh.axis_names, allow_uneven_axes=allow_uneven_axes)(aval, array_mapping)
|
||||
special_axes = {}
|
||||
if axis_ctx is not None:
|
||||
axis_names = mesh.axis_names
|
||||
for manual_axis in axis_ctx.manual_axes:
|
||||
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||
|
||||
|
||||
def get_unconstrained_dims(axis_resources: ParsedPartitionSpec):
|
||||
return {i for i, axes in enumerate(axis_resources) if axes is None}
|
||||
def get_unconstrained_dims(sharding: MeshPspecSharding):
|
||||
return {i for i, axes in enumerate(sharding._parsed_pspec)
|
||||
if axes is None}
|
||||
|
||||
|
||||
def global_to_local(positional_semantics, avals, shardings):
|
||||
@ -1269,7 +1257,7 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics:
|
||||
|
||||
|
||||
def _maybe_replace_from_gda_with_pspec(
|
||||
in_sharding_flat, arg) -> sharding.MeshPspecSharding:
|
||||
in_sharding_flat, arg) -> MeshPspecSharding:
|
||||
if isinstance(arg, GDA):
|
||||
gda_cpspec = CanonicalizedParsedPartitionSpec(
|
||||
ParsedPartitionSpec.from_user_input(arg.mesh_axes, arg_name="GDA spec"))
|
||||
@ -1281,7 +1269,7 @@ def _maybe_replace_from_gda_with_pspec(
|
||||
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. "
|
||||
f"Got GDA spec: {gda_cpspec.user_spec} and "
|
||||
f"pjit spec: {in_sharding_flat.spec} for GDA: {arg}")
|
||||
return sharding.MeshPspecSharding._from_parsed_pspec(arg.mesh, gda_cpspec)
|
||||
return MeshPspecSharding._from_parsed_pspec(arg.mesh, gda_cpspec)
|
||||
return in_sharding_flat
|
||||
|
||||
|
||||
|
@ -532,8 +532,10 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
jaxpr = jax.make_jaxpr(jax.vmap(f))(x)
|
||||
pjit_eqn, = jaxpr.eqns
|
||||
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
|
||||
self.assertEqual(constraint_eqn.params['axis_resources'].partitions, (None, ('x',)))
|
||||
self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
|
||||
self.assertEqual(constraint_eqn.params['sharding']._parsed_pspec.partitions,
|
||||
(None, ('x',)))
|
||||
self.assertEqual(constraint_eqn.params['sharding']._parsed_pspec.sync,
|
||||
SpecSync.DIM_PERMUTE)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testShardingInXMap(self):
|
||||
@ -1757,7 +1759,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
dims = 5
|
||||
aval = jax.core.ShapedArray((len(devices),) * dims, jnp.float32)
|
||||
def roundtrip(spec):
|
||||
op_sharding = pjit_lib.get_aval_sharding_proto(aval, spec, mesh)
|
||||
op_sharding = MeshPspecSharding(mesh, spec)._to_xla_op_sharding(aval.ndim)
|
||||
parsed_spec = pjit_lib.parse_flatten_op_sharding(op_sharding, mesh)[0].partitions
|
||||
self.assertEqual(parsed_spec[:len(spec)], spec)
|
||||
self.assertEqual(parsed_spec[len(spec):], ((),) * (len(parsed_spec) - len(spec)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user