Change the internals of with_sharding_constraint to use the sharding instances.

PiperOrigin-RevId: 459600050
This commit is contained in:
Yash Katariya 2022-07-07 14:21:38 -07:00 committed by jax authors
parent fe1bbd59dd
commit 7da733f94b
4 changed files with 54 additions and 62 deletions

View File

@ -2719,14 +2719,12 @@ tf_impl_with_avals[pjit.pjit_p] = _pjit
def _pjit_sharding_constraint(arg: TfVal, *,
axis_resources: pjit.ParsedPartitionSpec,
sharding: sharding.MeshPspecSharding,
resource_env: maps.ResourceEnv,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> TfVal:
ms = sharding.MeshPspecSharding._from_parsed_pspec(
resource_env.physical_mesh, axis_resources)
return _shard_value(arg, _in_avals[0], ms)
return _shard_value(arg, _in_avals[0], sharding)
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint

View File

@ -1858,6 +1858,7 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources):
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
from jax.experimental.pjit import sharding_constraint_p, ParsedPartitionSpec
from jax.experimental.sharding import MeshPspecSharding
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
if isinstance(jaxpr, core.ClosedJaxpr):
return jaxpr.map_jaxpr(rec)
@ -1871,10 +1872,13 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
new_eqns.append(eqn.replace(
outvars=tmp_outvars, params=dict(eqn.params, **new_jaxpr_params)))
for outvar, tmpvar in zip(eqn.outvars, tmp_outvars):
new_eqns.append(core.JaxprEqn([tmpvar], [outvar], sharding_constraint_p,
dict(resource_env=resource_env, axis_resources=ParsedPartitionSpec((), ())),
set(),
eqn.source_info))
new_eqns.append(core.JaxprEqn(
[tmpvar], [outvar], sharding_constraint_p,
dict(resource_env=resource_env,
sharding=MeshPspecSharding._from_parsed_pspec(
resource_env.physical_mesh, ParsedPartitionSpec((), ()))),
set(),
eqn.source_info))
return jaxpr.replace(eqns=new_eqns)
def _flatten_axes(what, tree, axes, tupled_args):

View File

@ -22,7 +22,7 @@ from functools import partial
from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax.experimental.array import Array
from jax.experimental import sharding
from jax.experimental.sharding import MeshPspecSharding, Sharding
from jax import core
from jax import linear_util as lu
from jax import stages
@ -275,7 +275,7 @@ def pjit(fun: Callable,
# unspecified. For `AUTO` sharding, it can only be used with
# MeshPspecSharding.
if not _is_unspecified(out_axis_resources):
if not all(isinstance(s, sharding.Sharding) for s in tree_flatten(out_axis_resources)[0]):
if not all(isinstance(s, Sharding) for s in tree_flatten(out_axis_resources)[0]):
raise ValueError('When `config.jax_array` flag is enabled, '
'out_axis_resources should contain instances of '
'`Sharding`.')
@ -411,7 +411,7 @@ def _create_mesh_pspec_sharding(mesh, x):
return x
if _is_from_gda(x):
return x
return sharding.MeshPspecSharding._from_parsed_pspec(mesh, x)
return MeshPspecSharding._from_parsed_pspec(mesh, x)
def flatten_axis_resources(what, tree, shardings, tupled_args):
@ -566,7 +566,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
# TODO(yashkatariya): Replace this with shape check against sharding which
# uses OpSharding.tile_assignment_dimension.
def pjit_check_aval_sharding(
sharding: sharding.MeshPspecSharding, aval, what_aval: str,
sharding: MeshPspecSharding, aval, what_aval: str,
allow_uneven_sharding: bool, local: bool = False):
if local:
m = sharding.mesh.local_mesh
@ -907,7 +907,7 @@ def _pjit_batcher_for_sharding(s, dim, val):
# `sync` attribute changes. To make sure we preserve that, we need to pass
# that parsed partition spec when created the sharding instance.
# Inferring the `PartitiionSpec` from that is easy as done in the classmethod.
return sharding.MeshPspecSharding._from_parsed_pspec(s.mesh, parsed_pspec).normalize()
return MeshPspecSharding._from_parsed_pspec(s.mesh, parsed_pspec).normalize()
def _pjit_jvp(primals_in, tangents_in,
@ -970,7 +970,7 @@ def _pjit_partial_eval(trace, *in_tracers,
# shardings are that even in the `Array` codepath.
out_shardings=(
keep_where(out_shardings, known_outs) +
(sharding.MeshPspecSharding(mesh, pxla.PartitionSpec(None)),) * num_residuals),
(MeshPspecSharding(mesh, pxla.PartitionSpec(None)),) * num_residuals),
resource_env=resource_env,
donated_invars=keep_where(donated_invars, known_ins),
name=name,
@ -990,7 +990,7 @@ def _pjit_partial_eval(trace, *in_tracers,
residual_specs = tuple(output_ppspec[-num_residuals:])
else:
residual_specs = ()
residual_sharding = tuple(sharding.MeshPspecSharding._from_parsed_pspec(mesh, r)
residual_sharding = tuple(MeshPspecSharding._from_parsed_pspec(mesh, r)
for r in residual_specs)
known_params['out_shardings'] = (
keep_where(out_shardings, known_outs) + residual_sharding)
@ -1128,22 +1128,27 @@ pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit
def with_sharding_constraint(x, axis_resources):
x_flat, tree = tree_flatten(x)
parsed_axis_resources, entries, _, _ = _prepare_axis_resources(
parsed_axis_resources, _, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True)
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint axis_resources",
tree, parsed_axis_resources))
resource_env = pxla.thread_resources.env
mesh = resource_env.physical_mesh
_check_shapes_against_resources(
"with_sharding_constraint arguments",
mesh.is_multi_process, mesh.shape,
x_flat, axis_resources_flat, allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(y, axis_resources=r, resource_env=resource_env)
for y, r in safe_zip(x_flat, axis_resources_flat)]
sharding_flat = [MeshPspecSharding._from_parsed_pspec(mesh, a)
for a in axis_resources_flat]
for xf, i in safe_zip(x_flat, sharding_flat):
pjit_check_aval_sharding(i, xf, "with_sharding_constraint arguments",
allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(xf, sharding=i.normalize(),
resource_env=resource_env)
for xf, i in safe_zip(x_flat, sharding_flat)]
return tree_unflatten(tree, outs)
def _sharding_constraint_impl(x, axis_resources, resource_env):
def _sharding_constraint_impl(x, sharding, resource_env):
# TODO(skye): can we also prevent this from being called in other
# non-pjit contexts? (e.g. pmap, control flow)
raise NotImplementedError(
@ -1153,48 +1158,46 @@ sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
ad.deflinear2(sharding_constraint_p,
lambda ct, _, axis_resources, resource_env: (
lambda ct, _, sharding, resource_env: (
sharding_constraint_p.bind(
ct, axis_resources=axis_resources, resource_env=resource_env),))
ct, sharding=sharding,
resource_env=resource_env),))
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, axis_resources,
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
resource_env):
aval, = ctx.avals_in
mesh = resource_env.physical_mesh
assert isinstance(sharding, MeshPspecSharding)
return [
mlir.wrap_with_sharding_op(
x_node,
get_aval_sharding_proto(
aval,
axis_resources,
mesh,
ctx.module_context.axis_context,
allow_uneven_axes=True),
unspecified_dims=get_unconstrained_dims(axis_resources))
sharding._to_xla_op_sharding(
aval.ndim, axis_ctx=ctx.module_context.axis_context),
unspecified_dims=get_unconstrained_dims(sharding))
]
mlir.register_lowering(sharding_constraint_p,
_sharding_constraint_mhlo_lowering)
def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type, vals_in, dims_in,
axis_resources, resource_env):
def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type,
vals_in, dims_in, sharding, resource_env):
x, = vals_in
d, = dims_in
# None means unconstrained in ParsedPartitionSpec
new_parts = (axis_name,) if insert_axis else None
y = sharding_constraint_p.bind(
x,
axis_resources=axis_resources.insert_axis_partitions(d, new_parts),
sharding=_pjit_batcher_for_sharding(sharding, d, new_parts),
resource_env=resource_env)
return y, d
batching.axis_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, False)
pxla.spmd_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, True)
def _resource_typing_sharding_constraint(avals, params, source_info, resource_env, named_axis_resources):
def _resource_typing_sharding_constraint(avals, params, source_info,
resource_env, named_axis_resources):
aval, = avals
_check_resources_against_named_axes(
"with_sharding_constraint input", aval,
params['axis_resources'], named_axis_resources)
params['sharding']._parsed_pspec, named_axis_resources)
pxla.custom_resource_typing_rules[sharding_constraint_p] = \
_resource_typing_sharding_constraint
@ -1212,24 +1215,9 @@ def get_array_mapping(
if axes is not None for axis in axes)
def get_aval_sharding_proto(aval: core.AbstractValue,
axis_resources: ParsedPartitionSpec,
mesh: maps.Mesh,
axis_ctx: Optional[mlir.SPMDAxisContext] = None,
allow_uneven_axes: bool = False) -> xc.OpSharding:
array_mapping = get_array_mapping(axis_resources)
sharding_spec = pxla.mesh_sharding_specs(
mesh.shape, mesh.axis_names, allow_uneven_axes=allow_uneven_axes)(aval, array_mapping)
special_axes = {}
if axis_ctx is not None:
axis_names = mesh.axis_names
for manual_axis in axis_ctx.manual_axes:
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
return sharding_spec.sharding_proto(special_axes=special_axes)
def get_unconstrained_dims(axis_resources: ParsedPartitionSpec):
return {i for i, axes in enumerate(axis_resources) if axes is None}
def get_unconstrained_dims(sharding: MeshPspecSharding):
return {i for i, axes in enumerate(sharding._parsed_pspec)
if axes is None}
def global_to_local(positional_semantics, avals, shardings):
@ -1269,7 +1257,7 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics:
def _maybe_replace_from_gda_with_pspec(
in_sharding_flat, arg) -> sharding.MeshPspecSharding:
in_sharding_flat, arg) -> MeshPspecSharding:
if isinstance(arg, GDA):
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(arg.mesh_axes, arg_name="GDA spec"))
@ -1281,7 +1269,7 @@ def _maybe_replace_from_gda_with_pspec(
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. "
f"Got GDA spec: {gda_cpspec.user_spec} and "
f"pjit spec: {in_sharding_flat.spec} for GDA: {arg}")
return sharding.MeshPspecSharding._from_parsed_pspec(arg.mesh, gda_cpspec)
return MeshPspecSharding._from_parsed_pspec(arg.mesh, gda_cpspec)
return in_sharding_flat

View File

@ -532,8 +532,10 @@ class PJitTest(jtu.BufferDonationTestCase):
jaxpr = jax.make_jaxpr(jax.vmap(f))(x)
pjit_eqn, = jaxpr.eqns
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
self.assertEqual(constraint_eqn.params['axis_resources'].partitions, (None, ('x',)))
self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
self.assertEqual(constraint_eqn.params['sharding']._parsed_pspec.partitions,
(None, ('x',)))
self.assertEqual(constraint_eqn.params['sharding']._parsed_pspec.sync,
SpecSync.DIM_PERMUTE)
@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingInXMap(self):
@ -1757,7 +1759,7 @@ class UtilTest(jtu.JaxTestCase):
dims = 5
aval = jax.core.ShapedArray((len(devices),) * dims, jnp.float32)
def roundtrip(spec):
op_sharding = pjit_lib.get_aval_sharding_proto(aval, spec, mesh)
op_sharding = MeshPspecSharding(mesh, spec)._to_xla_op_sharding(aval.ndim)
parsed_spec = pjit_lib.parse_flatten_op_sharding(op_sharding, mesh)[0].partitions
self.assertEqual(parsed_spec[:len(spec)], spec)
self.assertEqual(parsed_spec[len(spec):], ((),) * (len(parsed_spec) - len(spec)))