From 9e2708eb57ab1810ee576d7da6d489c64bb995ad Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 4 Dec 2024 12:11:54 -0800 Subject: [PATCH] [sharding_in_types] Use `set_mesh` API to trigger sharding_in_types instead of the config option. PiperOrigin-RevId: 702814257 --- jax/_src/interpreters/pxla.py | 11 +-- jax/_src/pjit.py | 11 +-- jax/_src/test_util.py | 11 +++ tests/pjit_test.py | 160 +++++++++++++++------------------- 4 files changed, 88 insertions(+), 105 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 98cfbbb1d..41f91202e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2193,15 +2193,8 @@ def lower_sharding_computation( assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) - if config.sharding_in_types.value: - # TODO(yashkatariya): Thread it via jit path and remove the None check by - # making tests go via set_mesh API always. - devices_from_context = ( - None if mesh_lib.device_context.concrete_mesh is None - else mesh_lib.device_context.concrete_mesh._flat_devices_tuple) - else: - devices_from_context = (None if context_mesh is None or context_mesh.empty - else context_mesh._flat_devices_tuple) + devices_from_context = (None if context_mesh is None or context_mesh.empty + else context_mesh._flat_devices_tuple) # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. unique_intermediate_shardings = util.stable_unique( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 462be851e..1f8378d84 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -707,9 +707,6 @@ def get_abstract_mesh(in_avals): f'Mesh for all inputs should be equal. Got one mesh: {m} and' f' another mesh: {a.sharding.mesh}') m = a.sharding.mesh # type: ignore - # TODO(yashkatariya): Remove this when mesh context can be set by the user. - if m is None: - return contextlib.nullcontext() assert isinstance(m, AbstractMesh) return m @@ -1791,8 +1788,12 @@ def _pjit_lower( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - mesh, api_name = ((resource_env.physical_mesh, 'pjit') - if resource_env is not None else (None, 'jit')) + if config.sharding_in_types.value: + mesh = mesh_lib.device_context.concrete_mesh + api_name = 'jit' + else: + mesh, api_name = ((resource_env.physical_mesh, 'pjit') + if resource_env is not None else (None, 'jit')) return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c639ebd03..0bd5c7b13 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -51,6 +51,7 @@ from jax._src import monitoring from jax._src import pjit as pjit_lib from jax._src import stages from jax._src import xla_bridge +from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir from jax._src.interpreters import pxla @@ -1442,6 +1443,16 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) +def with_user_mesh(sizes, names): + def decorator(fn): + def mesh_fn(*args, **kwargs): + mesh = create_mesh(sizes, names) + with mesh_lib.set_mesh(mesh): + return fn(*args, **kwargs, mesh=mesh) + return mesh_fn + return decorator + + def create_mesh(mesh_shape, axis_names, iota_order=False): size = math.prod(mesh_shape) if len(jax.devices()) < size: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index af81b3557..52261ef02 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4630,38 +4630,16 @@ class ArrayPjitTest(jtu.JaxTestCase): ins, _ = f.lower(np.arange(8)).compile().input_shardings self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0])) - def test_sharding_in_types_with_set_mesh(self): - if config.use_shardy_partitioner.value: - self.skipTest("ShiT doesn't work with shardy") - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - with mesh_lib.set_mesh(mesh): - np_inp = np.arange(16.).reshape(8, 2) - s = NamedSharding(mesh, P('x', 'y')) - arr = jax.device_put(np_inp, s) - - @jax.jit - def f(x): - self.assertEqual(x.sharding.spec, s.spec) - x = x * 2 - self.assertEqual(x.sharding.spec, s.spec) - x = x * x - self.assertEqual(x.sharding.spec, s.spec) - return x - - out = f(arr) - self.assertEqual(out.sharding, s) - self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) - def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") -@jtu.with_config(jax_sharding_in_types=True, jax_use_shardy_partitioner=False) +@jtu.with_config(jax_use_shardy_partitioner=False) class ShardingInTypesTest(jtu.JaxTestCase): - def test_basic_mul(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_basic_mul(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4696,8 +4674,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - def test_fully_replicated_array_mul(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_fully_replicated_array_mul(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp1, s) @@ -4745,8 +4723,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce') ) - def test_dot_general(self, spec1, spec2, out_spec, collective_name): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4781,8 +4759,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) - def test_dot_general_out_type(self): - mesh = jtu.create_mesh((4,), ('x',)) + @jtu.with_user_mesh((4,), ('x',)) + def test_dot_general_out_type(self, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) @@ -4824,8 +4802,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): "dot_general requires contracting dimensions to have consistent sharding", TypeError), ) - def test_dot_general_error(self, spec1, spec2, error_msg, error_type): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) @@ -4837,8 +4815,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): with self.assertRaisesRegex(error_type, error_msg): f(arr1, arr2) - def test_dot_general_batch_error(self): - mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_dot_general_batch_error(self, mesh): arr1 = jax.device_put(np.ones((8, 4, 2)), NamedSharding(mesh, P('x', 'y', 'z'))) arr2 = jax.device_put(np.ones((8, 2, 4)), @@ -4856,9 +4834,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) - def test_aval_repr(self): - mesh = jtu.create_mesh((2, 2), ('model', 'data')) - + @jtu.with_user_mesh((2, 2), ('model', 'data')) + def test_aval_repr(self, mesh): aval = core.ShapedArray((128, 64), np.float32, sharding=NamedSharding(mesh, P('model', 'data'))) self.assertEqual(aval.str_short(), 'float32[128@model,64@data]') @@ -4876,14 +4853,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') @parameterized.named_parameters( - ('all', None, P('x', 'y'), P()), - ('first', 0, P('x', 'y'), P('y')), - ('second', 1, P('x', 'y'), P('x')), - ('first2', 0, P(('x', 'y'), None), P(None)), + ('all', None, P('x', 'y'), P(), True), + ('first', 0, P('x', 'y'), P('y'), True), + ('second', 1, P('x', 'y'), P('x'), True), + ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4907,14 +4884,14 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertIn('all-reduce', compiled_text) @parameterized.named_parameters( - ('all', None, P('x', 'y'), P()), - ('first', 0, P('x', 'y'), P('y')), - ('second', 1, P('x', 'y'), P('x')), - ('first2', 0, P(('x', 'y'), None), P(None)), + ('all', None, P('x', 'y'), P(), True), + ('first', 0, P('x', 'y'), P('y'), True), + ('second', 1, P('x', 'y'), P('x'), True), + ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - def test_reduce_max(self, axis, in_spec, out_spec, reduce=True): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) arr = jax.device_put(np_inp, s) @@ -4954,8 +4931,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): ('2', 2, P('x', 'y', None)), ('-1', -1, P('x', 'y', None)), ) - def test_broadcast_in_dim(self, axis, out_spec): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_broadcast_in_dim(self, axis, out_spec, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4980,8 +4957,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): ('3', 3), ('4', 4), ) - def test_integer_pow(self, pow): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_integer_pow(self, pow, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5010,12 +4987,13 @@ class ShardingInTypesTest(jtu.JaxTestCase): def f(x, y): return x + y - with self.assertRaisesRegex( - ValueError, "Mesh for all inputs should be equal"): - f(arr1, arr2) + with config.sharding_in_types(True): + with self.assertRaisesRegex( + ValueError, "Mesh for all inputs should be equal"): + f(arr1, arr2) - def test_sin_unop(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5032,8 +5010,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_jnp_array(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_jnp_array(self, mesh): np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5048,8 +5026,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): f(arr) - def test_lax_transpose_rule(self): - mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_lax_transpose_rule(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', 'y', 'z')) arr = jax.device_put(np_inp, s) @@ -5067,8 +5045,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_broadcasted_iota_with_sharding(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_broadcasted_iota_with_sharding(self, mesh): np_inp = np.arange(4) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) @@ -5094,8 +5072,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - def test_einsum_with_out_type(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_einsum_with_out_type(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -5140,8 +5118,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out[0].sharding, arr3.sharding) self.assertEqual(out[1].sharding, arr4.sharding) - def test_einsum_inverse(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_einsum_inverse(self, mesh): np_inp = np.arange(64.) @jax.jit @@ -5179,9 +5157,9 @@ class ShardingInTypesTest(jtu.JaxTestCase): ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True) ) + @jtu.with_user_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, - use_sharding_arg): - mesh = jtu.create_mesh((2,), ('x',)) + use_sharding_arg, mesh): np_inp = np.arange(math.prod(src_shape), dtype=np.float32).reshape(src_shape) arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) @@ -5209,8 +5187,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - def test_select(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_select(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp, s) @@ -5234,8 +5212,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): TypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) - def test_device_put_reshard(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_device_put_reshard(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5250,8 +5228,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - def test_shard_map_full_manual(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_shard_map_full_manual(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5275,8 +5253,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - def test_shard_map_dot(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_shard_map_dot(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) @@ -5302,8 +5280,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - def test_slice(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_slice(self, mesh): np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5333,8 +5311,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) - def test_squeeze(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_squeeze(self, mesh): np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @@ -5359,8 +5337,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - def test_pad(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_pad(self, mesh): np_inp = np.arange(8.) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -5401,8 +5379,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) - def test_concatenate(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) + @jtu.with_user_mesh((2, 1), ('x', 'y')) + def test_concatenate(self, mesh): np_inp = np.arange(16.).reshape(4, 4) s = NamedSharding(mesh, P('x', 'y')) arr1 = jax.device_put(np_inp, s) @@ -5443,8 +5421,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) - def test_scan(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_scan(self, mesh): carry = jax.device_put(np.arange(16.).reshape(2, 8), NamedSharding(mesh, P(None, 'x'))) arr = jax.device_put(np.arange(128.).reshape(8, 8, 2), @@ -5481,8 +5459,8 @@ class ShardingInTypesTest(jtu.JaxTestCase): ValueError, "0th dimension of all xs should be replicated"): f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) - def test_argminmax(self): - mesh = jtu.create_mesh((2, 2), ('x', 'y')) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_argminmax(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s)