[Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.

Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd

PiperOrigin-RevId: 618878870
This commit is contained in:
Yash Katariya 2024-03-25 10:07:55 -07:00 committed by jax authors
parent b9e699f554
commit 25d01e983c
8 changed files with 215 additions and 76 deletions

View File

@ -34,10 +34,12 @@ from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import layout
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -527,6 +529,18 @@ class ArrayImpl(basearray.Array):
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
return out
@property
def layout(self):
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
try:
return layout.SpecifiedLayout(self._pjrt_layout)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
return None
else:
raise
@property
def global_shards(self) -> Sequence[Shard]:
"""Returns list of all `Shard`s of the Array across all devices.
@ -637,7 +651,7 @@ if not TYPE_CHECKING:
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)
# explicitly set to be unhashable. Same as what device_array.py does.
# explicitly set to be unhashable.
setattr(ArrayImpl, "__hash__", None)
setattr(ArrayImpl, "__array_priority__", 100)

View File

@ -46,7 +46,7 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import XLACompatibleLayout, LayoutRequest
from jax._src.layout import AutoLayout, SpecifiedLayout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
@ -834,10 +834,10 @@ def _to_physical_op_sharding(
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
if layout is None:
return "default"
if isinstance(layout, LayoutRequest):
if isinstance(layout, AutoLayout):
return "auto"
return layout._to_xla_layout()
@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
result_names: Sequence[str | None] | None = None,
num_replicas: int = 1,

View File

@ -60,7 +60,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
from jax._src.layout import SpecifiedLayout, AutoLayout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
@ -1985,13 +1985,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
return False
return True
MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]]
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
in_shardings: Any
in_layouts: Any
debug_info: core.JaxprDebugInfo | None
@ -2023,7 +2024,7 @@ def lower_sharding_computation(
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
closed_jaxpr.jaxpr.debug_info)
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
@ -2227,8 +2228,6 @@ def lower_mesh_computation(
out_jaxpr_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_avals = out_jaxpr_avals
@ -2319,7 +2318,7 @@ def lower_mesh_computation(
in_layouts=(None,) * len(global_in_avals),
out_layouts=(None,) * len(global_out_avals),
shape_poly_state=lowering_result.shape_poly_state,
all_args_info=all_args_info)
all_args_info=None)
class MeshComputation(stages.XlaLowering):
_hlo: ir.Module | None
@ -2599,7 +2598,7 @@ def _get_layouts_from_executable(
if isinstance(i, SpecifiedLayout):
if i != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
new_in_layouts.append(i)
else:
new_in_layouts.append(x)
@ -2610,7 +2609,7 @@ def _get_layouts_from_executable(
if isinstance(o, SpecifiedLayout):
if o != x:
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
new_out_layouts.append(o)
else:
new_out_layouts.append(x)
@ -3016,6 +3015,7 @@ class MeshExecutable(stages.XlaExecutable):
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
ref_avals = self.in_avals
in_shardings = self._in_shardings
in_layouts = self._in_layouts
debug_info = None
else:
kept_args = args
@ -3023,12 +3023,16 @@ class MeshExecutable(stages.XlaExecutable):
iter_in_shardings = iter(self._in_shardings)
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_shardings)]
iter_in_layouts = iter(self._in_layouts)
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
for i, s in enumerate(self._all_args_info.in_layouts)]
debug_info = self._all_args_info.debug_info
arg_avals = map(xla.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
check_array_xla_sharding_layout_match(kept_args, in_shardings,
in_layouts, debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@ -3184,15 +3188,17 @@ def check_device_backend_on_shardings(shardings) -> bool:
return False
def check_gda_or_array_xla_sharding_match(
def check_array_xla_sharding_layout_match(
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
in_xla_layouts: Sequence[SpecifiedLayout],
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
from jax._src.array import ArrayImpl
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
jaxpr_debug_info.arg_names)
errors = []
num_errors = 5
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
@ -3205,27 +3211,45 @@ def check_gda_or_array_xla_sharding_match(
# Raise memory kind mismatch error even if the arg is uncommitted.
if arg.sharding.memory_kind != xs.memory_kind:
errors.append(
"Got input sharding(s) that compiled object was called with: "
("Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
'sharding'))
if (not db_xs and arg._committed and
not op_shardings.are_op_shardings_equal(
arg.sharding._to_xla_hlo_sharding(arg.ndim),
xs._to_xla_hlo_sharding(arg.ndim))):
errors.append(
"Got input sharding(s) that compiled object was called with: "
("Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
'sharding'))
if (xla_extension_version >= 249 and not db_xs and arg._committed and
arg.layout is not None and xl is not None and arg.layout != xl):
errors.append(
("Got input layout(s) that compiled object was called with: "
f"{arg.layout} and layout(s) the computation was compiled "
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
'layout'))
if errors:
str_errors = '\n'.join(errors[:num_errors])
first_errors, error_kinds = unzip2(errors[:num_errors])
str_errors = '\n'.join(first_errors)
if all(k == 'sharding' for k in error_kinds):
kind_str = r'sharding(s)'
elif all(k == 'layout' for k in error_kinds):
kind_str = 'layout(s)'
else:
kind_str = 'sharding(s) and layout(s)'
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
"Compiled object called with input sharding(s) does not match the "
"sharding(s) the computation was compiled with. "
f"Compiled object called with input {kind_str} does "
f"not match the {kind_str} the computation was "
"compiled with. "
f"Here are {num_mismatch_str}:\n{str_errors}")

View File

@ -14,8 +14,6 @@
from __future__ import annotations
import re
from jax._src.lib import xla_client as xc
@ -24,16 +22,10 @@ class Layout:
pass
class XLACompatibleLayout(Layout):
class SpecifiedLayout(Layout):
layout: xc.PjRtLayout
def _to_xla_layout(self) -> str:
raise NotImplementedError("Subclasses should implement this method.")
class SpecifiedLayout(XLACompatibleLayout):
layout: xc.Layout
def __init__(self, layout: xc.Layout):
def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)
@ -51,19 +43,10 @@ class SpecifiedLayout(XLACompatibleLayout):
def _to_xla_layout(self) -> str:
return self._layout_str
@property
def _minor_to_major(self):
m = re.search("{([0-9,]*):", str(self))
assert m is not None
m2m_str = m.group(1)
if m2m_str == "":
return ()
return tuple(int(x) for x in m2m_str.split(","))
class LayoutRequest:
class AutoLayout:
def __repr__(self):
return "Request a layout from the compiler"
return "AUTO"
AUTO = LayoutRequest()
AUTO = AutoLayout()

View File

@ -435,6 +435,8 @@ def _make_jit_wrapper(jit_info: PjitInfo):
try:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
in_layouts_flat = _resolve_in_layouts(
args_flat, in_layouts_flat, in_shardings)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
@ -1130,7 +1132,6 @@ def explain_tracing_cache_miss(
p("explanation unavailable! please open an issue at https://github.com/google/jax")
return done()
@partial(lu.cache, explain=explain_tracing_cache_miss)
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
del ignored_inline # just for explain_cache_miss
@ -1264,6 +1265,35 @@ pjit_p = core.AxisPrimitive("pjit")
pjit_p.multiple_results = True
def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
# If device or backend is set, return the default layout. This is because you
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
# which causes error checks to fail. Returning the default layout allows
# this to exist. It's the same for handling shardings.
if pxla.check_device_backend_on_shardings(jit_in_shardings):
return (None,) * len(jit_in_layouts)
resolved_in_layouts = []
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
arg_layout, committed = (
(arg.layout, getattr(arg, '_committed', True))
if getattr(arg, 'layout', None) is not None else (None, False))
if jit_in_l is None:
if committed:
resolved_in_layouts.append(arg_layout)
else:
resolved_in_layouts.append(None)
else:
if committed and arg_layout != jit_in_l:
raise ValueError('Layout passed to jit does not match the layout '
'on the respective arg. '
f'Got pjit layout: {jit_in_l},\n'
f'arg sharding: {arg_layout} for '
f'arg shape: {shaped_abstractify(arg).str_short()}')
resolved_in_layouts.append(jit_in_l)
return tuple(resolved_in_layouts)
def _resolve_in_shardings(
args, pjit_in_shardings: Sequence[PjitSharding],
out_shardings: Sequence[PjitSharding],
@ -1387,7 +1417,8 @@ def _pjit_call_impl_python(
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
pxla.check_array_xla_sharding_layout_match(
args, compiled._in_shardings, compiled._in_layouts,
jaxpr.jaxpr.debug_info)
if config.distributed_debug.value:
# Defensively only perform fingerprint logic if debug logging is enabled

View File

@ -14,10 +14,12 @@
import math
import os
import re
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P
from jax._src import config
from jax._src import layout
@ -50,6 +52,15 @@ def tearDownModule():
xla_bridge.get_backend.cache_clear()
pattern = re.compile(r"\{(.*?):")
# Extract minor_to_major from str(layout) because layout doesn't have a
# minor_to_major property yet.
def extract_minor_to_major(l):
match = re.search(pattern, str(l))
return tuple(int(i) for i in match.groups()[0].split(','))
class LayoutTest(jtu.JaxTestCase):
def setUp(self):
@ -60,9 +71,11 @@ class LayoutTest(jtu.JaxTestCase):
super().setUp()
def test_auto_layout(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape1 = (128, 128)
shape2 = (128, 128)
s1 = NamedSharding(mesh, P('x', 'y'))
s2 = NamedSharding(mesh, P('x'))
def apply(x, y):
return x.T, y.T
@ -71,70 +84,89 @@ class LayoutTest(jtu.JaxTestCase):
return x * 2, y * 2
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y')))
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x')))
sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1)
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO,
_out_layouts=layout.AUTO)
lowered_apply = jax.jit(apply).lower(
sds1, sds2, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO)
compiled_apply = lowered_apply.compile()
arg_layouts, kw_layouts = compiled_apply._input_layouts()
self.assertEmpty(kw_layouts)
for i, o in zip(arg_layouts, compiled_apply._output_layouts()):
self.assertEqual(i._minor_to_major, o._minor_to_major[::-1])
self.assertEqual(extract_minor_to_major(i),
extract_minor_to_major(o)[::-1])
init_compiled = jax.jit(init).lower(
arr1, arr2, _out_layouts=arg_layouts).compile()
sds1, sds2, _out_layouts=arg_layouts).compile()
for i, o in zip(init_compiled._input_layouts()[0],
init_compiled._output_layouts()):
self.assertEqual(i._minor_to_major, o._minor_to_major)
self.assertEqual(i, o)
arr1 = jax.device_put(np_inp1, s1)
arr2 = jax.device_put(np_inp2, s2)
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
init_out = init_compiled(arr1, arr2)
init_compiled(arr1, arr2)
self.assertEqual(init_count[0], 1)
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0])
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
apply_out = compiled_apply(*init_out)
compiled_apply(*init_out)
self.assertEqual(apply_count[0], 1)
self.assertEqual(apply_out[0].layout, compiled_apply._output_layouts()[0])
self.assertEqual(apply_out[1].layout, compiled_apply._output_layouts()[1])
self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout),
extract_minor_to_major(init_out[0].layout)[::-1])
self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout),
extract_minor_to_major(init_out[1].layout)[::-1])
self.assertArraysEqual(init_out[0], np_inp1 * 2)
self.assertArraysEqual(init_out[1], np_inp2 * 2)
self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T)
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
def test_default_layout(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 4, 2)
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (4, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
arr = jax.device_put(np_inp, s)
def f(x):
return x.T
lowered = jax.jit(f).lower(arr, _in_layouts=None, _out_layouts=None)
lowered = jax.jit(f).lower(sds, _in_layouts=None, _out_layouts=None)
self.assertIn("default", lowered.as_text())
compiled = lowered.compile()
out = compiled(arr)
self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (2, 1, 0))
self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._output_layouts()), (2, 1, 0))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO,
compiled_auto = jax.jit(f).lower(sds, _in_layouts=layout.AUTO,
_out_layouts=layout.AUTO).compile()
self.assertTupleEqual(compiled_auto._input_layouts()[0][0]._minor_to_major,
(2, 1, 0))
self.assertTupleEqual(compiled_auto._output_layouts()._minor_to_major,
(0, 1, 2))
self.assertTupleEqual(
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled_auto._output_layouts()), (0, 1, 2))
def test_in_layouts_out_layouts(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (8, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
@ -142,17 +174,21 @@ class LayoutTest(jtu.JaxTestCase):
def f(x):
return x.T
compiled = jax.jit(f).lower(
arr, _in_layouts=None, _out_layouts=layout.AUTO).compile()
self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0))
self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1))
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._output_layouts()), (0, 1))
out = compiled(arr)
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.layout, compiled._output_layouts())
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
def test_sharding_and_layouts(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (4, 8)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
@ -160,8 +196,10 @@ class LayoutTest(jtu.JaxTestCase):
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
out = compiled(np_inp)
self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0))
self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1))
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._output_layouts()), (0, 1))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, s)
@ -188,6 +226,39 @@ class LayoutTest(jtu.JaxTestCase):
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
# and then pass that back into compiled.
def test_aot_layout_mismatch(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
shape = (256, 4, 2)
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x'))
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
arr = jax.device_put(np_inp, s)
def f(x):
return (x * 2).T
with self.assertRaisesRegex(
ValueError,
'Layout passed to jit does not match the layout on the respective arg'):
jax.jit(f).lower(arr, _in_layouts=layout.AUTO)
compiled = jax.jit(f).lower(
sds, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
with self.assertRaisesRegex(
ValueError,
r'Compiled object called with input layout\(s\) does'
r' not match the layout\(s\) the computation was'
' compiled with'):
compiled(arr)
def test_cpu_default_backend_layout(self):
out_cpu = jax.jit(jnp.dot, backend='cpu')(np.ones((8, 8)), np.ones((8, 8)))
jax.jit(jnp.dot, backend=jax.default_backend()).lower(
out_cpu, out_cpu).compile() # doesn't crash
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -198,8 +198,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
self.assertEqual(dev.default_memory().kind, "device")
def test_parameter_streaming(self):
_, s_host, _, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="unpinned_host")
self.skipTest("Enable after pinned_host support exists")
_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
inp_dev = jax.device_put(inp_host, s_dev)

View File

@ -2848,6 +2848,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(compiled._executable._kept_var_idx, {5})
self.assertLen(compiled._executable.in_avals, 1)
def test_pjit_relayout_multi_slice(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@jax.jit
def mul(x):
return x @ x.T
x = jnp.arange(8).reshape(4, 2)
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y')))
compiled = mul.lower(jax.ShapeDtypeStruct(
y.shape, y.dtype, sharding=y.sharding)).compile()
out = compiled(y)
self.assertArraysEqual(out, x @ x.T)
def test_pjit_with_device_arg(self):
def mul(x):
return x @ x.T