mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add a private set_mesh
API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on. PiperOrigin-RevId: 700537898
This commit is contained in:
parent
13726690dd
commit
0d2dfea4b1
@ -212,6 +212,7 @@ if xla_extension_version >= 295:
|
||||
return (axis_env_state.value, mesh_context_manager.value,
|
||||
xla_metadata_context_manager.value,
|
||||
abstract_mesh_context_manager.value,
|
||||
device_context.value,
|
||||
compute_on_context_manager.value, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
@ -245,6 +246,7 @@ else:
|
||||
axis_env_state = ()
|
||||
mesh_context_manager = ()
|
||||
abstract_mesh_context_manager = ()
|
||||
device_context = ()
|
||||
xla_metadata_context_manager = ()
|
||||
compute_on_context_manager = ()
|
||||
|
||||
@ -255,12 +257,14 @@ else:
|
||||
mesh_context_manager = context.mesh_context_manager
|
||||
if context and context.abstract_mesh_context_manager:
|
||||
abstract_mesh_context_manager = context.abstract_mesh_context_manager
|
||||
if context and context.device_context:
|
||||
device_context = context.device_context
|
||||
if context and context.xla_metadata_context_manager:
|
||||
xla_metadata_context_manager = context.xla_metadata_context_manager
|
||||
if context and context.compute_on_context_manager:
|
||||
compute_on_context_manager = context.compute_on_context_manager
|
||||
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
|
||||
xla_metadata_context_manager,
|
||||
device_context, xla_metadata_context_manager,
|
||||
compute_on_context_manager, enable_x64.value,
|
||||
numpy_rank_promotion.value, default_matmul_precision.value,
|
||||
dynamic_shapes.value,
|
||||
@ -976,6 +980,7 @@ if xla_extension_version >= 295:
|
||||
axis_env_state = config_ext.Config((), include_in_jit_key=True)
|
||||
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
device_context = config_ext.Config((), include_in_jit_key=True)
|
||||
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
else:
|
||||
@ -1019,6 +1024,7 @@ else:
|
||||
axis_env_state: Hashable = ()
|
||||
mesh_context_manager: Hashable = ()
|
||||
abstract_mesh_context_manager: Hashable = ()
|
||||
device_context: Hashable = ()
|
||||
compute_on_context_manager: Hashable = ()
|
||||
xla_metadata_context_manager: Hashable = ()
|
||||
|
||||
@ -1086,6 +1092,7 @@ else:
|
||||
axis_env_state = JitConfig('axis_env_state')
|
||||
mesh_context_manager = JitConfig('mesh_context_manager')
|
||||
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
|
||||
device_context = JitConfig('device_context')
|
||||
compute_on_context_manager = JitConfig('compute_on_context_manager')
|
||||
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
|
||||
|
||||
|
@ -1605,7 +1605,7 @@ def get_sharding(sharding, ndim):
|
||||
assert len(sharding.spec) == ndim
|
||||
return sharding
|
||||
|
||||
context_mesh = mesh_lib.mesh_context.mesh
|
||||
context_mesh = mesh_lib.abstract_mesh_context.mesh
|
||||
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
|
||||
# code.
|
||||
if context_mesh is None:
|
||||
|
@ -2193,8 +2193,15 @@ def lower_sharding_computation(
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
if config.sharding_in_types.value:
|
||||
# TODO(yashkatariya): Thread it via jit path and remove the None check by
|
||||
# making tests go via set_mesh API always.
|
||||
devices_from_context = (
|
||||
None if mesh_lib.device_context.concrete_mesh is None
|
||||
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
|
||||
else:
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
unique_intermediate_shardings = util.stable_unique(
|
||||
|
@ -455,10 +455,10 @@ class AbstractMesh:
|
||||
_raise_value_error("local_mesh")
|
||||
|
||||
def __enter__(self):
|
||||
return push_mesh_context(self)
|
||||
return push_abstract_mesh_context(self)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pop_mesh_context()
|
||||
pop_abstract_mesh_context()
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@ -473,27 +473,29 @@ def _raise_value_error(name):
|
||||
raise ValueError(f"AbstractMesh does not implement {name}")
|
||||
|
||||
|
||||
class MeshContext(threading.local):
|
||||
class AbstractMeshContext(threading.local):
|
||||
def __init__(self):
|
||||
self.stack = [None]
|
||||
self.mesh = self.stack[-1]
|
||||
|
||||
mesh_context = MeshContext()
|
||||
abstract_mesh_context = AbstractMeshContext()
|
||||
|
||||
def push_mesh_context(val):
|
||||
mesh_context.stack.append(val)
|
||||
mesh_context.mesh = val
|
||||
def push_abstract_mesh_context(val):
|
||||
abstract_mesh_context.stack.append(val)
|
||||
abstract_mesh_context.mesh = val
|
||||
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
|
||||
# Right now that leads to weird numerical issues.
|
||||
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
|
||||
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
|
||||
if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
|
||||
return val
|
||||
|
||||
def pop_mesh_context():
|
||||
mesh_context.stack.pop()
|
||||
mesh_context.mesh = mesh_context.stack[-1]
|
||||
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
|
||||
def pop_abstract_mesh_context():
|
||||
abstract_mesh_context.stack.pop()
|
||||
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
|
||||
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
|
||||
if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
|
||||
|
||||
@ -501,8 +503,40 @@ def pop_mesh_context():
|
||||
class null_mesh_context:
|
||||
|
||||
def __enter__(self):
|
||||
return push_mesh_context(None)
|
||||
return push_abstract_mesh_context(None)
|
||||
|
||||
def __exit__(self, *excinfo):
|
||||
pop_mesh_context()
|
||||
pop_abstract_mesh_context()
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_mesh(mesh: Mesh):
|
||||
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
|
||||
enter_device_context(mesh)):
|
||||
yield
|
||||
|
||||
|
||||
class DeviceContext(threading.local):
|
||||
def __init__(self):
|
||||
self.stack = [None]
|
||||
self.concrete_mesh = self.stack[-1]
|
||||
|
||||
device_context = DeviceContext()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enter_device_context(mesh: Mesh):
|
||||
device_context.stack.append(mesh)
|
||||
device_context.concrete_mesh = mesh
|
||||
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.device_context.set_local(non_none_meshes)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
device_context.stack.pop()
|
||||
device_context.concrete_mesh = device_context.stack[-1]
|
||||
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
|
||||
if non_none_meshes:
|
||||
jax_config.device_context.set_local(non_none_meshes)
|
||||
|
@ -644,8 +644,8 @@ def _infer_params_impl(
|
||||
attr_token = _attr_token(flat_fun, in_type)
|
||||
|
||||
abstract_mesh = (
|
||||
get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None
|
||||
else mesh_lib.mesh_context.mesh)
|
||||
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
|
||||
else mesh_lib.abstract_mesh_context.mesh)
|
||||
with abstract_mesh:
|
||||
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
|
@ -4622,6 +4622,28 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
ins, _ = f.lower(np.arange(8)).compile().input_shardings
|
||||
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
def test_sharding_in_types_with_set_mesh(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("ShiT doesn't work with shardy")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
x = x * 2
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
x = x * x
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
return x
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
@ -5229,7 +5251,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def g(x, y):
|
||||
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
|
||||
return x * y
|
||||
|
||||
@jax.jit
|
||||
@ -5254,7 +5276,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def g(x, y):
|
||||
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective)
|
||||
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
|
||||
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
|
||||
z = x @ allgatherd_y
|
||||
return jax.lax.psum(z, axis_name='y')
|
||||
|
Loading…
x
Reference in New Issue
Block a user