[sharding_in_types] Use set_mesh API to trigger sharding_in_types instead of the config option.

PiperOrigin-RevId: 702814257
This commit is contained in:
Yash Katariya 2024-12-04 12:11:54 -08:00 committed by jax authors
parent fa6585dea1
commit 9e2708eb57
4 changed files with 88 additions and 105 deletions

View File

@ -2193,15 +2193,8 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
if config.sharding_in_types.value:
# TODO(yashkatariya): Thread it via jit path and remove the None check by
# making tests go via set_mesh API always.
devices_from_context = (
None if mesh_lib.device_context.concrete_mesh is None
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
else:
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
unique_intermediate_shardings = util.stable_unique(

View File

@ -707,9 +707,6 @@ def get_abstract_mesh(in_avals):
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
f' another mesh: {a.sharding.mesh}')
m = a.sharding.mesh # type: ignore
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if m is None:
return contextlib.nullcontext()
assert isinstance(m, AbstractMesh)
return m
@ -1791,8 +1788,12 @@ def _pjit_lower(
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
if resource_env is not None else (None, 'jit'))
if config.sharding_in_types.value:
mesh = mesh_lib.device_context.concrete_mesh
api_name = 'jit'
else:
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
if resource_env is not None else (None, 'jit'))
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),

View File

@ -51,6 +51,7 @@ from jax._src import monitoring
from jax._src import pjit as pjit_lib
from jax._src import stages
from jax._src import xla_bridge
from jax._src import mesh as mesh_lib
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
@ -1442,6 +1443,16 @@ def with_and_without_mesh(f):
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
def with_user_mesh(sizes, names):
def decorator(fn):
def mesh_fn(*args, **kwargs):
mesh = create_mesh(sizes, names)
with mesh_lib.set_mesh(mesh):
return fn(*args, **kwargs, mesh=mesh)
return mesh_fn
return decorator
def create_mesh(mesh_shape, axis_names, iota_order=False):
size = math.prod(mesh_shape)
if len(jax.devices()) < size:

View File

@ -4630,38 +4630,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
ins, _ = f.lower(np.arange(8)).compile().input_shardings
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
def test_sharding_in_types_with_set_mesh(self):
if config.use_shardy_partitioner.value:
self.skipTest("ShiT doesn't work with shardy")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
with mesh_lib.set_mesh(mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
x = x * 2
self.assertEqual(x.sharding.spec, s.spec)
x = x * x
self.assertEqual(x.sharding.spec, s.spec)
return x
out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@jtu.with_config(jax_sharding_in_types=True, jax_use_shardy_partitioner=False)
@jtu.with_config(jax_use_shardy_partitioner=False)
class ShardingInTypesTest(jtu.JaxTestCase):
def test_basic_mul(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_basic_mul(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -4696,8 +4674,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)
def test_fully_replicated_array_mul(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_fully_replicated_array_mul(self, mesh):
np_inp1 = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr1 = jax.device_put(np_inp1, s)
@ -4745,8 +4723,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'),
('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce')
)
def test_dot_general(self, spec1, spec2, out_spec, collective_name):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh):
np_inp1 = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))
@ -4781,8 +4759,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out[0].sharding, arr1.sharding)
self.assertEqual(out[1].sharding, arr2.sharding)
def test_dot_general_out_type(self):
mesh = jtu.create_mesh((4,), ('x',))
@jtu.with_user_mesh((4,), ('x',))
def test_dot_general_out_type(self, mesh):
np_inp1 = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
@ -4824,8 +4802,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
"dot_general requires contracting dimensions to have consistent sharding",
TypeError),
)
def test_dot_general_error(self, spec1, spec2, error_msg, error_type):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh):
np_inp1 = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))
@ -4837,8 +4815,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
with self.assertRaisesRegex(error_type, error_msg):
f(arr1, arr2)
def test_dot_general_batch_error(self):
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
@jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z'))
def test_dot_general_batch_error(self, mesh):
arr1 = jax.device_put(np.ones((8, 4, 2)),
NamedSharding(mesh, P('x', 'y', 'z')))
arr2 = jax.device_put(np.ones((8, 2, 4)),
@ -4856,9 +4834,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
' have the consistent sharding'):
jnp.einsum('abc,acz->abz', arr1, arr2)
def test_aval_repr(self):
mesh = jtu.create_mesh((2, 2), ('model', 'data'))
@jtu.with_user_mesh((2, 2), ('model', 'data'))
def test_aval_repr(self, mesh):
aval = core.ShapedArray((128, 64), np.float32,
sharding=NamedSharding(mesh, P('model', 'data')))
self.assertEqual(aval.str_short(), 'float32[128@model,64@data]')
@ -4876,14 +4853,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]')
@parameterized.named_parameters(
('all', None, P('x', 'y'), P()),
('first', 0, P('x', 'y'), P('y')),
('second', 1, P('x', 'y'), P('x')),
('first2', 0, P(('x', 'y'), None), P(None)),
('all', None, P('x', 'y'), P(), True),
('first', 0, P('x', 'y'), P('y'), True),
('second', 1, P('x', 'y'), P('x'), True),
('first2', 0, P(('x', 'y'), None), P(None), True),
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
)
def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh):
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, in_spec)
arr = jax.device_put(np_inp, s)
@ -4907,14 +4884,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertIn('all-reduce', compiled_text)
@parameterized.named_parameters(
('all', None, P('x', 'y'), P()),
('first', 0, P('x', 'y'), P('y')),
('second', 1, P('x', 'y'), P('x')),
('first2', 0, P(('x', 'y'), None), P(None)),
('all', None, P('x', 'y'), P(), True),
('first', 0, P('x', 'y'), P('y'), True),
('second', 1, P('x', 'y'), P('x'), True),
('first2', 0, P(('x', 'y'), None), P(None), True),
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
)
def test_reduce_max(self, axis, in_spec, out_spec, reduce=True):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, in_spec)
arr = jax.device_put(np_inp, s)
@ -4954,8 +4931,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('2', 2, P('x', 'y', None)),
('-1', -1, P('x', 'y', None)),
)
def test_broadcast_in_dim(self, axis, out_spec):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_broadcast_in_dim(self, axis, out_spec, mesh):
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -4980,8 +4957,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('3', 3),
('4', 4),
)
def test_integer_pow(self, pow):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_integer_pow(self, pow, mesh):
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -5010,12 +4987,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def f(x, y):
return x + y
with self.assertRaisesRegex(
ValueError, "Mesh for all inputs should be equal"):
f(arr1, arr2)
with config.sharding_in_types(True):
with self.assertRaisesRegex(
ValueError, "Mesh for all inputs should be equal"):
f(arr1, arr2)
def test_sin_unop(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_sin_unop(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -5032,8 +5010,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
def test_jnp_array(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_jnp_array(self, mesh):
np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -5048,8 +5026,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
f(arr)
def test_lax_transpose_rule(self):
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
@jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z'))
def test_lax_transpose_rule(self, mesh):
np_inp = np.arange(16).reshape(4, 2, 2)
s = NamedSharding(mesh, P('x', 'y', 'z'))
arr = jax.device_put(np_inp, s)
@ -5067,8 +5045,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
def test_broadcasted_iota_with_sharding(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_broadcasted_iota_with_sharding(self, mesh):
np_inp = np.arange(4)
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np_inp, s)
@ -5094,8 +5072,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
_, out = g(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
def test_einsum_with_out_type(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_einsum_with_out_type(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@ -5140,8 +5118,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out[0].sharding, arr3.sharding)
self.assertEqual(out[1].sharding, arr4.sharding)
def test_einsum_inverse(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_einsum_inverse(self, mesh):
np_inp = np.arange(64.)
@jax.jit
@ -5179,9 +5157,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True),
('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True)
)
@jtu.with_user_mesh((2,), ('x',))
def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec,
use_sharding_arg):
mesh = jtu.create_mesh((2,), ('x',))
use_sharding_arg, mesh):
np_inp = np.arange(math.prod(src_shape),
dtype=np.float32).reshape(src_shape)
arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec))
@ -5209,8 +5187,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)
def test_select(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_select(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr1 = jax.device_put(np_inp, s)
@ -5234,8 +5212,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
TypeError, "select cases must have the same shardings"):
f(arr1 == arr2, arr1, arr3)
def test_device_put_reshard(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_device_put_reshard(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -5250,8 +5228,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
def test_shard_map_full_manual(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_shard_map_full_manual(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
@ -5275,8 +5253,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertArraysEqual(out, (np_inp * np_inp) * 2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
def test_shard_map_dot(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_shard_map_dot(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
@ -5302,8 +5280,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
def test_slice(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_slice(self, mesh):
np_inp = np.arange(16.).reshape(4, 4)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
@ -5333,8 +5311,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"):
f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y')))))
def test_squeeze(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_squeeze(self, mesh):
np_inp = np.arange(16.).reshape(4, 4, 1)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None)))
@ -5359,8 +5337,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)
def test_pad(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_pad(self, mesh):
np_inp = np.arange(8.)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
@ -5401,8 +5379,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y'))))
f(arr, ((4, 4, 1),), None)
def test_concatenate(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
@jtu.with_user_mesh((2, 1), ('x', 'y'))
def test_concatenate(self, mesh):
np_inp = np.arange(16.).reshape(4, 4)
s = NamedSharding(mesh, P('x', 'y'))
arr1 = jax.device_put(np_inp, s)
@ -5443,8 +5421,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = jax.jit(jax.grad(g))(arr1, arr2)
self.assertEqual(out.sharding, s)
def test_scan(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_scan(self, mesh):
carry = jax.device_put(np.arange(16.).reshape(2, 8),
NamedSharding(mesh, P(None, 'x')))
arr = jax.device_put(np.arange(128.).reshape(8, 8, 2),
@ -5481,8 +5459,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
ValueError, "0th dimension of all xs should be replicated"):
f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None))))
def test_argminmax(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_argminmax(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)