diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 6209a37d9..22cdff8c6 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -481,6 +481,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, static_args): mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context + if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and + set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): + return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) if isinstance(axis_context, sharding_impls.ShardingContext): devices = axis_context.device_assignment diff --git a/tests/BUILD b/tests/BUILD index 220bba3c3..f89b113e9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1235,6 +1235,7 @@ jax_test( "notsan", ], # Times out under *SAN. deps = [ + "//jax:experimental", "//jax:tree_util", ], ) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 52621645a..6c5ad7208 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -44,6 +44,7 @@ from jax._src import linear_util as lu from jax._src import tree_util import jax.numpy as jnp +from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.shard_map import shard_map config.parse_flags_with_absl() @@ -1684,6 +1685,61 @@ class ShardMapSystematicTest(jtu.JaxTestCase): tol = 1e-2 if jtu.test_device_matches(['tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) +@jtu.pytest_mark_if_available('multiaccelerator') +class CustomPartitionerTest(jtu.JaxTestCase): + + def skip_if_custom_partitioning_not_supported(self): + if jtu.is_cloud_tpu(): + raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + if xla_bridge.using_pjrt_c_api(): + raise unittest.SkipTest('custom partitioning not implemented in PJRT C API') + + def test_custom_partitioning(self): + self.skip_if_custom_partitioning_not_supported() + + mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) + assert a.addressable_data(0).shape == (4, 2) + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + ) + + @jax.jit + def fwd(a): + c = shard_map( + f, + mesh, + check_rep=False, + in_specs=(P('z', ('x', 'y')),), + out_specs=P('z', ('x', 'y')))(a) + return c + + c = fwd(a) + self.assertEqual(c.addressable_data(0).shape, (4, 2)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())