From e1a496d3b6b6bae81552d3f2fdd8b7c9fd312995 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 27 Jun 2024 16:46:44 -0700 Subject: [PATCH] Add concrete layout API to JAX. The API takes `major_to_minor: tuple[int, ...]` and `tiling: tuple[tuple[int, ...], ...]` as the arguments. Allows users to pass layouts to `with_sharding_constraint` to constrain the layout + sharding. `sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required. memory space is exposed via JAX memories API so it doesn't have to be in the layout API. Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts. Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information. PiperOrigin-RevId: 647487510 --- jax/BUILD | 3 +- jax/_src/array.py | 3 +- jax/_src/interpreters/mlir.py | 42 +++++-- jax/_src/interpreters/pxla.py | 6 +- jax/_src/layout.py | 84 +++++++++++--- jax/_src/maps.py | 1 + jax/_src/pjit.py | 54 ++++++--- .../array_serialization/serialization_test.py | 12 +- tests/layout_test.py | 105 ++++++++++++++---- 9 files changed, 228 insertions(+), 82 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 2f7480e31..41fefd860 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -800,10 +800,11 @@ pytype_strict_library( name = "layout", srcs = ["_src/layout.py"], deps = [ + ":dtypes", ":sharding", ":sharding_impls", "//jax/_src/lib", - ], + ] + py_deps("numpy"), ) pytype_strict_library( diff --git a/jax/_src/array.py b/jax/_src/array.py index 6e3f0a76f..4dcb57b91 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -506,7 +506,8 @@ class ArrayImpl(basearray.Array): if self.is_deleted(): return Layout(None, self.sharding) try: - return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding) + return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), + self.sharding) except xe.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index eccb29577..a729d7fe3 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -824,12 +824,15 @@ def _to_physical_op_sharding( return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore -def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout) -> str | None: +def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, + aval: core.AbstractValue) -> str | None: if layout is None: return "default" if isinstance(layout, AutoLayout): return "auto" - return layout._to_xla_layout() + if aval is core.abstract_token: + return "default" + return layout._to_xla_layout(aval.dtype) # type: ignore def _get_mem_kind(s: JSharding | None) -> str | None: @@ -1194,11 +1197,9 @@ def lower_jaxpr_to_fun( ir_arg_shardings = None if arg_shardings is not None: - in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals) ir_arg_shardings = util.flatten( [[_to_physical_op_sharding(a, s)] * len(types) - for a, s, types in zip(in_avals, arg_shardings, input_types)]) - del in_avals + for a, s, types in zip(input_avals, arg_shardings, input_types)]) ir_arg_memory_kinds = None if arg_memory_kinds is not None: @@ -1208,8 +1209,8 @@ def lower_jaxpr_to_fun( ir_arg_layouts = None if arg_layouts is not None: ir_arg_layouts = util.flatten( - [[_to_xla_layout(l)] * len(types) - for l, types in zip(arg_layouts, input_types)]) + [[_to_xla_layout(l, a)] * len(types) + for l, a, types in zip(arg_layouts, input_avals, input_types)]) ir_donated_args = None if xla_donated_args is not None: @@ -1244,8 +1245,8 @@ def lower_jaxpr_to_fun( ir_result_layouts = None if result_layouts is not None: ir_result_layouts = util.flatten( - [[_to_xla_layout(l)] * len(types) - for l, types in zip(result_layouts, output_types)]) + [[_to_xla_layout(l, a)] * len(types) + for l, a, types in zip(result_layouts, output_avals, output_types)]) if ( replicated_args is not None @@ -2171,6 +2172,29 @@ def get_sharding_attr(sharding_proto: xc.OpSharding): return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto))) +def wrap_with_layout_op(ctx: LoweringRuleContext, + x: ir.Value, + aval_out: core.AbstractValue, + layout: DeviceLocalLayout, + aval_in: core.AbstractValue): + result_type = aval_to_ir_type(aval_out) + out_shape = core.physical_aval(aval_out).shape # type: ignore + if core.is_constant_shape(out_shape): + result_shapes = None + else: + result_shapes = [eval_dynamic_shape_as_tensor(ctx, out_shape)] + + op = custom_call('LayoutConstraint', result_types=[result_type], operands=[x], + api_version=1, + result_shapes=result_shapes, + # Set operand layouts to anything. XLA will ignore it. + operand_layouts=[list(range(aval_in.ndim))], # type: ignore + # TODO(yashkatariya): Figure out how to pass tiling to the + # custom call. + result_layouts=[layout.major_to_minor[::-1]]) + return op.result + + # MLIR lowerings for lax primitives def cache_lowering(f): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 69d7c619b..378aa2c5f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1937,7 +1937,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval # first call you pass it a sharded array with layout and on second call you # pass a numpy array. The layouts should be the same to get cache hits. try: - al = DeviceLocalLayout( + al = DeviceLocalLayout.from_pjrt_layout( d.client.get_default_layout(aval.dtype, shard_shape, d)) except: return None @@ -2704,7 +2704,7 @@ def _get_layouts_from_executable( new_in_layouts = [] for x, i in safe_zip(in_layouts_xla, in_layouts): - x = DeviceLocalLayout(x) + x = DeviceLocalLayout.from_pjrt_layout(x) if isinstance(i, DeviceLocalLayout): if i != x: raise AssertionError( @@ -2716,7 +2716,7 @@ def _get_layouts_from_executable( new_out_layouts = [] for x, o in safe_zip(out_layouts_xla, out_layouts): - x = DeviceLocalLayout(x) + x = DeviceLocalLayout.from_pjrt_layout(x) if isinstance(o, DeviceLocalLayout): if o != x: raise AssertionError( diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 2071794a0..3b4424345 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -16,9 +16,12 @@ from __future__ import annotations from typing import Union +import numpy as np +from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding from jax._src.sharding_impls import AUTO as AutoSharding, is_auto from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version class AutoLayout: @@ -27,31 +30,78 @@ class AutoLayout: return "AUTO" -class DeviceLocalLayout: - layout: xc.PjRtLayout +if xla_extension_version >= 274: + class DeviceLocalLayout: + major_to_minor: tuple[int, ...] + tiling: tuple[tuple[int, ...], ...] | None - AUTO = AutoLayout() + AUTO = AutoLayout() - def __init__(self, layout: xc.PjRtLayout): - self._layout = layout - self._layout_str = str(self._layout) + def __init__(self, major_to_minor: tuple[int, ...], + tiling: tuple[tuple[int, ...], ...] | None = None): + self.major_to_minor = tuple(major_to_minor) + self.tiling = None if tiling is None else tuple(map(tuple, tiling)) - def __repr__(self): - return f'DeviceLocalLayout({self._layout_str})' + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + xla_layout = pjrt_layout._xla_layout() + return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types + xla_layout.tiling()) - def __hash__(self): - return hash(self._layout) + def __repr__(self): + return (f'DeviceLocalLayout(major_to_minor={self.major_to_minor},' + f' tiling={self.tiling})') - def __eq__(self, other): - if not isinstance(other, DeviceLocalLayout): - return False - return self._layout == other._layout + def __hash__(self): + return hash((self.major_to_minor, self.tiling)) - def _to_xla_layout(self) -> str: - return self._layout_str + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return (self.major_to_minor == other.major_to_minor and + self.tiling == other.tiling) + + def _to_xla_layout(self, dtype) -> str: + if self.tiling is None: + xla_layout = xc.Layout(self.major_to_minor[::-1]) + else: + if issubdtype(dtype, np.integer): + sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0 + else: + sub_byte_size = 0 + xla_layout = xc.Layout(self.major_to_minor[::-1], self.tiling, # type: ignore + sub_byte_size) + return str(xla_layout) +else: + class DeviceLocalLayout: # type: ignore + layout: xc.PjRtLayout + + AUTO = AutoLayout() + + def __init__(self, layout: xc.PjRtLayout): + self._layout = layout + self._layout_str = str(self._layout) + + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + return DeviceLocalLayout(pjrt_layout) # type: ignore + + def __repr__(self): + return f'DeviceLocalLayout({self._layout_str})' + + def __hash__(self): + return hash(self._layout) + + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return self._layout == other._layout + + def _to_xla_layout(self, dtype) -> str: + return self._layout_str -LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] +LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation ShardingOptions = Union[Sharding, None, AutoSharding] diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 20fc54d8f..4b574775d 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1816,6 +1816,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None): [tmpvar], [outvar], sharding_constraint_p, dict(resource_env=resource_env, sharding=gspmd_sharding, + layout=None, unconstrained_dims=unconstrained_dims), set(), eqn.source_info, eqn.ctx)) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 204c288d6..454611424 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2488,6 +2488,9 @@ def with_sharding_constraint(x, shardings): .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) + + layouts, shardings = _split_layout_and_sharding(shardings) + user_shardings = prepare_axis_resources( shardings, "shardings", allow_unconstrained_dims=True) del shardings @@ -2496,6 +2499,10 @@ def with_sharding_constraint(x, shardings): flatten_axes("with_sharding_constraint shardings", tree, user_shardings)) del user_shardings + user_layouts_flat = tuple( + flatten_axes("with_sharding_constraint layouts", tree, layouts)) + del layouts + resource_env = mesh_lib.thread_resources.env mesh = resource_env.physical_mesh @@ -2511,19 +2518,27 @@ def with_sharding_constraint(x, shardings): shardings_flat, x_flat, None, "with_sharding_constraint arguments", allow_uneven_sharding=True) - outs = [sharding_constraint_p.bind(xf, sharding=s, + outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l, resource_env=resource_env, unconstrained_dims=ud) - for xf, s, ud in zip(x_flat, shardings_flat, unconstrained_dims)] + for xf, s, l, ud in zip(x_flat, shardings_flat, user_layouts_flat, + unconstrained_dims)] return tree_unflatten(tree, outs) def _identity_fn(x): return x -def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): - if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): - return x - # Run a jit here to raise good errors when device assignment don't match. - return api.jit(_identity_fn, out_shardings=sharding)(x) +def _sharding_constraint_impl(x, sharding, layout, resource_env, + unconstrained_dims): + if layout is None: + if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): + return x + # Run a jit here to raise good errors when device assignment don't match. + return api.jit(_identity_fn, out_shardings=sharding)(x) + else: + if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and + x.sharding.is_equivalent_to(sharding, x.ndim)): + return x + return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x) sharding_constraint_p = core.Primitive("sharding_constraint") @@ -2532,7 +2547,7 @@ sharding_constraint_p.def_abstract_eval(lambda x, **_: x) ad.deflinear2(sharding_constraint_p, lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),)) -def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, +def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, resource_env, unconstrained_dims): aval, = ctx.avals_in out_aval, = ctx.avals_out @@ -2547,19 +2562,19 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, sharding._to_xla_hlo_sharding(aval.ndim), mesh)[0] sharding = NamedSharding._from_parsed_pspec( mesh, parsed_pspec, _manual_axes=axis_ctx.manual_axes) - return [ - mlir.wrap_with_sharding_op(ctx, - x_node, out_aval, - sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), - unspecified_dims=unconstrained_dims) - ] + out = mlir.wrap_with_sharding_op( + ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), + unspecified_dims=unconstrained_dims) + if layout is not None: + out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) + return [out] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_hlo_lowering) -def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, - axis_name, main_type, vals_in, dims_in, - sharding, resource_env, unconstrained_dims): +def _sharding_constraint_batcher( + insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, + dims_in, sharding, layout, resource_env, unconstrained_dims): if spmd_axis_name is not None and isinstance(sharding, NamedSharding): used = {n for ns in sharding.spec for n in (ns if isinstance(ns, tuple) else (ns,))} @@ -2586,9 +2601,14 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, vmapped_sharding = NamedSharding( vmapped_sharding.mesh, PartitionSpec(*new_spec)) + # TODO(yashkatariya): Figure out layouts should change under vmap. + if layout is not None: + raise NotImplementedError + y = sharding_constraint_p.bind( x, sharding=vmapped_sharding, + layout=layout, resource_env=resource_env, unconstrained_dims=unconstrained_dims) return y, d diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index b71c3cac2..ccf2d0546 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -17,7 +17,6 @@ import asyncio import contextlib import math from functools import partial -import re import os import pathlib import tracemalloc as tm @@ -46,13 +45,6 @@ def tearDownModule(): _exit_stack.close() -pattern = re.compile(r"\{(.*?):") - -def extract_minor_to_major(l): - match = re.search(pattern, str(l)) - return tuple(int(i) for i in match.groups()[0].split(',')) - - class CheckpointTest(jtu.JaxTestCase): def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @@ -436,8 +428,8 @@ class CheckpointTest(jtu.JaxTestCase): out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( arr).compile().output_layouts() - self.assertEqual(extract_minor_to_major(arr.layout), - extract_minor_to_major(out_layout)[::-1]) + self.assertEqual(arr.layout.device_local_layout.major_to_minor, + out_layout.device_local_layout.major_to_minor[::-1]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) diff --git a/tests/layout_test.py b/tests/layout_test.py index d071d6cb7..d0d0a27b8 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -14,9 +14,9 @@ import contextlib import math -import re from absl.testing import absltest import numpy as np +from functools import partial import jax import jax.numpy as jnp @@ -37,15 +37,6 @@ def tearDownModule(): _exit_stack.close() -pattern = re.compile(r"\{(.*?):") - -# Extract minor_to_major from str(layout) because layout doesn't have a -# minor_to_major property yet. -def extract_minor_to_major(l): - match = re.search(pattern, str(l)) - return tuple(int(i) for i in match.groups()[0].split(',')) - - class LayoutTest(jtu.JaxTestCase): def setUp(self): @@ -79,8 +70,8 @@ class LayoutTest(jtu.JaxTestCase): self.assertEmpty(kw_layouts) for i, o in zip(arg_layouts, compiled_apply.output_layouts()): - self.assertEqual(extract_minor_to_major(i), - extract_minor_to_major(o)[::-1]) + self.assertEqual(i.device_local_layout.major_to_minor, + o.device_local_layout.major_to_minor[::-1]) init_compiled = jax.jit( init, out_shardings=arg_layouts).lower(sds1, sds2).compile() @@ -108,10 +99,10 @@ class LayoutTest(jtu.JaxTestCase): self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0]) self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1]) - self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout), - extract_minor_to_major(init_out[0].layout)[::-1]) - self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout), - extract_minor_to_major(init_out[1].layout)[::-1]) + self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, + init_out[0].layout.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor, + init_out[1].layout.device_local_layout.major_to_minor[::-1]) self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) @@ -135,18 +126,22 @@ class LayoutTest(jtu.JaxTestCase): out = compiled(arr) self.assertTupleEqual( - extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0)) + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled.output_layouts()), (2, 1, 0)) + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), out_shardings=Layout(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0)) + compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (2, 1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2)) + compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1, 2)) with self.assertRaisesRegex( ValueError, "jax.jit` does not accept device-local layouts directly"): @@ -166,9 +161,11 @@ class LayoutTest(jtu.JaxTestCase): compiled = jax.jit(f, in_shardings=Layout(), out_shardings=Layout(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0)) + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled.output_layouts()), (0, 1)) + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) @@ -185,9 +182,11 @@ class LayoutTest(jtu.JaxTestCase): out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0)) + compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + (1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled.output_layouts()), (0, 1)) + compiled.output_layouts().device_local_layout.major_to_minor[::-1], + (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -358,6 +357,64 @@ class LayoutTest(jtu.JaxTestCase): jax.make_array_from_callback( np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) + def test_wsc_concrete_layout(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (128, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128),)) + + # We need AUTO so that XLA can override the entry computation layout set. + # TODO(yashkatariya): Expose a config that sets out_shardings to AUTO by + # default instead of `None` i.e. default layout and let the compiler choose + # the layout or try setting it to AUTO by default and see if there is chaos. + @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + + out = f(arr) + self.assertEqual(out.layout, Layout(custom_dll, s)) + self.assertEqual(out.layout, arr.layout) + self.assertArraysEqual(out, np_inp.T) + + def test_wsc_concrete_layout_bfloat16(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (128, 128) + s = NamedSharding(mesh, P('x')) + inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + arr = jax.device_put(inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=(0, 1), tiling=((8, 128), (2, 1))) + + @partial(jax.jit, out_shardings=Layout(DLL.AUTO)) + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + + out = f(arr) + self.assertEqual(out.layout, Layout(custom_dll, s)) + self.assertEqual(out.layout, arr.layout) + self.assertArraysEqual(out, inp.T) + + def test_device_put_user_concrete_layout(self): + shape = (8, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + dll = DLL(major_to_minor=(1, 0), tiling=((8, 128),)) + s = SingleDeviceSharding(jax.devices()[0]) + + out = jax.device_put(np_inp, Layout(dll, s)) + self.assertEqual(out.layout, Layout(dll, s)) + self.assertArraysEqual(out, np_inp) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())