From e854f1657a9d77853e38e0c5b2a5adf752f2ebef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 17 Dec 2024 19:17:48 -0800 Subject: [PATCH] Allow `P.UNCONSTRAINED` in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs. This also simplifies the `impl` rule of `sharding_cast`. PiperOrigin-RevId: 707349491 --- jax/_src/interpreters/mlir.py | 44 ++++++++++++++++++++++++++++++----- jax/_src/interpreters/pxla.py | 17 ++++++++------ jax/_src/mesh.py | 16 +++++++++++++ jax/_src/pjit.py | 21 ++++++----------- tests/pjit_test.py | 30 +++++++++++++++++++++++- 5 files changed, 100 insertions(+), 28 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 5923cfe00..97f91555c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -49,7 +49,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.sharding import Sharding as JSharding -from jax._src.sharding_impls import AUTO +from jax._src.sharding_impls import AUTO, NamedSharding from jax._src.partition_spec import UnconstrainedSingleton from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension @@ -1055,6 +1055,21 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None: assert isinstance(s, JSharding) return s.memory_kind +def contains_unconstrained(s): + return isinstance(s, NamedSharding) and None in s._parsed_pspec + +def all_unconstrained(s, aval): + if isinstance(s, NamedSharding): + if aval.ndim != len(s._parsed_pspec): + return False + return all(p is None for p in s._parsed_pspec) + return False + +def _get_unconstrained_dimensions(s, aval): + us = contains_unconstrained(s) + return (us, all_unconstrained(s, aval), + ({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None)) + def lower_jaxpr_to_module( module_name: str, @@ -1114,7 +1129,8 @@ def lower_jaxpr_to_module( f"only {platforms_with_donation} support donation") if (num_partitions > 1 and (result_shardings is None or - all(s is None or isinstance(s, AUTO) for s in result_shardings))): + all(s is None or isinstance(s, AUTO) or contains_unconstrained(s) + for s in result_shardings))): xla_donated_args = donated_args donated_args = [False] * len(donated_args) if xla_donated_args is None: @@ -1448,7 +1464,8 @@ def lower_jaxpr_to_fun( ir_arg_memory_kinds = None if arg_memory_kinds is not None: ir_arg_memory_kinds = util.flatten( - [[mk] * len_ir_types(types) for mk, types in zip(arg_memory_kinds, input_types)]) + [[mk] * len_ir_types(types) + for mk, types in zip(arg_memory_kinds, input_types)]) ir_arg_layouts = None if arg_layouts is not None: @@ -1459,13 +1476,18 @@ def lower_jaxpr_to_fun( ir_donated_args = None if xla_donated_args is not None: ir_donated_args = util.flatten( - [[is_donated] * len_ir_types(types) for is_donated, types in zip(xla_donated_args, input_types)]) + [[is_donated] * len_ir_types(types) + for is_donated, types in zip(xla_donated_args, input_types)]) ir_result_shardings = None + unconstrained_shardings = None if result_shardings is not None: ir_result_shardings = util.flatten( [[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types) for a, s, types in zip(output_avals, result_shardings, output_types)]) + unconstrained_shardings = util.flatten( + [[_get_unconstrained_dimensions(s, a)] * len_ir_types(types) + for a, s, types in zip(output_avals, result_shardings, output_types)]) ir_result_memory_kinds = None custom_call_ir_result_memory_kinds = None @@ -1580,8 +1602,9 @@ def lower_jaxpr_to_fun( attrs['jax.result_info'] = ir.StringAttr.get(name_) if use_sharding_annotations and ir_result_shardings is not None: - for attrs, sharding in zip(result_attrs, ir_result_shardings): - if sharding is not None: + for attrs, sharding, us in zip(result_attrs, ir_result_shardings, + unconstrained_shardings): # type: ignore + if sharding is not None and not us[0]: if config.use_shardy_partitioner.value: attrs["sdy.sharding"] = get_sharding_attr(sharding) else: @@ -1658,6 +1681,15 @@ def lower_jaxpr_to_fun( o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s) for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)] + if ir_result_shardings is not None: + flat_outputs = [ + wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s, + unspecified_dims=us[2]) + if us[0] and not us[1] else o + for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings, + output_avals, unconstrained_shardings) # type: ignore + ] + # Insert a custom call if output is on host because XLA needs that to do the # transfer. if custom_call_ir_result_memory_kinds is not None and name == "main": diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a1936d213..a5cd193b5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -62,7 +62,7 @@ from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton +from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding as JSharding from jax._src.mesh import AbstractMesh, Mesh from jax._src.sharding_impls import ( @@ -2162,10 +2162,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment): out = [] for s, a in zip(shardings, avals): - # Remove the `UnconstrainedSingleton` logic after UNCONSTRAINED is supported - # in out_shardings at top level jit. - if (isinstance(s, UnspecifiedValue) and a.sharding is not None and - all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)): + if isinstance(s, UnspecifiedValue) and a.sharding is not None: out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), a.sharding.spec)) else: @@ -2794,6 +2791,11 @@ def _maybe_get_and_check_out_shardings( dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) + elif mlir.contains_unconstrained(orig): + if (aval is not core.abstract_token and + dtypes.issubdtype(aval.dtype, dtypes.extended)): + xla_s = sharding_impls.logical_sharding(aval, xla_s) + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error @@ -2909,8 +2911,9 @@ class UnloadedMeshExecutable: allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings) - allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO)) - for o in out_shardings) + allow_prop_to_outputs = tuple( + isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o) + for o in out_shardings) mesh = None if auto_spmd_lowering: diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 2b1f5c178..25fb2b38f 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -353,6 +353,22 @@ class Mesh(contextlib.ContextDecorator): def with_axis_types(self, new_axis_types) -> Mesh: return Mesh(self.devices, self.axis_names, axis_types=new_axis_types) + @functools.cached_property + def _are_all_axes_collective(self) -> bool: + return all(t == AxisTypes.Collective for t in self.axis_types.keys()) + + @functools.cached_property + def _are_all_axes_auto(self) -> bool: + return all(t == AxisTypes.Auto for t in self.axis_types.keys()) + + @functools.cached_property + def _any_axis_collective(self) -> bool: + return any(t == AxisTypes.Collective for t in self.axis_types.keys()) + + @functools.cached_property + def _any_axis_auto(self) -> bool: + return any(t == AxisTypes.Auto for t in self.axis_types.keys()) + EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4a08558bd..aafaaad68 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2675,13 +2675,6 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: () # -------------------- sharding_cast --------------------------- -def _check_mesh_shape_same(src_sharding, dst_sharding, aval): - if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple: - raise ValueError( - f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not' - ' match the mesh shape of the target sharding' - f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}') - def sharding_cast(xs, shardings): if isinstance(shardings, NamedSharding): return tree_map(lambda x: sharding_cast_p.bind( @@ -2695,17 +2688,17 @@ def sharding_cast(xs, shardings): sharding_cast_p = core.Primitive('sharding_cast') def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding): - _check_mesh_shape_same(src_sharding, dst_sharding, aval) + if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple: + raise ValueError( + f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not' + ' match the mesh shape of the target sharding' + f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}') return aval.update(sharding=dst_sharding) sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval) def _sharding_cast_impl(x, src_sharding, dst_sharding): - aval = shaped_abstractify(x) - _check_mesh_shape_same(x.sharding, dst_sharding, aval) - new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types) - concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec) - # TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)` - return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x) + return dispatch.apply_primitive(sharding_cast_p, x, src_sharding=src_sharding, + dst_sharding=dst_sharding) sharding_cast_p.def_impl(_sharding_cast_impl) def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 619de3f02..6bfa9f632 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4680,6 +4680,34 @@ class ArrayPjitTest(jtu.JaxTestCase): RuntimeError, 'A jitted computation cannot contain AbstractMesh'): lowered3.compile() + def test_jit_out_shardings_unconstrained(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) + + out_s = NamedSharding(mesh, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + @partial(jax.jit, out_shardings=out_s) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np_inp * 2) + + @partial(jax.jit, out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y'))) + def g(x): + return x * 3 + + out = g(arr) + self.assertArraysEqual(out, np_inp * 3) + self.assertEqual(out.sharding, s) + lowered_text = g.lower(arr).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {"y"}]>', lowered_text) + else: + self.assertIn("unspecified_dims=[0]", lowered_text) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") @@ -5548,7 +5576,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): return a out = f(arr, arr.T) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x',))) def test_auto_user(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'),