diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c9c1628a7..6500eca14 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -41,7 +41,6 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu -from jax._src import mesh as mesh_lib from jax._src import op_shardings from jax._src import sharding_specs from jax._src import profiler @@ -65,6 +64,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton from jax._src.sharding import Sharding as JSharding +from jax._src.mesh import AbstractMesh, Mesh from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, UnspecifiedValue, get_array_mapping as _get_array_mapping, @@ -98,7 +98,6 @@ ShardedAxis = sharding_specs.ShardedAxis Replicated = sharding_specs.Replicated AvalDimSharding = Union[Unstacked, Chunked, NoSharding] -Mesh = mesh_lib.Mesh MeshAxisName = sharding_impls.MeshAxisName MeshDimAssignment = Union[ShardedAxis, Replicated] ShardingSpec = sharding_specs.ShardingSpec @@ -1723,20 +1722,19 @@ def _get_and_check_device_assignment( devices: Sequence[xc.Device] | None, ) -> tuple[xc.Client, tuple[xc.Device, ...]]: first_sharding_info = None - if devices is None: - devices = () - else: - devices = tuple(devices) + devices = () if devices is None else tuple(devices) - for i, s_type, source_info in shardings: - if isinstance(i, UnspecifiedValue): + for sh, s_type, source_info in shardings: + if isinstance(sh, UnspecifiedValue): + continue + if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh): continue - if first_sharding_info is None: first_sharding_info = ( - (i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO) - else (i._device_assignment, s_type, source_info)) - arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment + (sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO) + else (sh._device_assignment, s_type, source_info)) + arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO) + else sh._device_assignment) if not devices: if first_sharding_info[0] != arr_device_assignment: raise DeviceAssignmentMismatchError([ @@ -1837,7 +1835,8 @@ class SemanticallyEqualShardings: def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], avals: tuple[core.AbstractValue]): gspmd_shardings = [ - s if isinstance(s, (UnspecifiedValue, AUTO)) + s if (isinstance(s, (UnspecifiedValue, AUTO)) or + (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh))) else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings @@ -1895,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, propagated_out_mem_kinds: tuple[None | str, ...], platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters, - abstract_mesh: mesh_lib.AbstractMesh | None): + abstract_mesh: AbstractMesh | None): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -2082,6 +2081,40 @@ def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr return tuple(safe_map(read, jaxpr.outvars)) +def _get_num_devices(shardings, device_assignment, lowering_platforms, + prim_requires_devices) -> int: + ext_abstract_mesh, concrete_sharding = None, False + for s in shardings: + if isinstance(s, UnspecifiedValue): + continue + elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh: + raise ValueError("AbstractMesh should be the same across all " + f"shardings. Got {ext_abstract_mesh} and {s.mesh}") + ext_abstract_mesh = s.mesh + else: + concrete_sharding = True + if (concrete_sharding and ext_abstract_mesh is not None and + len(device_assignment) != ext_abstract_mesh.size): + raise ValueError( + f"AbstractMesh size: {ext_abstract_mesh.size} does not match the" + f" device assignment size: {len(device_assignment)}") + if concrete_sharding: + return len(device_assignment) + if ext_abstract_mesh is None: + return len(device_assignment) + if lowering_platforms is None: + raise ValueError( + "Passing lowering_platforms via" + " jit(f).trace(*args).lower(lowering_platforms=...) is required when" + " only AbstractMesh exists in a jitted computation.") + if prim_requires_devices: + raise ValueError( + "AbstractMesh cannot be used when jaxpr contains primitives that" + " require devices to be present during lowering.") + return ext_abstract_mesh.size + + MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]] @@ -2126,7 +2159,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment): @lru_cache(maxsize=128) def _abstract_to_concrete_mesh(abstract_mesh): - return mesh_lib.Mesh( + return Mesh( np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types) @@ -2153,7 +2186,7 @@ def lower_sharding_computation( donated_invars: Sequence[bool], *, keep_unused: bool, - context_mesh: mesh_lib.Mesh | None, + context_mesh: Mesh | None, compiler_options_kvs: tuple[tuple[str, Any], ...], lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, @@ -2211,6 +2244,7 @@ def lower_sharding_computation( ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) for js, source_info in unique_intermediate_shardings)), devices_from_context) + unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] if config.sharding_in_types.value: out_shardings = _concretize_abstract_shardings( @@ -2221,12 +2255,23 @@ def lower_sharding_computation( platforms = lowering_platforms or ( getattr(backend, "_raw_platform", backend.platform),) + prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) + + # TODO(yashkatariya): All device specific logic should go in compilation + # but this requires a big refactor. The current `_get_num_devices` logic + # is good enough to lower with AbstractMesh but cannot be compiled. Once + # I refactor, this will also work well with mesh being provided at + # compile time. + num_devices = _get_num_devices( + it.chain(unique_in_shardings, unique_out_shardings, + unique_intermediate_shardings), + device_assignment, lowering_platforms, prim_requires_devices) + committed = bool( - devices_from_context or - len(device_assignment) > 1 or - any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or - any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or - any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings)) + devices_from_context + or num_devices > 1 + or any(not isinstance(s, UnspecifiedValue) for s in it.chain( + unique_in_shardings, unique_out_shardings, unique_intermediate_shardings))) da_object = _create_da_object(tuple(device_assignment)) @@ -2234,8 +2279,7 @@ def lower_sharding_computation( all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, it.chain(unique_in_shardings, unique_out_shardings, - [js for js, _ in unique_intermediate_shardings], - transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types + unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types if all_default_mem_kind: propagated_out_mem_kinds = (None,) * len(global_out_avals) @@ -2244,12 +2288,11 @@ def lower_sharding_computation( closed_jaxpr, in_shardings) # 2. Build up the HLO - prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) abstract_mesh = None if prim_requires_devices: for sharding in it.chain(unique_in_shardings, unique_out_shardings, - [js for js, _ in unique_intermediate_shardings]): + unique_intermediate_shardings): if isinstance(sharding, NamedSharding): if (abstract_mesh is not None and abstract_mesh != sharding.mesh.abstract_mesh): @@ -2267,7 +2310,7 @@ def lower_sharding_computation( (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, - semantic_out_shardings, in_layouts, out_layouts, len(da_object), + semantic_out_shardings, in_layouts, out_layouts, num_devices, tuple(da_object) if prim_requires_devices else None, donated_invars, name_stack, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, @@ -2310,7 +2353,7 @@ def lower_sharding_computation( all_default_mem_kind=all_default_mem_kind, all_args_info=all_args_info, pgle_profiler=pgle_profiler, - intermediate_shardings=[s for s, _ in unique_intermediate_shardings], + intermediate_shardings=unique_intermediate_shardings, context_mesh=context_mesh) @@ -2480,7 +2523,7 @@ def _register_out_sharding_handler( def _gspmd_to_named_sharding( out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding: - assert isinstance(orig_in_s.mesh, mesh_lib.Mesh) + assert isinstance(orig_in_s.mesh, Mesh) return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding) @@ -2532,7 +2575,7 @@ def _get_out_sharding_from_orig_sharding( def maybe_recover_user_shardings( old_shardings, new_shardings, old_avals, new_avals, - intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None): + intermediate_shardings=None, context_mesh: Mesh | None = None): if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings): return new_shardings @@ -2832,8 +2875,14 @@ class UnloadedMeshExecutable: all_args_info: AllArgsInfo | None = None, pgle_profiler: profiler.PGLEProfiler | None = None, intermediate_shardings: Sequence[JSharding] | None = None, - context_mesh: mesh_lib.Mesh | None = None + context_mesh: Mesh | None = None, ) -> MeshExecutable: + if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) + for s in it.chain(in_shardings, out_shardings)): + raise RuntimeError( + "A jitted computation cannot contain AbstractMesh in in_shardings and" + " out_shardings during compilation. You can use `jax.export` to " + " lower with an AbstractMesh and later compile with concrete devices.") if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) if isinstance(device_assignment, xc.DeviceList): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index eea2a1e34..39c28f11c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -498,7 +498,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) + pgle_profiler=None) return stages.Traced( p.params['jaxpr'], args_info, p.params["name"], p.out_tree, lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ad57434f1..249ebde08 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4631,6 +4631,54 @@ class ArrayPjitTest(jtu.JaxTestCase): ins, _ = f.lower(np.arange(8)).compile().input_shardings self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) + def test_abstract_mesh_lower(self): + mesh = jtu.create_mesh((2,), 'x') + mesh2 = jtu.create_mesh((1,), 'x') + + abstract_sds = jax.ShapeDtypeStruct( + (8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x'))) + abstract_sds2 = jax.ShapeDtypeStruct( + (8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('x'))) + + @jax.jit + def f(x): + return x * 2 + + lowered = f.trace(abstract_sds).lower(lowering_platforms=('tpu',)) + self.assertIn('num_partitions = 2', lowered.as_text()) + + with self.assertRaisesRegex( + RuntimeError, 'A jitted computation cannot contain AbstractMesh'): + lowered.compile() + + @jax.jit + def g(x, y): + return x, y + + concrete_s = NamedSharding(mesh, P('x')) + concrete_sds = jax.ShapeDtypeStruct((8,), jnp.float32, sharding=concrete_s) + with self.assertRaisesRegex( + ValueError, + 'AbstractMesh size: 1 does not match the device assignment size: 2'): + g.lower(abstract_sds2, concrete_sds) + + with self.assertRaisesRegex( + ValueError, "Passing lowering_platforms.*is required"): + g.lower(abstract_sds, np.arange(8)) + + lowered2 = g.trace(abstract_sds, np.arange(8)).lower( + lowering_platforms=('tpu',)) + self.assertIn('num_partitions = 2', lowered2.as_text()) + with self.assertRaisesRegex( + RuntimeError, 'A jitted computation cannot contain AbstractMesh'): + lowered2.compile() + + lowered3 = g.lower(abstract_sds, concrete_sds) + self.assertIn('num_partitions = 2', lowered3.as_text()) + with self.assertRaisesRegex( + RuntimeError, 'A jitted computation cannot contain AbstractMesh'): + lowered3.compile() + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")