diff --git a/jax/_src/array.py b/jax/_src/array.py index 07d55f018..e412f5c3c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -34,10 +34,12 @@ from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors +from jax._src import layout from jax._src import profiler from jax._src import tree_util from jax._src import xla_bridge from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension as xe from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -527,6 +529,18 @@ class ArrayImpl(basearray.Array): out.append(Shard(_get_device(a), self.sharding, self.shape, a)) return out + @property + def layout(self): + # TODO(yashkatariya): Remove the try;except when pathways supports layouts. + try: + return layout.SpecifiedLayout(self._pjrt_layout) + except xe.XlaRuntimeError as e: + msg, *_ = e.args + if type(msg) is str and msg.startswith("UNIMPLEMENTED"): + return None + else: + raise + @property def global_shards(self) -> Sequence[Shard]: """Returns list of all `Shard`s of the Array across all devices. @@ -637,7 +651,7 @@ if not TYPE_CHECKING: ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl) -# explicitly set to be unhashable. Same as what device_array.py does. +# explicitly set to be unhashable. setattr(ArrayImpl, "__hash__", None) setattr(ArrayImpl, "__array_priority__", 100) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a11c2c182..af3b7fb31 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -46,7 +46,7 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla -from jax._src.layout import XLACompatibleLayout, LayoutRequest +from jax._src.layout import AutoLayout, SpecifiedLayout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version @@ -834,10 +834,10 @@ def _to_physical_op_sharding( return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore -def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None: +def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None: if layout is None: return "default" - if isinstance(layout, LayoutRequest): + if isinstance(layout, AutoLayout): return "auto" return layout._to_xla_layout() @@ -862,8 +862,8 @@ def lower_jaxpr_to_module( replicated_args: Sequence[bool] | None = None, arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, result_shardings: Sequence[XLACompatibleSharding | None] | None = None, - in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None, - out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None, + in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None, + out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None, arg_names: Sequence[str | None] | None = None, result_names: Sequence[str | None] | None = None, num_replicas: int = 1, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index dc759333f..949c5692a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -60,7 +60,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest +from jax._src.layout import SpecifiedLayout, AutoLayout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir @@ -1985,13 +1985,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings): return False return True -MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]] +MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]] class AllArgsInfo(NamedTuple): """Avals, shardings, layouts and debug_info for all arguments prior to DCE.""" in_avals: Sequence[core.ShapedArray] in_shardings: Any + in_layouts: Any debug_info: core.JaxprDebugInfo | None @@ -2023,7 +2024,7 @@ def lower_sharding_computation( auto_spmd_lowering = check_if_any_auto( it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore - all_args_info = AllArgsInfo(global_in_avals, in_shardings, + all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts, closed_jaxpr.jaxpr.debug_info) (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, @@ -2227,8 +2228,6 @@ def lower_mesh_computation( out_jaxpr_avals = fun_or_jaxpr.out_avals consts = fun_or_jaxpr.consts - all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info) - assert len(out_shardings) == len(out_jaxpr_avals) if spmd_lowering: global_out_avals = out_jaxpr_avals @@ -2319,7 +2318,7 @@ def lower_mesh_computation( in_layouts=(None,) * len(global_in_avals), out_layouts=(None,) * len(global_out_avals), shape_poly_state=lowering_result.shape_poly_state, - all_args_info=all_args_info) + all_args_info=None) class MeshComputation(stages.XlaLowering): _hlo: ir.Module | None @@ -2599,7 +2598,7 @@ def _get_layouts_from_executable( if isinstance(i, SpecifiedLayout): if i != x: raise AssertionError( - f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)") + f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)") new_in_layouts.append(i) else: new_in_layouts.append(x) @@ -2610,7 +2609,7 @@ def _get_layouts_from_executable( if isinstance(o, SpecifiedLayout): if o != x: raise AssertionError( - f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)") + f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)") new_out_layouts.append(o) else: new_out_layouts.append(x) @@ -3016,6 +3015,7 @@ class MeshExecutable(stages.XlaExecutable): kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx] ref_avals = self.in_avals in_shardings = self._in_shardings + in_layouts = self._in_layouts debug_info = None else: kept_args = args @@ -3023,12 +3023,16 @@ class MeshExecutable(stages.XlaExecutable): iter_in_shardings = iter(self._in_shardings) in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s for i, s in enumerate(self._all_args_info.in_shardings)] + iter_in_layouts = iter(self._in_layouts) + in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s + for i, s in enumerate(self._all_args_info.in_layouts)] debug_info = self._all_args_info.debug_info arg_avals = map(xla.abstractify, kept_args) check_arg_avals_for_call(ref_avals, arg_avals, debug_info) # Check the GDA sharding and the input sharding. - check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info) + check_array_xla_sharding_layout_match(kept_args, in_shardings, + in_layouts, debug_info) return self.unsafe_call(*args) # pylint: disable=not-callable def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: @@ -3184,15 +3188,17 @@ def check_device_backend_on_shardings(shardings) -> bool: return False -def check_gda_or_array_xla_sharding_match( +def check_array_xla_sharding_layout_match( args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding], + in_xla_layouts: Sequence[SpecifiedLayout], jaxpr_debug_info: core.JaxprDebugInfo | None) -> None: from jax._src.array import ArrayImpl arg_names = ([''] * len(args) if jaxpr_debug_info is None else jaxpr_debug_info.arg_names) errors = [] num_errors = 5 - for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names): + for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts, + arg_names): if not isinstance(arg, ArrayImpl): continue if is_unspecified_or_auto(xs): @@ -3205,27 +3211,45 @@ def check_gda_or_array_xla_sharding_match( # Raise memory kind mismatch error even if the arg is uncommitted. if arg.sharding.memory_kind != xs.memory_kind: errors.append( - "Got input sharding(s) that compiled object was called with: " + ("Got input sharding(s) that compiled object was called with: " f"{arg.sharding} and sharding(s) the computation was compiled " - f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}") + f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}", + 'sharding')) if (not db_xs and arg._committed and not op_shardings.are_op_shardings_equal( arg.sharding._to_xla_hlo_sharding(arg.ndim), xs._to_xla_hlo_sharding(arg.ndim))): errors.append( - "Got input sharding(s) that compiled object was called with: " + ("Got input sharding(s) that compiled object was called with: " f"{arg.sharding} and sharding(s) the computation was compiled " - f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}") + f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}", + 'sharding')) + + if (xla_extension_version >= 249 and not db_xs and arg._committed and + arg.layout is not None and xl is not None and arg.layout != xl): + errors.append( + ("Got input layout(s) that compiled object was called with: " + f"{arg.layout} and layout(s) the computation was compiled " + f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}", + 'layout')) if errors: - str_errors = '\n'.join(errors[:num_errors]) + first_errors, error_kinds = unzip2(errors[:num_errors]) + str_errors = '\n'.join(first_errors) + if all(k == 'sharding' for k in error_kinds): + kind_str = r'sharding(s)' + elif all(k == 'layout' for k in error_kinds): + kind_str = 'layout(s)' + else: + kind_str = 'sharding(s) and layout(s)' num_mismatch_str = ( f'the {len(errors)} mismatches' if len(errors) < num_errors else f"{num_errors} mismatches out of {len(errors)}") raise ValueError( - "Compiled object called with input sharding(s) does not match the " - "sharding(s) the computation was compiled with. " + f"Compiled object called with input {kind_str} does " + f"not match the {kind_str} the computation was " + "compiled with. " f"Here are {num_mismatch_str}:\n{str_errors}") diff --git a/jax/_src/layout.py b/jax/_src/layout.py index fcf4f3053..4b1b74ba5 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -14,8 +14,6 @@ from __future__ import annotations -import re - from jax._src.lib import xla_client as xc @@ -24,16 +22,10 @@ class Layout: pass -class XLACompatibleLayout(Layout): +class SpecifiedLayout(Layout): + layout: xc.PjRtLayout - def _to_xla_layout(self) -> str: - raise NotImplementedError("Subclasses should implement this method.") - - -class SpecifiedLayout(XLACompatibleLayout): - layout: xc.Layout - - def __init__(self, layout: xc.Layout): + def __init__(self, layout: xc.PjRtLayout): self._layout = layout self._layout_str = str(self._layout) @@ -51,19 +43,10 @@ class SpecifiedLayout(XLACompatibleLayout): def _to_xla_layout(self) -> str: return self._layout_str - @property - def _minor_to_major(self): - m = re.search("{([0-9,]*):", str(self)) - assert m is not None - m2m_str = m.group(1) - if m2m_str == "": - return () - return tuple(int(x) for x in m2m_str.split(",")) - -class LayoutRequest: +class AutoLayout: def __repr__(self): - return "Request a layout from the compiler" + return "AUTO" -AUTO = LayoutRequest() +AUTO = AutoLayout() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 5c8d386c5..c3d902a71 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -435,6 +435,8 @@ def _make_jit_wrapper(jit_info: PjitInfo): try: in_shardings = _resolve_in_shardings( args_flat, params['in_shardings'], params['out_shardings'], mesh) + in_layouts_flat = _resolve_in_layouts( + args_flat, in_layouts_flat, in_shardings) lowering = _pjit_lower( params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], @@ -1130,7 +1132,6 @@ def explain_tracing_cache_miss( p("explanation unavailable! please open an issue at https://github.com/google/jax") return done() - @partial(lu.cache, explain=explain_tracing_cache_miss) def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): del ignored_inline # just for explain_cache_miss @@ -1264,6 +1265,35 @@ pjit_p = core.AxisPrimitive("pjit") pjit_p.multiple_results = True +def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings): + # If device or backend is set, return the default layout. This is because you + # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' + # which causes error checks to fail. Returning the default layout allows + # this to exist. It's the same for handling shardings. + if pxla.check_device_backend_on_shardings(jit_in_shardings): + return (None,) * len(jit_in_layouts) + + resolved_in_layouts = [] + for arg, jit_in_l in safe_zip(args, jit_in_layouts): + arg_layout, committed = ( + (arg.layout, getattr(arg, '_committed', True)) + if getattr(arg, 'layout', None) is not None else (None, False)) + if jit_in_l is None: + if committed: + resolved_in_layouts.append(arg_layout) + else: + resolved_in_layouts.append(None) + else: + if committed and arg_layout != jit_in_l: + raise ValueError('Layout passed to jit does not match the layout ' + 'on the respective arg. ' + f'Got pjit layout: {jit_in_l},\n' + f'arg sharding: {arg_layout} for ' + f'arg shape: {shaped_abstractify(arg).str_short()}') + resolved_in_layouts.append(jit_in_l) + return tuple(resolved_in_layouts) + + def _resolve_in_shardings( args, pjit_in_shardings: Sequence[PjitSharding], out_shardings: Sequence[PjitSharding], @@ -1387,8 +1417,9 @@ def _pjit_call_impl_python( _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.enable_checks.value: - pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings, - jaxpr.jaxpr.debug_info) + pxla.check_array_xla_sharding_layout_match( + args, compiled._in_shardings, compiled._in_layouts, + jaxpr.jaxpr.debug_info) if config.distributed_debug.value: # Defensively only perform fingerprint logic if debug logging is enabled # NOTE(skyewm): I didn't benchmark this diff --git a/tests/layout_test.py b/tests/layout_test.py index 7a7168ee7..6bf7e3087 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -14,10 +14,12 @@ import math import os +import re from absl.testing import absltest import numpy as np import jax +import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import config from jax._src import layout @@ -50,6 +52,15 @@ def tearDownModule(): xla_bridge.get_backend.cache_clear() +pattern = re.compile(r"\{(.*?):") + +# Extract minor_to_major from str(layout) because layout doesn't have a +# minor_to_major property yet. +def extract_minor_to_major(l): + match = re.search(pattern, str(l)) + return tuple(int(i) for i in match.groups()[0].split(',')) + + class LayoutTest(jtu.JaxTestCase): def setUp(self): @@ -60,9 +71,11 @@ class LayoutTest(jtu.JaxTestCase): super().setUp() def test_auto_layout(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) + s1 = NamedSharding(mesh, P('x', 'y')) + s2 = NamedSharding(mesh, P('x')) def apply(x, y): return x.T, y.T @@ -71,70 +84,89 @@ class LayoutTest(jtu.JaxTestCase): return x * 2, y * 2 np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y'))) np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x'))) + sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1) + sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2) - lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO, - _out_layouts=layout.AUTO) + lowered_apply = jax.jit(apply).lower( + sds1, sds2, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO) compiled_apply = lowered_apply.compile() arg_layouts, kw_layouts = compiled_apply._input_layouts() self.assertEmpty(kw_layouts) + for i, o in zip(arg_layouts, compiled_apply._output_layouts()): - self.assertEqual(i._minor_to_major, o._minor_to_major[::-1]) + self.assertEqual(extract_minor_to_major(i), + extract_minor_to_major(o)[::-1]) init_compiled = jax.jit(init).lower( - arr1, arr2, _out_layouts=arg_layouts).compile() + sds1, sds2, _out_layouts=arg_layouts).compile() for i, o in zip(init_compiled._input_layouts()[0], init_compiled._output_layouts()): - self.assertEqual(i._minor_to_major, o._minor_to_major) + self.assertEqual(i, o) + + arr1 = jax.device_put(np_inp1, s1) + arr2 = jax.device_put(np_inp2, s2) with jtu.count_aot_jit_cpp_cache_miss() as init_count: init_out = init_compiled(arr1, arr2) init_compiled(arr1, arr2) self.assertEqual(init_count[0], 1) + self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0]) + self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0]) + with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count[0], 1) + self.assertEqual(apply_out[0].layout, compiled_apply._output_layouts()[0]) + self.assertEqual(apply_out[1].layout, compiled_apply._output_layouts()[1]) + + self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout), + extract_minor_to_major(init_out[0].layout)[::-1]) + self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout), + extract_minor_to_major(init_out[1].layout)[::-1]) + self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T) self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - shape = (8, 4, 2) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) arr = jax.device_put(np_inp, s) def f(x): return x.T - lowered = jax.jit(f).lower(arr, _in_layouts=None, _out_layouts=None) + lowered = jax.jit(f).lower(sds, _in_layouts=None, _out_layouts=None) self.assertIn("default", lowered.as_text()) compiled = lowered.compile() out = compiled(arr) - self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (2, 1, 0)) - self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (2, 1, 0)) + self.assertTupleEqual( + extract_minor_to_major(compiled._input_layouts()[0][0]), (2, 1, 0)) + self.assertTupleEqual( + extract_minor_to_major(compiled._output_layouts()), (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO, + compiled_auto = jax.jit(f).lower(sds, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile() - self.assertTupleEqual(compiled_auto._input_layouts()[0][0]._minor_to_major, - (2, 1, 0)) - self.assertTupleEqual(compiled_auto._output_layouts()._minor_to_major, - (0, 1, 2)) + self.assertTupleEqual( + extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0)) + self.assertTupleEqual( + extract_minor_to_major(compiled_auto._output_layouts()), (0, 1, 2)) def test_in_layouts_out_layouts(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -142,17 +174,21 @@ class LayoutTest(jtu.JaxTestCase): def f(x): return x.T + compiled = jax.jit(f).lower( arr, _in_layouts=None, _out_layouts=layout.AUTO).compile() - self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0)) - self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1)) + self.assertTupleEqual( + extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0)) + self.assertTupleEqual( + extract_minor_to_major(compiled._output_layouts()), (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) + self.assertEqual(out.layout, compiled._output_layouts()) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -160,8 +196,10 @@ class LayoutTest(jtu.JaxTestCase): compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower( np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile() out = compiled(np_inp) - self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0)) - self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1)) + self.assertTupleEqual( + extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0)) + self.assertTupleEqual( + extract_minor_to_major(compiled._output_layouts()), (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -188,6 +226,39 @@ class LayoutTest(jtu.JaxTestCase): # TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array # and then pass that back into compiled. + def test_aot_layout_mismatch(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + shape = (256, 4, 2) + np_inp = np.arange(math.prod(shape)).reshape(shape) + s = NamedSharding(mesh, P('x')) + + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) + arr = jax.device_put(np_inp, s) + + def f(x): + return (x * 2).T + + with self.assertRaisesRegex( + ValueError, + 'Layout passed to jit does not match the layout on the respective arg'): + jax.jit(f).lower(arr, _in_layouts=layout.AUTO) + + compiled = jax.jit(f).lower( + sds, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile() + + with self.assertRaisesRegex( + ValueError, + r'Compiled object called with input layout\(s\) does' + r' not match the layout\(s\) the computation was' + ' compiled with'): + compiled(arr) + + def test_cpu_default_backend_layout(self): + out_cpu = jax.jit(jnp.dot, backend='cpu')(np.ones((8, 8)), np.ones((8, 8))) + + jax.jit(jnp.dot, backend=jax.default_backend()).lower( + out_cpu, out_cpu).compile() # doesn't crash + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/memories_test.py b/tests/memories_test.py index e2d390599..02cb12b79 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -198,8 +198,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase): self.assertEqual(dev.default_memory().kind, "device") def test_parameter_streaming(self): - _, s_host, _, inp_host = _create_inputs( - (8, 2), P("x", "y"), mem_kind="unpinned_host") + self.skipTest("Enable after pinned_host support exists") + + _, s_host, np_inp, inp_host = _create_inputs( + (8, 2), P("x", "y"), mem_kind="pinned_host") s_dev = s_host.with_memory_kind('device') inp_dev = jax.device_put(inp_host, s_dev) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ed197498f..42e03a978 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -2848,6 +2848,20 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertEqual(compiled._executable._kept_var_idx, {5}) self.assertLen(compiled._executable.in_avals, 1) + def test_pjit_relayout_multi_slice(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + + @jax.jit + def mul(x): + return x @ x.T + + x = jnp.arange(8).reshape(4, 2) + y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y'))) + compiled = mul.lower(jax.ShapeDtypeStruct( + y.shape, y.dtype, sharding=y.sharding)).compile() + out = compiled(y) + self.assertArraysEqual(out, x @ x.T) + def test_pjit_with_device_arg(self): def mul(x): return x @ x.T