mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
set_mesh
should return the prev_mesh instead of nothing. Users can choose to use the return value or ignore it.
PiperOrigin-RevId: 738039559
This commit is contained in:
parent
7c5871f464
commit
a5c0f200e7
@ -1391,12 +1391,20 @@ def use_mesh(mesh: mesh_lib.Mesh):
|
||||
mesh_lib.use_concrete_mesh(mesh)):
|
||||
yield
|
||||
|
||||
def set_mesh(mesh: mesh_lib.Mesh) -> None:
|
||||
if not isinstance(mesh, mesh_lib.Mesh):
|
||||
def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
|
||||
"""Sets the given concrete mesh globally and returns the previous concrete
|
||||
mesh."""
|
||||
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
|
||||
raise ValueError(
|
||||
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
|
||||
if not core.trace_state_clean():
|
||||
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
|
||||
|
||||
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
|
||||
config.device_context.set_local(mesh)
|
||||
if mesh is None:
|
||||
config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore
|
||||
else:
|
||||
config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore
|
||||
|
||||
prev_mesh = config.device_context.get_global()
|
||||
config.device_context.set_global(mesh)
|
||||
return prev_mesh
|
||||
|
@ -7096,16 +7096,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_set_mesh(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
|
||||
prev_mesh = config.device_context.value
|
||||
prev_abstract_mesh = config.abstract_mesh_context_manager.value
|
||||
try:
|
||||
jax.sharding.set_mesh(mesh)
|
||||
|
||||
prev_mesh = jax.sharding.set_mesh(mesh)
|
||||
out = reshard(np.arange(8), P('x'))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
finally:
|
||||
config.device_context.set_local(prev_mesh)
|
||||
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
|
||||
jax.sharding.set_mesh(prev_mesh)
|
||||
|
||||
@jtu.with_user_mesh((2,), ('x',))
|
||||
def test_auto_axes_late_bind(self, mesh):
|
||||
|
Loading…
x
Reference in New Issue
Block a user