Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).

Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
This commit is contained in:
Yash Katariya 2024-11-26 20:00:19 -08:00 committed by jax authors
parent 13726690dd
commit 0d2dfea4b1
6 changed files with 92 additions and 22 deletions

View File

@ -212,6 +212,7 @@ if xla_extension_version >= 295:
return (axis_env_state.value, mesh_context_manager.value,
xla_metadata_context_manager.value,
abstract_mesh_context_manager.value,
device_context.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
@ -245,6 +246,7 @@ else:
axis_env_state = ()
mesh_context_manager = ()
abstract_mesh_context_manager = ()
device_context = ()
xla_metadata_context_manager = ()
compute_on_context_manager = ()
@ -255,12 +257,14 @@ else:
mesh_context_manager = context.mesh_context_manager
if context and context.abstract_mesh_context_manager:
abstract_mesh_context_manager = context.abstract_mesh_context_manager
if context and context.device_context:
device_context = context.device_context
if context and context.xla_metadata_context_manager:
xla_metadata_context_manager = context.xla_metadata_context_manager
if context and context.compute_on_context_manager:
compute_on_context_manager = context.compute_on_context_manager
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
xla_metadata_context_manager,
device_context, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
@ -976,6 +980,7 @@ if xla_extension_version >= 295:
axis_env_state = config_ext.Config((), include_in_jit_key=True)
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
device_context = config_ext.Config((), include_in_jit_key=True)
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
else:
@ -1019,6 +1024,7 @@ else:
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
abstract_mesh_context_manager: Hashable = ()
device_context: Hashable = ()
compute_on_context_manager: Hashable = ()
xla_metadata_context_manager: Hashable = ()
@ -1086,6 +1092,7 @@ else:
axis_env_state = JitConfig('axis_env_state')
mesh_context_manager = JitConfig('mesh_context_manager')
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
device_context = JitConfig('device_context')
compute_on_context_manager = JitConfig('compute_on_context_manager')
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')

View File

@ -1605,7 +1605,7 @@ def get_sharding(sharding, ndim):
assert len(sharding.spec) == ndim
return sharding
context_mesh = mesh_lib.mesh_context.mesh
context_mesh = mesh_lib.abstract_mesh_context.mesh
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
# code.
if context_mesh is None:

View File

@ -2193,8 +2193,15 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
if config.sharding_in_types.value:
# TODO(yashkatariya): Thread it via jit path and remove the None check by
# making tests go via set_mesh API always.
devices_from_context = (
None if mesh_lib.device_context.concrete_mesh is None
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
else:
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
unique_intermediate_shardings = util.stable_unique(

View File

@ -455,10 +455,10 @@ class AbstractMesh:
_raise_value_error("local_mesh")
def __enter__(self):
return push_mesh_context(self)
return push_abstract_mesh_context(self)
def __exit__(self, exc_type, exc_value, traceback):
pop_mesh_context()
pop_abstract_mesh_context()
return False
@staticmethod
@ -473,27 +473,29 @@ def _raise_value_error(name):
raise ValueError(f"AbstractMesh does not implement {name}")
class MeshContext(threading.local):
class AbstractMeshContext(threading.local):
def __init__(self):
self.stack = [None]
self.mesh = self.stack[-1]
mesh_context = MeshContext()
abstract_mesh_context = AbstractMeshContext()
def push_mesh_context(val):
mesh_context.stack.append(val)
mesh_context.mesh = val
def push_abstract_mesh_context(val):
abstract_mesh_context.stack.append(val)
abstract_mesh_context.mesh = val
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
# Right now that leads to weird numerical issues.
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
if m is not None)
if non_none_meshes:
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
return val
def pop_mesh_context():
mesh_context.stack.pop()
mesh_context.mesh = mesh_context.stack[-1]
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
def pop_abstract_mesh_context():
abstract_mesh_context.stack.pop()
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
if m is not None)
if non_none_meshes:
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
@ -501,8 +503,40 @@ def pop_mesh_context():
class null_mesh_context:
def __enter__(self):
return push_mesh_context(None)
return push_abstract_mesh_context(None)
def __exit__(self, *excinfo):
pop_mesh_context()
pop_abstract_mesh_context()
return False
@contextlib.contextmanager
def set_mesh(mesh: Mesh):
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
enter_device_context(mesh)):
yield
class DeviceContext(threading.local):
def __init__(self):
self.stack = [None]
self.concrete_mesh = self.stack[-1]
device_context = DeviceContext()
@contextlib.contextmanager
def enter_device_context(mesh: Mesh):
device_context.stack.append(mesh)
device_context.concrete_mesh = mesh
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
if non_none_meshes:
jax_config.device_context.set_local(non_none_meshes)
try:
yield
finally:
device_context.stack.pop()
device_context.concrete_mesh = device_context.stack[-1]
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
if non_none_meshes:
jax_config.device_context.set_local(non_none_meshes)

View File

@ -644,8 +644,8 @@ def _infer_params_impl(
attr_token = _attr_token(flat_fun, in_type)
abstract_mesh = (
get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None
else mesh_lib.mesh_context.mesh)
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
else mesh_lib.abstract_mesh_context.mesh)
with abstract_mesh:
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
flat_fun, in_type, attr_token, dbg,

View File

@ -4622,6 +4622,28 @@ class ArrayPjitTest(jtu.JaxTestCase):
ins, _ = f.lower(np.arange(8)).compile().input_shardings
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
def test_sharding_in_types_with_set_mesh(self):
if config.use_shardy_partitioner.value:
self.skipTest("ShiT doesn't work with shardy")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
with mesh_lib.set_mesh(mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
x = x * 2
self.assertEqual(x.sharding.spec, s.spec)
x = x * x
self.assertEqual(x.sharding.spec, s.spec)
return x
out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@ -5229,7 +5251,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def g(x, y):
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
return x * y
@jax.jit
@ -5254,7 +5276,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
def g(x, y):
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective)
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
z = x @ allgatherd_y
return jax.lax.psum(z, axis_name='y')