mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Rename hidden_mode -> hidden_axes and hidden_mode_ctx -> use_hidden_axes. Same for visible mode and visible_mode_ctx.
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode. PiperOrigin-RevId: 716771872
This commit is contained in:
parent
783d03c5b2
commit
12b59f8e53
@ -42,7 +42,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import array_api_metadata
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.pjit import hidden_mode, PartitionSpec
|
||||
from jax._src.pjit import hidden_axes, PartitionSpec
|
||||
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
@ -783,8 +783,8 @@ class _IndexUpdateRef:
|
||||
if out_sharding is not None:
|
||||
assert isinstance(out_sharding, (NamedSharding, PartitionSpec))
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
|
||||
out_specs=out_sharding.spec)
|
||||
take = hidden_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
|
||||
out_shardings=out_sharding.spec)
|
||||
return take(self.array, self.index)
|
||||
|
||||
def set(self, values, *, indices_are_sorted=False, unique_indices=False,
|
||||
|
@ -2771,10 +2771,13 @@ mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering)
|
||||
|
||||
# -------------------- auto and user mode -------------------------
|
||||
|
||||
def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes):
|
||||
def _get_new_mesh(axes: str | tuple[str, ...] | None,
|
||||
axis_type: mesh_lib.AxisTypes):
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if axes is None:
|
||||
axes = cur_mesh.axis_names # type: ignore
|
||||
if not isinstance(axes, tuple):
|
||||
axes = (axes,)
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
for a in axes:
|
||||
if cur_mesh._name_to_type[a] == axis_type: # type: ignore
|
||||
raise ValueError(f'Axes {a} cannot be casted to type {axis_type} since '
|
||||
@ -2782,7 +2785,8 @@ def _get_new_mesh(axes: str | tuple[str, ...], axis_type: mesh_lib.AxisTypes):
|
||||
new_mesh = cur_mesh.update_axis_types({axis_type: axes}) # type: ignore
|
||||
return new_mesh
|
||||
|
||||
def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs):
|
||||
def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
out_shardings):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
|
||||
def decorator(*args, **kwargs):
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
@ -2790,22 +2794,22 @@ def hidden_mode(fun, *, axes: str | tuple[str, ...], out_specs):
|
||||
a.sharding.spec, new_mesh), args)
|
||||
args = mesh_cast(args, in_specs)
|
||||
out = fun(*args, **kwargs)
|
||||
return mesh_cast(out, out_specs)
|
||||
return mesh_cast(out, out_shardings)
|
||||
return decorator
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hidden_axes(axes: str | tuple[str, ...]):
|
||||
def use_hidden_axes(*axes):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
|
||||
def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs):
|
||||
def visible_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
in_shardings):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
|
||||
def decorator(*args, **kwargs):
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
args = mesh_cast(args, in_specs)
|
||||
args = mesh_cast(args, in_shardings)
|
||||
out = fun(*args, **kwargs)
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
|
||||
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
@ -2813,7 +2817,7 @@ def visible_mode(fun, *, axes: str | tuple[str, ...], in_specs):
|
||||
return decorator
|
||||
|
||||
@contextlib.contextmanager
|
||||
def visible_axes(axes: str | tuple[str, ...]):
|
||||
def use_visible_axes(*axes):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Visible)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
@ -51,8 +51,8 @@ from jax._src import sharding_impls
|
||||
from jax._src.sharding_impls import (
|
||||
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
|
||||
SingleDeviceSharding, parse_flatten_op_sharding)
|
||||
from jax._src.pjit import (pjit, mesh_cast, hidden_mode, visible_mode,
|
||||
hidden_axes, visible_axes)
|
||||
from jax._src.pjit import (pjit, mesh_cast, hidden_axes, visible_axes,
|
||||
use_hidden_axes, use_visible_axes)
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
|
||||
from jax._src.interpreters import pxla
|
||||
@ -5816,7 +5816,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
with hidden_axes(axes=('x', 'y')):
|
||||
with use_hidden_axes('x', 'y'):
|
||||
y = mesh_cast(y, P(None, None))
|
||||
self.assertEqual(y.sharding.spec, P(None, None))
|
||||
z = jnp.sin(y)
|
||||
@ -5844,7 +5844,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
with visible_axes(axes=('x', 'y')):
|
||||
with use_visible_axes('x', 'y'):
|
||||
y = mesh_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -5870,7 +5870,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
with hidden_axes('x'):
|
||||
with use_hidden_axes('x'):
|
||||
y = mesh_cast(y, P(None, 'y'))
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -5897,7 +5897,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
x = x * 2
|
||||
with hidden_axes('x'):
|
||||
with use_hidden_axes('x'):
|
||||
z = x @ y
|
||||
return z
|
||||
|
||||
@ -6009,7 +6009,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
with hidden_axes('x'):
|
||||
with use_hidden_axes('x'):
|
||||
out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None))
|
||||
return out
|
||||
|
||||
@ -6048,7 +6048,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(hidden_mode, axes='x', out_specs=P('x', None))
|
||||
@partial(hidden_axes, axes='x', out_shardings=P('x', None))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -6078,7 +6078,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(visible_mode, axes=('x', 'y'), in_specs=P('x', 'y'))
|
||||
# No axes specified means full visible mode.
|
||||
@partial(visible_axes, in_shardings=P('x', 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -6108,7 +6109,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(visible_mode, axes='y', in_specs=P('x', 'y'))
|
||||
@partial(visible_axes, axes='y', in_shardings=P('x', 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P('x', 'y'))
|
||||
z = jnp.sin(y)
|
||||
@ -6134,7 +6135,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@partial(visible_mode, axes='y', in_specs=P(None, 'y'))
|
||||
@partial(visible_axes, axes='y', in_shardings=P(None, 'y'))
|
||||
def h(y):
|
||||
self.assertEqual(y.sharding.spec, P(None, 'y'))
|
||||
z = jnp.sin(y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user