diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index c038ef064..8f48746dd 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -24,6 +24,7 @@ import inspect from typing import Any import weakref +import numpy as np import jax from jax import tree_util from jax._src import api_util @@ -481,17 +482,20 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, infer_sharding_from_operands, decode_shardings, static_args): - mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) + mesh = mesh_lib.thread_resources.env.physical_mesh if isinstance(axis_context, sharding_impls.ShardingContext): devices = axis_context.device_assignment if devices is None: raise AssertionError( 'Please file a bug at https://github.com/google/jax/issues') + if axis_context.mesh_shape is not None: + ma, ms = list(zip(*axis_context.mesh_shape)) + mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = axis_context.mesh._flat_devices_tuple else: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 8605c58a8..068b5e3b7 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -204,7 +204,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: # stablehlo is oblivious of physical devices. prim_requires_devices_during_lowering: set[core.Primitive] = set() -def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr): +def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: for eqn in jaxpr.eqns: if eqn.primitive in prim_requires_devices_during_lowering: return True diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 4d15e803a..814c6a988 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1033,7 +1033,6 @@ def lower_jaxpr_to_module( input_output_aliases: None | tuple[int | None, ...] = None, propagated_out_mem_kinds: tuple[None | str, ...] | None = None, lowering_parameters: LoweringParameters, - mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None, ) -> LoweringResult: """Lowers a top-level jaxpr to an MLIR module. @@ -1121,13 +1120,14 @@ def lower_jaxpr_to_module( # XLA computation preserves the module name. attrs = ctx.module.operation.attributes if config.use_shardy_partitioner.value: - assert mesh_shape_tuple is not None + assert (isinstance(axis_context, sharding_impls.ShardingContext) and + axis_context.mesh_shape is not None) ctx.module.body.append( dialects.sdy.MeshOp( "mesh", dialects.sdy.MeshAttr.get( [dialects.sdy.MeshAxisAttr.get(name, size) - for name, size in mesh_shape_tuple]))) + for name, size in axis_context.mesh_shape]))) module_name = _module_name_regex.sub("_", module_name) attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4dcdfbcbf..ce96f7e81 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, propagated_out_mem_kinds: tuple[None | str, ...], platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters, - mesh_shape_tuple: tuple[tuple[str, int], ...]): + mesh_shape_tuple: tuple[tuple[str, int], ...] | None): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -1911,7 +1911,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(global_in_avals) - axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment) + axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment, + mesh_shape_tuple) num_partitions = num_devices else: # This path is triggered for `jit(pmap)` cases. @@ -1957,8 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, all_default_mem_kind=all_default_mem_kind, input_output_aliases=inout_aliases, propagated_out_mem_kinds=propagated_out_mem_kinds, - lowering_parameters=lowering_parameters, - mesh_shape_tuple=mesh_shape_tuple) + lowering_parameters=lowering_parameters) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( effects.ordered_effects.filter_not_in(closed_jaxpr.effects)) @@ -2202,15 +2202,21 @@ def lower_sharding_computation( in_shardings, global_in_avals) # type: ignore semantic_out_shardings = SemanticallyEqualShardings( out_shardings, global_out_avals) # type: ignore + prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) + mesh_shape_tuple = None - if config.use_shardy_partitioner.value: - for sharding in it.chain( - in_shardings, out_shardings, - [js for js, _ in unique_intermediate_shardings]): + if config.use_shardy_partitioner.value or prim_requires_devices: + for sharding in it.chain(in_shardings, out_shardings, + [js for js, _ in unique_intermediate_shardings]): if isinstance(sharding, sharding_impls.NamedSharding): + if (mesh_shape_tuple is not None and + mesh_shape_tuple != sharding.mesh.shape_tuple): + raise ValueError( + "mesh should be the same across the entire program. Got mesh" + f" shape for one sharding {mesh_shape_tuple} and" + f" {sharding.mesh.shape_tuple} for another") mesh_shape_tuple = sharding.mesh.shape_tuple - break (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d41ef7410..1a23f4ba7 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1162,6 +1162,7 @@ class ShardingContext: """ num_devices: int device_assignment: tuple[xc.Device, ...] | None = None + mesh_shape: tuple[tuple[str, int], ...] | None = None def __post_init__(self): if self.device_assignment is not None: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 44782ec15..9368f7da9 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1514,6 +1514,45 @@ class CustomPartitionerTest(jtu.JaxTestCase): xs = jnp.ones([32, 16]) self.assertEqual(pjit_f(xs), xs.sum()) + def test_custom_partitioning_no_mesh_context(self): + self.skip_if_custom_partitioning_not_supported() + + @custom_partitioning + def f(x): + return x + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + @jax.jit + def g(y): + return y + + return g(x) + + x_shard = arg_shapes[0].sharding + return ( + mesh, + lower_fn, + NamedSharding(x_shard.mesh, P('x')), + (NamedSharding(x_shard.mesh, P('x')),), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + x_shard = arg_shapes[0].sharding + return NamedSharding(x_shard.mesh, P('x')) + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + ) + + mesh = jtu.create_global_mesh((4,), ('x',)) + x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) + s = NamedSharding(mesh, P('x')) + + jit_f = jax.jit(f, in_shardings=s, out_shardings=s) + self.assertArraysEqual(x, jit_f(x)) + @jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase):