[sharding_in_types] Error out when using auto_axes or explicit_axes API when there is no context mesh.

Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.

I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.

PiperOrigin-RevId: 730687231
This commit is contained in:
Yash Katariya 2025-02-24 19:19:11 -08:00 committed by jax authors
parent 41faf51a16
commit b707f0bdbb
2 changed files with 25 additions and 5 deletions

View File

@ -2827,9 +2827,16 @@ batching.skippable_batchers[reshard_p] = lambda _: ()
# -------------------- auto and user mode -------------------------
def _get_new_mesh(axes: str | tuple[str, ...] | None,
axis_type: mesh_lib.AxisTypes,
axis_type: mesh_lib.AxisTypes, name: str,
error_on_manual_to_auto_explict=False):
cur_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Maybe allow fetching mesh from the args to enable
# computation follows data?
if cur_mesh.empty:
raise ValueError(
f'Context mesh {cur_mesh} cannot be empty. Please use'
' `jax.sharding.use_mesh` API to enter into a mesh context when using'
f' `{name}` API.')
if axes is None:
axes = cur_mesh.axis_names
if not isinstance(axes, tuple):
@ -2848,7 +2855,7 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None,
def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
out_shardings):
def decorator(*args, **kwargs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto,
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes',
error_on_manual_to_auto_explict=True)
with mesh_lib.set_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
@ -2860,7 +2867,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
@contextlib.contextmanager
def use_auto_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto)
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes')
with mesh_lib.set_abstract_mesh(new_mesh):
yield
@ -2868,7 +2875,7 @@ def use_auto_axes(*axes):
def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
in_shardings):
def decorator(*args, **kwargs):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes',
error_on_manual_to_auto_explict=True)
with mesh_lib.set_abstract_mesh(new_mesh):
args = mesh_cast(args, in_shardings)
@ -2880,7 +2887,8 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
@contextlib.contextmanager
def use_explicit_axes(*axes):
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit)
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
'use_explicit_axes')
with mesh_lib.set_abstract_mesh(new_mesh):
yield

View File

@ -6863,6 +6863,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.shape, expected_shape)
self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec))
def test_auto_axes_computation_follows_data_error(self):
mesh = jtu.create_mesh((2,), ('x',), axis_types={AxisTypes.Explicit: 'x'})
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np.arange(8), s)
@jax.jit
def f(x):
return x * 2
with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"):
auto_axes(f, out_shardings=s)(arr)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):