diff --git a/jax/BUILD b/jax/BUILD index 2f7480e31..41fefd860 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -800,10 +800,11 @@ pytype_strict_library( name = "layout", srcs = ["_src/layout.py"], deps = [ + ":dtypes", ":sharding", ":sharding_impls", "//jax/_src/lib", - ], + ] + py_deps("numpy"), ) pytype_strict_library( diff --git a/jax/_src/array.py b/jax/_src/array.py index 6e3f0a76f..4dcb57b91 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -506,7 +506,8 @@ class ArrayImpl(basearray.Array): if self.is_deleted(): return Layout(None, self.sharding) try: - return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding) + return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), + self.sharding) except xe.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index eccb29577..a729d7fe3 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -824,12 +824,15 @@ def _to_physical_op_sharding( return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore -def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None: +def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, + aval: core.AbstractValue) -> str | None: if layout is None: return "default" if isinstance(layout, AutoLayout): return "auto" - return layout._to_xla_layout() + if aval is core.abstract_token: + return "default" + return layout._to_xla_layout(aval.dtype) # type: ignore def _get_mem_kind(s: JSharding | None) -> str | None: @@ -1194,11 +1197,9 @@ def lower_jaxpr_to_fun( ir_arg_shardings = None if arg_shardings is not None: - in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals) ir_arg_shardings = util.flatten( [[_to_physical_op_sharding(a, s)] * len(types) - for a, s, types in zip(in_avals, arg_shardings, input_types)]) - del in_avals + for a, s, types in zip(input_avals, arg_shardings, input_types)]) ir_arg_memory_kinds = None if arg_memory_kinds is not None: @@ -1208,8 +1209,8 @@ def lower_jaxpr_to_fun( ir_arg_layouts = None if arg_layouts is not None: ir_arg_layouts = util.flatten( - [[_to_xla_layout(l)] * len(types) - for l, types in zip(arg_layouts, input_types)]) + [[_to_xla_layout(l, a)] * len(types) + for l, a, types in zip(arg_layouts, input_avals, input_types)]) ir_donated_args = None if xla_donated_args is not None: @@ -1244,8 +1245,8 @@ def lower_jaxpr_to_fun( ir_result_layouts = None if result_layouts is not None: ir_result_layouts = util.flatten( - [[_to_xla_layout(l)] * len(types) - for l, types in zip(result_layouts, output_types)]) + [[_to_xla_layout(l, a)] * len(types) + for l, a, types in zip(result_layouts, output_avals, output_types)]) if ( replicated_args is not None @@ -2171,6 +2172,29 @@ def get_sharding_attr(sharding_proto: xc.OpSharding): return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto))) +def wrap_with_layout_op(ctx: LoweringRuleContext, + x: ir.Value, + aval_out: core.AbstractValue, + layout: DeviceLocalLayout, + aval_in: core.AbstractValue): + result_type = aval_to_ir_type(aval_out) + out_shape = core.physical_aval(aval_out).shape # type: ignore + if core.is_constant_shape(out_shape): + result_shapes = None + else: + result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)] + + op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x], + api_version=1, + result_shapes=result_shapes, + # Set operand layouts to anything. XLA will ignore it. + operand_layouts=[list(range(aval_in.ndim))], # type: ignore + # TODO(yashkatariya): Figure out how to pass tiling to the + # custom call. + result_layouts=[layout.major_to_minor[::-1]]) + return op.result + + # MLIR lowerings for lax primitives def cache_lowering(f): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 69d7c619b..378aa2c5f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1937,7 +1937,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval # first call you pass it a sharded array with layout and on second call you # pass a numpy array. The layouts should be the same to get cache hits. try: - al = DeviceLocalLayout( + al = DeviceLocalLayout.from_pjrt_layout( d.client.get_default_layout(aval.dtype, shard_shape, d)) except: return None @@ -2704,7 +2704,7 @@ def _get_layouts_from_executable( new_in_layouts = [] for x, i in safe_zip(in_layouts_xla, in_layouts): - x = DeviceLocalLayout(x) + x = DeviceLocalLayout.from_pjrt_layout(x) if isinstance(i, DeviceLocalLayout): if i != x: raise AssertionError( @@ -2716,7 +2716,7 @@ def _get_layouts_from_executable( new_out_layouts = [] for x, o in safe_zip(out_layouts_xla, out_layouts): - x = DeviceLocalLayout(x) + x = DeviceLocalLayout.from_pjrt_layout(x) if isinstance(o, DeviceLocalLayout): if o != x: raise AssertionError( diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 2071794a0..3b4424345 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -16,9 +16,12 @@ from __future__ import annotations from typing import Union +import numpy as np +from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding from jax._src.sharding_impls import AUTO as AutoSharding, is_auto from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version class AutoLayout: @@ -27,31 +30,78 @@ class AutoLayout: return "AUTO" -class DeviceLocalLayout: - layout: xc.PjRtLayout +if xla_extension_version >= 274: + class DeviceLocalLayout: + major_to_minor: tuple[int, ...] + tiling: tuple[tuple[int, ...], ...] | None - AUTO = AutoLayout() + AUTO = AutoLayout() - def __init__(self, layout: xc.PjRtLayout): - self._layout = layout - self._layout_str = str(self._layout) + def __init__(self, major_to_minor: tuple[int, ...], + tiling: tuple[tuple[int, ...], ...] | None = None): + self.major_to_minor = tuple(major_to_minor) + self.tiling = None if tiling is None else tuple(map(tuple, tiling)) - def __repr__(self): - return f'DeviceLocalLayout({self._layout_str})' + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + xla_layout = pjrt_layout._xla_layout() + return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types + xla_layout.tiling()) - def __hash__(self): - return hash(self._layout) + def __repr__(self): + return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},' + f' tiling={self.tiling})') - def __eq__(self, other): - if not isinstance(other, DeviceLocalLayout): - return False - return self._layout == other._layout + def __hash__(self): + return hash((self.major_to_minor, self.tiling)) - def _to_xla_layout(self) -> str: - return self._layout_str + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return (self.major_to_minor == other.major_to_minor and + self.tiling == other.tiling) + + def _to_xla_layout(self, dtype) -> str: + if self.tiling is None: + xla_layout = xc.Layout(self.major_to_minor[::-1]) + else: + if issubdtype(dtype, np.integer): + sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0 + else: + sub_byte_size = 0 + xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore + sub_byte_size) + return str(xla_layout) +else: + class DeviceLocalLayout: # type: ignore + layout: xc.PjRtLayout + + AUTO = AutoLayout() + + def __init__(self, layout: xc.PjRtLayout): + self._layout = layout + self._layout_str = str(self._layout) + + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + return DeviceLocalLayout(pjrt_layout) # type: ignore + + def __repr__(self): + return f'DeviceLocalLayout({self._layout_str})' + + def __hash__(self): + return hash(self._layout) + + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return self._layout == other._layout + + def _to_xla_layout(self, dtype) -> str: + return self._layout_str -LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] +LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation ShardingOptions = Union[Sharding, None, AutoSharding] diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 20fc54d8f..4b574775d 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1816,6 +1816,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None): [tmpvar], [outvar], sharding_constraint_p, dict(resource_env=resource_env, sharding=gspmd_sharding, + layout=None, unconstrained_dims=unconstrained_dims), set(), eqn.source_info, eqn.ctx)) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 204c288d6..454611424 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2488,6 +2488,9 @@ def with_sharding_constraint(x, shardings): .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) + + layouts, shardings = _split_layout_and_sharding(shardings) + user_shardings = prepare_axis_resources( shardings, "shardings", allow_unconstrained_dims=True) del shardings @@ -2496,6 +2499,10 @@ def with_sharding_constraint(x, shardings): flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) del user_shardings + user_layouts_flat = tuple( + flatten_axes("with_sharding_constraint layouts", tree, layouts)) + del layouts + resource_env = mesh_lib.thread_resources.env mesh = resource_env.physical_mesh @@ -2511,19 +2518,27 @@ def with_sharding_constraint(x, shardings): shardings_flat, x_flat, None, "with_sharding_constraint arguments", allow_uneven_sharding=True) - outs = [sharding_constraint_p.bind(xf, sharding=s, + outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l, resource_env=resource_env, unconstrained_dims=ud) - for xf, s, ud in zip(x_flat, shardings_flat, unconstrained_dims)] + for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat, + unconstrained_dims)] return tree_unflatten(tree, outs) def _identity_fn(x): return x -def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): - if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): - return x - # Run a jit here to raise good errors when device assignment don't match. - return api.jit(_identity_fn, out_shardings=sharding)(x) +def _sharding_constraint_impl(x, sharding, layout, resource_env, + unconstrained_dims): + if layout is None: + if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): + return x + # Run a jit here to raise good errors when device assignment don't match. + return api.jit(_identity_fn, out_shardings=sharding)(x) + else: + if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and + x.sharding.is_equivalent_to(sharding, x.ndim)): + return x + return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x) sharding_constraint_p = core.Primitive("sharding_constraint") @@ -2532,7 +2547,7 @@ sharding_constraint_p.def_abstract_eval(lambda x, **_: x) ad.deflinear2(sharding_constraint_p, lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),)) -def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, +def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, resource_env, unconstrained_dims): aval, = ctx.avals_in out_aval, = ctx.avals_out @@ -2547,19 +2562,19 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, sharding._to_xla_hlo_sharding(aval.ndim), mesh)[0] sharding = NamedSharding._from_parsed_pspec( mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes) - return [ - mlir.wrap_with_sharding_op(ctx, - x_node, out_aval, - sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), - unspecified_dims=unconstrained_dims) - ] + out = mlir.wrap_with_sharding_op( + ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), + unspecified_dims=unconstrained_dims) + if layout is not None: + out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) + return [out] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_hlo_lowering) -def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, - axis_name, main_type, vals_in, dims_in, - sharding, resource_env, unconstrained_dims): +def _sharding_constraint_batcher( + insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, + dims_in, sharding, layout, resource_env, unconstrained_dims): if spmd_axis_name is not None and isinstance(sharding, NamedSharding): used = {n for ns in sharding.spec for n in (ns if isinstance(ns, tuple) else (ns,))} @@ -2586,9 +2601,14 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, vmapped_sharding = NamedSharding( vmapped_sharding.mesh, PartitionSpec(*new_spec)) + # TODO(yashkatariya): Figure out layouts should change under vmap. + if layout is not None: + raise NotImplementedError + y = sharding_constraint_p.bind( x, sharding=vmapped_sharding, + layout=layout, resource_env=resource_env, unconstrained_dims=unconstrained_dims) return y, d diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index b71c3cac2..ccf2d0546 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -17,7 +17,6 @@ import asyncio import contextlib import math from functools import partial -import re import os import pathlib import tracemalloc as tm @@ -46,13 +45,6 @@ def tearDownModule(): _exit_stack.close() -pattern = re.compile(r"\{(.*?):") - -def extract_minor_to_major(l): - match = re.search(pattern, str(l)) - return tuple(int(i) for i in match.groups()[0].split(',')) - - class CheckpointTest(jtu.JaxTestCase): def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @@ -436,8 +428,8 @@ class CheckpointTest(jtu.JaxTestCase): out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( arr).compile().output_layouts() - self.assertEqual(extract_minor_to_major(arr.layout), - extract_minor_to_major(out_layout)[::-1]) + self.assertEqual(arr.layout.device_local_layout.major_to_minor, + out_layout.device_local_layout.major_to_minor[::-1]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) diff --git a/tests/layout_test.py b/tests/layout_test.py index d071d6cb7..d0d0a27b8 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -14,9 +14,9 @@ import contextlib import math -import re from absl.testing import absltest import numpy as np +from functools import partial import jax import jax.numpy as jnp @@ -37,15 +37,6 @@ def tearDownModule(): _exit_stack.close() -pattern = re.compile(r"\{(.*?):") - -# Extract minor_to_major from str(layout) because layout doesn't have a -# minor_to_major property yet. -def extract_minor_to_major(l): - match = re.search(pattern, str(l)) - return tuple(int(i) for i in match.groups()[0].split(',')) - - class LayoutTest(jtu.JaxTestCase): def setUp(self): @@ -79,8 +70,8 @@ class LayoutTest(jtu.JaxTestCase): self.assertEmpty(kw_layouts) for i, o in zip(arg_layouts, compiled_apply.output_layouts()): - self.assertEqual(extract_minor_to_major(i), - extract_minor_to_major(o)[::-1]) + self.assertEqual(i.device_local_layout.major_to_minor, + o.device_local_layout.major_to_minor[::-1]) init_compiled = jax.jit( init, out_shardings=arg_layouts).lower(sds1, sds2).compile() @@ -108,10 +99,10 @@ class LayoutTest(jtu.JaxTestCase): self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0]) self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1]) - self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout), - extract_minor_to_major(init_out[0].layout)[::-1]) - self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout), - extract_minor_to_major(init_out[1].layout)[::-1]) + self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, + init_out[0].layout.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor, + init_out[1].layout.device_local_layout.major_to_minor[::-1]) self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) @@ -135,18 +126,22 @@ class LayoutTest(jtu.JaxTestCase): out = compiled(arr) self.assertTupleEqual( - extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0)) + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled.output_layouts()), (2, 1, 0)) + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), out_shardings=Layout(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0)) + compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2)) + compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1, 2)) with self.assertRaisesRegex( ValueError, "jax.jit` does not accept device-local layouts directly"): @@ -166,9 +161,11 @@ class LayoutTest(jtu.JaxTestCase): compiled = jax.jit(f, in_shardings=Layout(), out_shardings=Layout(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0)) + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled.output_layouts()), (0, 1)) + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) @@ -185,9 +182,11 @@ class LayoutTest(jtu.JaxTestCase): out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0)) + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled.output_layouts()), (0, 1)) + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -358,6 +357,64 @@ class LayoutTest(jtu.JaxTestCase): jax.make_array_from_callback( np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) + def test_wsc_concrete_layout(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (128, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),)) + + # We need AUTO so that XLA can override the entry computation layout set. + # TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by + # default instead of `None` i.e. default layout and let the compiler choose + # the layout or try setting it to AUTO by default and see if there is chaos. + @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + + out = f(arr) + self.assertEqual(out.layout, Layout(custom_dll, s)) + self.assertEqual(out.layout, arr.layout) + self.assertArraysEqual(out, np_inp.T) + + def test_wsc_concrete_layout_bfloat16(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (128, 128) + s = NamedSharding(mesh, P('x')) + inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + arr = jax.device_put(inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1))) + + @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + + out = f(arr) + self.assertEqual(out.layout, Layout(custom_dll, s)) + self.assertEqual(out.layout, arr.layout) + self.assertArraysEqual(out, inp.T) + + def test_device_put_user_concrete_layout(self): + shape = (8, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),)) + s = SingleDeviceSharding(jax.devices()[0]) + + out = jax.device_put(np_inp, Layout(dll, s)) + self.assertEqual(out.layout, Layout(dll, s)) + self.assertArraysEqual(out, np_inp) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())