[sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.

Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
This commit is contained in:
Yash Katariya 2024-12-10 18:02:42 -08:00 committed by jax authors
parent e88b578356
commit b5e4fd161d
10 changed files with 157 additions and 56 deletions

View File

@ -455,6 +455,7 @@ pytype_strict_library(
":dtypes",
":effects",
":mesh",
":partition_spec",
":pretty_printer",
":source_info_util",
":traceback_util",
@ -558,6 +559,7 @@ pytype_strict_library(
":layout",
":op_shardings",
":partial_eval",
":partition_spec",
":path",
":pickle_util",
":sharding",

View File

@ -39,6 +39,7 @@ from jax._src import config
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
@ -1599,13 +1600,30 @@ def _invalid_shape_error(shape: Shape, context: str=""):
return TypeError(msg)
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def _maybe_modify_sharding(sharding):
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
return sharding
new_spec = []
for s in sharding.spec:
if s is None or isinstance(s, UnconstrainedSingleton):
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
P.UNCONSTRAINED
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
return sharding.with_spec(new_spec)
def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
from jax._src.sharding_impls import NamedSharding # type: ignore
if sharding is not None:
assert len(sharding.spec) == ndim
return sharding
return _maybe_modify_sharding(sharding)
context_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
@ -1675,9 +1693,7 @@ class ShapedArray(UnshapedArray):
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
axis_types = self.sharding.mesh.axis_types
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
return f'{dt_str}[{shapestr}]{axt}'
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
@ -1689,26 +1705,13 @@ class ShapedArray(UnshapedArray):
raise TypeError("len() of unsized object") from err # same as numpy error
def _get_axis_type_str(axis_types):
from jax._src.mesh import AxisTypes # type: ignore
out = []
for t, axes in axis_types.items():
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
if t == AxisTypes.Collective:
out.append(f"C:{a}")
elif t == AxisTypes.User:
out.append(f"U:{a}")
else:
assert t == AxisTypes.Auto
out.append(f"A:{a}")
return f"{{{', '.join(out)}}}"
def _get_shape_sharding_str(shape, spec):
out = []
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, UnconstrainedSingleton):
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")

View File

@ -50,6 +50,7 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO
from jax._src.partition_spec import UnconstrainedSingleton
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects, ir, passmanager
@ -2524,12 +2525,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
if aval.sharding.mesh._are_all_axes_collective:
return op
if aval.sharding.mesh._are_all_axes_auto:
return op
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
# `return op` early and avoid bloating HLO size.
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
if sharding_proto is None else sharding_proto)
# TODO(yashkatariya): Enable this
# unspecified_dims = (set(range(aval.ndim))
# if aval.sharding.mesh._any_axis_collective else None)
return wrap_with_sharding_op(ctx, op, aval, proto)
unspecified_dims = None
if aval.sharding.mesh._any_axis_collective:
unspecified_dims = set(range(aval.ndim))
elif aval.sharding.mesh._any_axis_auto:
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
if isinstance(s, UnconstrainedSingleton)}
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):

View File

@ -63,7 +63,7 @@ from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
@ -2123,11 +2123,13 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
@lru_cache(maxsize=128)
def _abstract_to_concrete_mesh(abstract_mesh):
return mesh_lib.Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
axis_types=abstract_mesh.axis_types)
out = []
for s, a in zip(shardings, avals):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:

View File

@ -124,6 +124,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
_mesh_object_dict = {} # type: ignore
MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]
class Mesh(contextlib.ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.
@ -178,11 +179,11 @@ class Mesh(contextlib.ContextDecorator):
devices: np.ndarray
axis_names: tuple[MeshAxisName, ...]
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
axis_types: MeshAxisType | None
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names: str | Sequence[MeshAxisName],
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
axis_names: str | Sequence[MeshAxisName], *,
axis_types: MeshAxisType | None = None):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
@ -216,7 +217,8 @@ class Mesh(contextlib.ContextDecorator):
return self
def __reduce__(self):
return (type(self), (self.devices, self.axis_names, self.axis_types))
return (type(self), (self.devices, self.axis_names),
{'axis_types': self.axis_types})
def __eq__(self, other):
if not isinstance(other, Mesh):
@ -348,7 +350,7 @@ class Mesh(contextlib.ContextDecorator):
@functools.cached_property
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, self.axis_types)
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
@ -373,8 +375,8 @@ class AbstractMesh:
details.
"""
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
axis_types: MeshAxisType | None = None):
self.shape_tuple = shape_tuple
self.axis_types = axis_types
if self.shape_tuple:
@ -434,6 +436,24 @@ class AbstractMesh:
return False
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _are_all_axes_auto(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_collective(self) -> bool:
if self.axis_types is None:
return False
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
@functools.cached_property
def _any_axis_auto(self) -> bool:
if self.axis_types is None:
return False
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
@property
def devices(self):
_raise_value_error("devices")
@ -474,6 +494,8 @@ def _raise_value_error(name):
@contextlib.contextmanager
def set_abstract_mesh(mesh: AbstractMesh):
if mesh is not None and mesh.axis_types is None:
raise RuntimeError('Please set the AxisTypes of Mesh.')
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
try:
yield

View File

@ -14,7 +14,7 @@
from __future__ import annotations
class _UnconstrainedPartitionSingleton:
class UnconstrainedSingleton:
def __repr__(self):
return "UNCONSTRAINED"
@ -23,7 +23,7 @@ class _UnconstrainedPartitionSingleton:
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()
class PartitionSpec(tuple):

View File

@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
f"is also found in manual_axes: {_manual_axes}.") from None
@util.cache(max_size=128, trace_context_in_key=False)
def _check_axis_type_consistency(mesh, parsed_pspec):
if mesh.axis_types is None:
return
for p in parsed_pspec:
if p is not None:
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
def hashed_index(x) -> int:
# This works for both `pjit` indices and `pmap` indices (which might
# have an integer instead of a slice).
@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
_check_axis_type_consistency(mesh, parsed_pspec)
return parsed_pspec
@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh(
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh:
*, devices: Sequence[xc.Device] | None = None,
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
"""Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set of
@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
mesh_devices = mesh_utils.create_device_mesh(
new_axis_shapes, devices,
allow_split_physical_axes=allow_split_physical_axes)
return mesh_lib.Mesh(mesh_devices, axis_names)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)

View File

@ -1443,26 +1443,28 @@ def with_and_without_mesh(f):
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
def with_user_mesh(sizes, names):
def with_user_mesh(sizes, names, axis_types=None):
axis_types = ({mesh_lib.AxisTypes.User: names}
if axis_types is None else axis_types)
def decorator(fn):
def mesh_fn(*args, **kwargs):
mesh = create_mesh(sizes, names)
mesh = create_mesh(sizes, names, axis_types=axis_types)
with mesh_lib.set_mesh(mesh):
return fn(*args, **kwargs, mesh=mesh)
return mesh_fn
return decorator
def create_mesh(mesh_shape, axis_names, iota_order=False):
def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None):
size = math.prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} global devices.")
if iota_order:
devices = sorted(jax.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
return jax.sharding.Mesh(mesh_devices, axis_names)
return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types)
else:
return jax.make_mesh(mesh_shape, axis_names)
return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types)
class _cached_property:
null = object()

View File

@ -46,7 +46,8 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh,
get_abstract_mesh)
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
@ -536,7 +537,7 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
for i, sz in enumerate(aval.shape))
if config.sharding_in_types.value:
new_mesh = AbstractMesh(
mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names})
mesh.shape_tuple, axis_types={AxisTypes.Collective: mesh.axis_names})
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
else:
new_sharding = None
@ -548,11 +549,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
# TODO(yashkatariya): Reset the mesh properly based on the input avals if the
# mesh of shard_map specifies collective axes.
if config.sharding_in_types.value:
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
new_sharding = NamedSharding(get_abstract_mesh(), spec)
else:
new_sharding = None
return aval.update(shape=new_shape, sharding=new_sharding)

View File

@ -4837,6 +4837,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@jtu.with_user_mesh((2, 2), ('model', 'data'))
def test_aval_repr(self, mesh):
mesh = mesh.abstract_mesh
aval = core.ShapedArray((128, 64), np.float32,
sharding=NamedSharding(mesh, P('model', 'data')))
self.assertEqual(aval.str_short(), 'float32[128@model,64@data]')
@ -4977,21 +4978,21 @@ class ShardingInTypesTest(jtu.JaxTestCase):
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
def test_broadcasting_nary_error(self):
mesh1 = Mesh([jax.devices()[0]], 'x')
mesh2 = Mesh([jax.devices()[0]], 'y')
@jtu.with_user_mesh((1,), 'x')
def test_broadcasting_nary_error(self, mesh):
mesh2 = Mesh([jax.devices()[0]], 'y',
axis_types={mesh_lib.AxisTypes.User: 'y'})
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P()))
arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x, y):
return x + y
with config.sharding_in_types(True):
with self.assertRaisesRegex(
ValueError, "Mesh for all inputs should be equal"):
f(arr1, arr2)
with self.assertRaisesRegex(
ValueError, "Mesh for all inputs should be equal"):
f(arr1, arr2)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_sin_unop(self, mesh):
@ -5482,6 +5483,53 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertIn('@Sharding', f.lower(arr).as_text())
def test_auto_user(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x, x2):
y = x * 2
z = jnp.sin(y)
a = z @ x2
return a
with mesh_lib.set_mesh(mesh):
out = f(arr, arr.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
lowered_text = f.lower(arr, arr.T).as_text()
self.assertNotIn('unspecified_dims', lowered_text)
mesh2 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: 'x',
mesh_lib.AxisTypes.Auto: 'y'})
with mesh_lib.set_mesh(mesh2):
arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None)))
out = f(arr, arr2)
self.assertEqual(out.sharding, NamedSharding(mesh2, P('x', None)))
lowered_text = f.lower(arr, arr2).as_text()
self.assertTrue(lowered_text.count("unspecified_dims") == 3)
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: 'y',
mesh_lib.AxisTypes.Auto: 'x'})
with mesh_lib.set_mesh(mesh3):
arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P('y', 'x')))
out = f(arr, arr2)
self.assertEqual(out.sharding, NamedSharding(mesh3, P('x',)))
lowered_text = f.lower(arr, arr2).as_text()
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
with self.assertRaisesRegex(
ValueError,
"AxisTypes should be the same in a tuple subset of PartitionSpec"):
NamedSharding(mesh2, P(('x', 'y')))
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):