Allow P.UNCONSTRAINED in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.

This also simplifies the `impl` rule of `sharding_cast`.

PiperOrigin-RevId: 707349491
This commit is contained in:
Yash Katariya 2024-12-17 19:17:48 -08:00 committed by jax authors
parent b56dc63160
commit e854f1657a
5 changed files with 100 additions and 28 deletions

View File

@ -49,7 +49,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO
from jax._src.sharding_impls import AUTO, NamedSharding
from jax._src.partition_spec import UnconstrainedSingleton
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
@ -1055,6 +1055,21 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
assert isinstance(s, JSharding)
return s.memory_kind
def contains_unconstrained(s):
return isinstance(s, NamedSharding) and None in s._parsed_pspec
def all_unconstrained(s, aval):
if isinstance(s, NamedSharding):
if aval.ndim != len(s._parsed_pspec):
return False
return all(p is None for p in s._parsed_pspec)
return False
def _get_unconstrained_dimensions(s, aval):
us = contains_unconstrained(s)
return (us, all_unconstrained(s, aval),
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
def lower_jaxpr_to_module(
module_name: str,
@ -1114,7 +1129,8 @@ def lower_jaxpr_to_module(
f"only {platforms_with_donation} support donation")
if (num_partitions > 1 and
(result_shardings is None or
all(s is None or isinstance(s, AUTO) for s in result_shardings))):
all(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
for s in result_shardings))):
xla_donated_args = donated_args
donated_args = [False] * len(donated_args)
if xla_donated_args is None:
@ -1448,7 +1464,8 @@ def lower_jaxpr_to_fun(
ir_arg_memory_kinds = None
if arg_memory_kinds is not None:
ir_arg_memory_kinds = util.flatten(
[[mk] * len_ir_types(types) for mk, types in zip(arg_memory_kinds, input_types)])
[[mk] * len_ir_types(types)
for mk, types in zip(arg_memory_kinds, input_types)])
ir_arg_layouts = None
if arg_layouts is not None:
@ -1459,13 +1476,18 @@ def lower_jaxpr_to_fun(
ir_donated_args = None
if xla_donated_args is not None:
ir_donated_args = util.flatten(
[[is_donated] * len_ir_types(types) for is_donated, types in zip(xla_donated_args, input_types)])
[[is_donated] * len_ir_types(types)
for is_donated, types in zip(xla_donated_args, input_types)])
ir_result_shardings = None
unconstrained_shardings = None
if result_shardings is not None:
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
for a, s, types in zip(output_avals, result_shardings, output_types)])
unconstrained_shardings = util.flatten(
[[_get_unconstrained_dimensions(s, a)] * len_ir_types(types)
for a, s, types in zip(output_avals, result_shardings, output_types)])
ir_result_memory_kinds = None
custom_call_ir_result_memory_kinds = None
@ -1580,8 +1602,9 @@ def lower_jaxpr_to_fun(
attrs['jax.result_info'] = ir.StringAttr.get(name_)
if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
unconstrained_shardings): # type: ignore
if sharding is not None and not us[0]:
if config.use_shardy_partitioner.value:
attrs["sdy.sharding"] = get_sharding_attr(sharding)
else:
@ -1658,6 +1681,15 @@ def lower_jaxpr_to_fun(
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
if ir_result_shardings is not None:
flat_outputs = [
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
unspecified_dims=us[2])
if us[0] and not us[1] else o
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_shardings) # type: ignore
]
# Insert a custom call if output is on host because XLA needs that to do the
# transfer.
if custom_call_ir_result_memory_kinds is not None and name == "main":

View File

@ -62,7 +62,7 @@ from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding as JSharding
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.sharding_impls import (
@ -2162,10 +2162,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
out = []
for s, a in zip(shardings, avals):
# Remove the `UnconstrainedSingleton` logic after UNCONSTRAINED is supported
# in out_shardings at top level jit.
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
@ -2794,6 +2791,11 @@ def _maybe_get_and_check_out_shardings(
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
elif mlir.contains_unconstrained(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error
@ -2909,8 +2911,9 @@ class UnloadedMeshExecutable:
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
for i in in_shardings)
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
for o in out_shardings)
allow_prop_to_outputs = tuple(
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
for o in out_shardings)
mesh = None
if auto_spmd_lowering:

View File

@ -353,6 +353,22 @@ class Mesh(contextlib.ContextDecorator):
def with_axis_types(self, new_axis_types) -> Mesh:
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)
@functools.cached_property
def _are_all_axes_collective(self) -> bool:
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _are_all_axes_auto(self) -> bool:
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_collective(self) -> bool:
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_auto(self) -> bool:
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))

View File

@ -2675,13 +2675,6 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
# -------------------- sharding_cast ---------------------------
def _check_mesh_shape_same(src_sharding, dst_sharding, aval):
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
raise ValueError(
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
' match the mesh shape of the target sharding'
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
def sharding_cast(xs, shardings):
if isinstance(shardings, NamedSharding):
return tree_map(lambda x: sharding_cast_p.bind(
@ -2695,17 +2688,17 @@ def sharding_cast(xs, shardings):
sharding_cast_p = core.Primitive('sharding_cast')
def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding):
_check_mesh_shape_same(src_sharding, dst_sharding, aval)
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
raise ValueError(
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
' match the mesh shape of the target sharding'
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
return aval.update(sharding=dst_sharding)
sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval)
def _sharding_cast_impl(x, src_sharding, dst_sharding):
aval = shaped_abstractify(x)
_check_mesh_shape_same(x.sharding, dst_sharding, aval)
new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types)
concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec)
# TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)`
return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x)
return dispatch.apply_primitive(sharding_cast_p, x, src_sharding=src_sharding,
dst_sharding=dst_sharding)
sharding_cast_p.def_impl(_sharding_cast_impl)
def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding):

View File

@ -4680,6 +4680,34 @@ class ArrayPjitTest(jtu.JaxTestCase):
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
lowered3.compile()
def test_jit_out_shardings_unconstrained(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, s)
out_s = NamedSharding(mesh, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
@partial(jax.jit, out_shardings=out_s)
def f(x):
return x * 2
out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp * 2)
@partial(jax.jit, out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y')))
def g(x):
return x * 3
out = g(arr)
self.assertArraysEqual(out, np_inp * 3)
self.assertEqual(out.sharding, s)
lowered_text = g.lower(arr).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {"y"}]>', lowered_text)
else:
self.assertIn("unspecified_dims=[0]", lowered_text)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@ -5548,7 +5576,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
return a
out = f(arr, arr.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
def test_auto_user(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),