Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.

Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.

PiperOrigin-RevId: 716771872
This commit is contained in:
Yash Katariya 2025-01-17 13:00:26 -08:00 committed by jax authors
parent 783d03c5b2
commit 12b59f8e53
3 changed files with 28 additions and 23 deletions

View File

@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc
from jax._src.numpy import array_api_metadata
from jax._src.numpy import lax_numpy
from jax._src import mesh as mesh_lib
from jax._src.pjit import hidden_mode, PartitionSpec
from jax._src.pjit import hidden_axes, PartitionSpec
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
@ -783,8 +783,8 @@ class _IndexUpdateRef:
if out_sharding is not None:
assert isinstance(out_sharding, (NamedSharding, PartitionSpec))
out_sharding = canonicalize_sharding(out_sharding)
take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
out_specs=out_sharding.spec)
take = hidden_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
out_shardings=out_sharding.spec)
return take(self.array, self.index)
def set(self, values, *, indices_are_sorted=False, unique_indices=False,

View File

@ -2771,10 +2771,13 @@ mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering)
# -------------------- auto and user mode -------------------------
def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes):
def _get_new_mesh(axes: str | tuple[str, ...] | None,
axis_type: mesh_lib.AxisTypes):
cur_mesh = mesh_lib.get_abstract_mesh()
if axes is None:
axes = cur_mesh.axis_names # type: ignore
if not isinstance(axes, tuple):
axes = (axes,)
cur_mesh = mesh_lib.get_abstract_mesh()
for a in axes:
if cur_mesh._name_to_type[a] == axis_type: # type: ignore
raise ValueError(f'Axes {a} cannot be casted to type {axis_type} since '
@ -2782,7 +2785,8 @@ def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes):
new_mesh = cur_mesh.update_axis_types({axis_type: axes}) # type: ignore
return new_mesh
def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs):
def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
out_shardings):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
@ -2790,22 +2794,22 @@ def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs):
a.sharding.spec, new_mesh), args)
args = mesh_cast(args, in_specs)
out = fun(*args, **kwargs)
return mesh_cast(out, out_specs)
return mesh_cast(out, out_shardings)
return decorator
@contextlib.contextmanager
def hidden_axes(axes: str | tuple[str, ...]):
def use_hidden_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
with mesh_lib.set_abstract_mesh(new_mesh):
yield
def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs):
def visible_axes(fun, *, axes: str | tuple[str, ...] | None = None,
in_shardings):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
args = mesh_cast(args, in_specs)
args = mesh_cast(args, in_shardings)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
@ -2813,7 +2817,7 @@ def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs):
return decorator
@contextlib.contextmanager
def visible_axes(axes: str | tuple[str, ...]):
def use_visible_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
with mesh_lib.set_abstract_mesh(new_mesh):
yield

View File

@ -51,8 +51,8 @@ from jax._src import sharding_impls
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
from jax._src.pjit import (pjit, mesh_cast, hidden_mode, visible_mode,
hidden_axes, visible_axes)
from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes,
use_hidden_axes, use_visible_axes)
from jax._src import mesh as mesh_lib
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
from jax._src.interpreters import pxla
@ -5816,7 +5816,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
with hidden_axes(axes=('x', 'y')):
with use_hidden_axes('x', 'y'):
y = mesh_cast(y, P(None, None))
self.assertEqual(y.sharding.spec, P(None, None))
z = jnp.sin(y)
@ -5844,7 +5844,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
with visible_axes(axes=('x', 'y')):
with use_visible_axes('x', 'y'):
y = mesh_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
@ -5870,7 +5870,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x):
y = x * 2
with hidden_axes('x'):
with use_hidden_axes('x'):
y = mesh_cast(y, P(None, 'y'))
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
@ -5897,7 +5897,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def f(x, y):
x = x * 2
with hidden_axes('x'):
with use_hidden_axes('x'):
z = x @ y
return z
@ -6009,7 +6009,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jax.jit
def g(x, y):
with hidden_axes('x'):
with use_hidden_axes('x'):
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
return out
@ -6048,7 +6048,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(hidden_mode, axes='x', out_specs=P('x', None))
@partial(hidden_axes, axes='x', out_shardings=P('x', None))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)
@ -6078,7 +6078,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(visible_mode, axes=('x', 'y'), in_specs=P('x', 'y'))
# No axes specified means full visible mode.
@partial(visible_axes, in_shardings=P('x', 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
@ -6108,7 +6109,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(visible_mode, axes='y', in_specs=P('x', 'y'))
@partial(visible_axes, axes='y', in_shardings=P('x', 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P('x', 'y'))
z = jnp.sin(y)
@ -6134,7 +6135,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(visible_mode, axes='y', in_specs=P(None, 'y'))
@partial(visible_axes, axes='y', in_shardings=P(None, 'y'))
def h(y):
self.assertEqual(y.sharding.spec, P(None, 'y'))
z = jnp.sin(y)