mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[sharding_in_types] Error out when using auto_axes
or explicit_axes
API when there is no context mesh.
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here. I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`. PiperOrigin-RevId: 730687231
This commit is contained in:
parent
41faf51a16
commit
b707f0bdbb
@ -2827,9 +2827,16 @@ batching.skippable_batchers[reshard_p] = lambda _: ()
|
||||
# -------------------- auto and user mode -------------------------
|
||||
|
||||
def _get_new_mesh(axes: str | tuple[str, ...] | None,
|
||||
axis_type: mesh_lib.AxisTypes,
|
||||
axis_type: mesh_lib.AxisTypes, name: str,
|
||||
error_on_manual_to_auto_explict=False):
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
# TODO(yashkatariya): Maybe allow fetching mesh from the args to enable
|
||||
# computation follows data?
|
||||
if cur_mesh.empty:
|
||||
raise ValueError(
|
||||
f'Context mesh {cur_mesh} cannot be empty. Please use'
|
||||
' `jax.sharding.use_mesh` API to enter into a mesh context when using'
|
||||
f' `{name}` API.')
|
||||
if axes is None:
|
||||
axes = cur_mesh.axis_names
|
||||
if not isinstance(axes, tuple):
|
||||
@ -2848,7 +2855,7 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None,
|
||||
def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
out_shardings):
|
||||
def decorator(*args, **kwargs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto,
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'auto_axes',
|
||||
error_on_manual_to_auto_explict=True)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
|
||||
@ -2860,7 +2867,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_auto_axes(*axes):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto)
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Auto, 'use_auto_axes')
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
@ -2868,7 +2875,7 @@ def use_auto_axes(*axes):
|
||||
def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
in_shardings):
|
||||
def decorator(*args, **kwargs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit, 'explicit_axes',
|
||||
error_on_manual_to_auto_explict=True)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
args = mesh_cast(args, in_shardings)
|
||||
@ -2880,7 +2887,8 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_explicit_axes(*axes):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit)
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Explicit,
|
||||
'use_explicit_axes')
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
yield
|
||||
|
||||
|
@ -6863,6 +6863,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, expected_shape)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec))
|
||||
|
||||
def test_auto_axes_computation_follows_data_error(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',), axis_types={AxisTypes.Explicit: 'x'})
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
arr = jax.device_put(np.arange(8), s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"):
|
||||
auto_axes(f, out_shardings=s)(arr)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user