[sharding_in_types] Allow auto_axes and explicit_axes to take numpy arrays, python scalars.

PiperOrigin-RevId: 729729215
This commit is contained in:
Yash Katariya 2025-02-21 18:48:26 -08:00 committed by jax authors
parent 34077851d8
commit 7c4fe2a7cc
4 changed files with 45 additions and 12 deletions

View File

@ -63,7 +63,8 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding as JSharding
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.mesh import (AbstractMesh, Mesh, get_abstract_mesh,
get_concrete_mesh)
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue,
get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources,
@ -2190,6 +2191,20 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
return tuple(out)
def _get_context_mesh(context_mesh: Mesh | None) -> Mesh | None:
if context_mesh is None:
return context_mesh
# Don't update the mesh because the old `with mesh` ctx mgr is set.
if get_concrete_mesh() is None:
return context_mesh
cur_mesh = get_abstract_mesh()
if cur_mesh.empty or context_mesh.empty:
return context_mesh
if cur_mesh == context_mesh.abstract_mesh:
return context_mesh
return context_mesh.update_axis_types(cur_mesh.axis_types)
@profiler.annotate_function
def lower_sharding_computation(
closed_jaxpr: core.ClosedJaxpr,
@ -2245,6 +2260,8 @@ def lower_sharding_computation(
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
context_mesh = _get_context_mesh(context_mesh)
devices_from_context = (None if context_mesh is None or context_mesh.empty
else context_mesh._flat_devices_tuple)
# Device assignment across all inputs, outputs and shardings inside jaxpr
@ -2609,10 +2626,10 @@ def maybe_recover_user_shardings(
new_shardings, [None] * len(new_shardings), i, None)
# For nullary cases like: `jit(lambda: ..., out_shardings=(None, sharding))`
for oi in new_shardings:
if oi is not None and type(oi) in _orig_out_sharding_handlers:
for ns in new_shardings:
if ns is not None and type(ns) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, [None] * len(new_shardings), oi, None)
new_shardings, new_avals, ns, None)
if context_mesh is not None and not context_mesh.empty:
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)

View File

@ -191,12 +191,11 @@ class _BaseMesh:
def _name_to_type(self):
return dict(safe_zip(self.axis_names, self._axis_types_tuple))
def update_axis_types(self, new_axis_types) -> AbstractMesh:
def _get_new_axis_types(self, new_axis_types):
# dict(self._name_to_type) will copy it.
updated_name_to_type = dict(self._name_to_type)
updated_name_to_type.update(axis_names_to_types(new_axis_types))
new_axis_types = axis_types_to_names(updated_name_to_type)
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
return axis_types_to_names(updated_name_to_type)
_mesh_object_dict = {} # type: ignore
@ -427,6 +426,10 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator):
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
def update_axis_types(self, new_axis_types) -> Mesh:
new_axis_types = self._get_new_axis_types(new_axis_types)
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@ -505,6 +508,10 @@ class AbstractMesh(_BaseMesh):
def abstract_mesh(self):
return self
def update_axis_types(self, new_axis_types) -> AbstractMesh:
new_axis_types = self._get_new_axis_types(new_axis_types)
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
@property
def devices(self):
_raise_value_error("devices")

View File

@ -2742,9 +2742,7 @@ def _mesh_cast_abstract_eval(aval, dst_sharding):
mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval)
def _mesh_cast_impl(x, dst_sharding):
x_aval = core.shaped_abstractify(x)
with mesh_lib.set_abstract_mesh(x_aval.sharding.mesh):
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
mesh_cast_p.def_impl(_mesh_cast_impl)
def _mesh_cast_transpose_rule(ct, x, dst_sharding):
@ -2851,7 +2849,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
error_on_manual_to_auto_explict=True)
with mesh_lib.set_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
a.aval.sharding.spec, new_mesh), args)
core.get_aval(a).sharding.spec, new_mesh), args)
args = mesh_cast(args, in_specs)
out = fun(*args, **kwargs)
return mesh_cast(out, out_shardings)
@ -2873,7 +2871,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
args = mesh_cast(args, in_shardings)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual(
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
core.get_aval(o).sharding.spec, mesh_lib.get_abstract_mesh()), out)
return mesh_cast(out, out_specs)
return decorator

View File

@ -6753,6 +6753,17 @@ class ShardingInTypesTest(jtu.JaxTestCase):
jax.jit(shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
)(np.arange(16).reshape(8, 2))
@jtu.with_user_mesh((2,), ('x',))
def test_auto_axes_numpy_array(self, mesh):
@jax.jit
def f(x):
self.assertTrue(x.aval.sharding.mesh._are_all_axes_auto)
return x * 2
out = auto_axes(f, out_shardings=P('x'))(np.arange(8))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
self.assertArraysEqual(out, np.arange(8) * 2)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):