mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sharding_in_types] Allow auto_axes
and explicit_axes
to take numpy arrays, python scalars.
PiperOrigin-RevId: 729729215
This commit is contained in:
parent
34077851d8
commit
7c4fe2a7cc
@ -63,7 +63,8 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.mesh import (AbstractMesh, Mesh, get_abstract_mesh,
|
||||
get_concrete_mesh)
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue,
|
||||
get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources,
|
||||
@ -2190,6 +2191,20 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def _get_context_mesh(context_mesh: Mesh | None) -> Mesh | None:
|
||||
if context_mesh is None:
|
||||
return context_mesh
|
||||
# Don't update the mesh because the old `with mesh` ctx mgr is set.
|
||||
if get_concrete_mesh() is None:
|
||||
return context_mesh
|
||||
cur_mesh = get_abstract_mesh()
|
||||
if cur_mesh.empty or context_mesh.empty:
|
||||
return context_mesh
|
||||
if cur_mesh == context_mesh.abstract_mesh:
|
||||
return context_mesh
|
||||
return context_mesh.update_axis_types(cur_mesh.axis_types)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
@ -2245,6 +2260,8 @@ def lower_sharding_computation(
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
context_mesh = _get_context_mesh(context_mesh)
|
||||
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
@ -2609,10 +2626,10 @@ def maybe_recover_user_shardings(
|
||||
new_shardings, [None] * len(new_shardings), i, None)
|
||||
|
||||
# For nullary cases like: `jit(lambda: ..., out_shardings=(None, sharding))`
|
||||
for oi in new_shardings:
|
||||
if oi is not None and type(oi) in _orig_out_sharding_handlers:
|
||||
for ns in new_shardings:
|
||||
if ns is not None and type(ns) in _orig_out_sharding_handlers:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, [None] * len(new_shardings), oi, None)
|
||||
new_shardings, new_avals, ns, None)
|
||||
|
||||
if context_mesh is not None and not context_mesh.empty:
|
||||
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)
|
||||
|
@ -191,12 +191,11 @@ class _BaseMesh:
|
||||
def _name_to_type(self):
|
||||
return dict(safe_zip(self.axis_names, self._axis_types_tuple))
|
||||
|
||||
def update_axis_types(self, new_axis_types) -> AbstractMesh:
|
||||
def _get_new_axis_types(self, new_axis_types):
|
||||
# dict(self._name_to_type) will copy it.
|
||||
updated_name_to_type = dict(self._name_to_type)
|
||||
updated_name_to_type.update(axis_names_to_types(new_axis_types))
|
||||
new_axis_types = axis_types_to_names(updated_name_to_type)
|
||||
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
|
||||
return axis_types_to_names(updated_name_to_type)
|
||||
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
@ -427,6 +426,10 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator):
|
||||
def abstract_mesh(self):
|
||||
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
|
||||
|
||||
def update_axis_types(self, new_axis_types) -> Mesh:
|
||||
new_axis_types = self._get_new_axis_types(new_axis_types)
|
||||
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)
|
||||
|
||||
|
||||
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
||||
|
||||
@ -505,6 +508,10 @@ class AbstractMesh(_BaseMesh):
|
||||
def abstract_mesh(self):
|
||||
return self
|
||||
|
||||
def update_axis_types(self, new_axis_types) -> AbstractMesh:
|
||||
new_axis_types = self._get_new_axis_types(new_axis_types)
|
||||
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
|
||||
|
||||
@property
|
||||
def devices(self):
|
||||
_raise_value_error("devices")
|
||||
|
@ -2742,9 +2742,7 @@ def _mesh_cast_abstract_eval(aval, dst_sharding):
|
||||
mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval)
|
||||
|
||||
def _mesh_cast_impl(x, dst_sharding):
|
||||
x_aval = core.shaped_abstractify(x)
|
||||
with mesh_lib.set_abstract_mesh(x_aval.sharding.mesh):
|
||||
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
|
||||
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
|
||||
mesh_cast_p.def_impl(_mesh_cast_impl)
|
||||
|
||||
def _mesh_cast_transpose_rule(ct, x, dst_sharding):
|
||||
@ -2851,7 +2849,7 @@ def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
error_on_manual_to_auto_explict=True)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
|
||||
a.aval.sharding.spec, new_mesh), args)
|
||||
core.get_aval(a).sharding.spec, new_mesh), args)
|
||||
args = mesh_cast(args, in_specs)
|
||||
out = fun(*args, **kwargs)
|
||||
return mesh_cast(out, out_shardings)
|
||||
@ -2873,7 +2871,7 @@ def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
args = mesh_cast(args, in_shardings)
|
||||
out = fun(*args, **kwargs)
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual(
|
||||
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
core.get_aval(o).sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
return mesh_cast(out, out_specs)
|
||||
return decorator
|
||||
|
||||
|
@ -6753,6 +6753,17 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
jax.jit(shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
|
||||
)(np.arange(16).reshape(8, 2))
|
||||
|
||||
@jtu.with_user_mesh((2,), ('x',))
|
||||
def test_auto_axes_numpy_array(self, mesh):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertTrue(x.aval.sharding.mesh._are_all_axes_auto)
|
||||
return x * 2
|
||||
|
||||
out = auto_axes(f, out_shardings=P('x'))(np.arange(8))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
self.assertArraysEqual(out, np.arange(8) * 2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user