mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd PiperOrigin-RevId: 618878870
This commit is contained in:
parent
b9e699f554
commit
25d01e983c
@ -34,10 +34,12 @@ from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
from jax._src import layout
|
||||
from jax._src import profiler
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -527,6 +529,18 @@ class ArrayImpl(basearray.Array):
|
||||
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
|
||||
return out
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
|
||||
try:
|
||||
return layout.SpecifiedLayout(self._pjrt_layout)
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]:
|
||||
"""Returns list of all `Shard`s of the Array across all devices.
|
||||
@ -637,7 +651,7 @@ if not TYPE_CHECKING:
|
||||
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)
|
||||
|
||||
|
||||
# explicitly set to be unhashable. Same as what device_array.py does.
|
||||
# explicitly set to be unhashable.
|
||||
setattr(ArrayImpl, "__hash__", None)
|
||||
setattr(ArrayImpl, "__array_priority__", 100)
|
||||
|
||||
|
@ -46,7 +46,7 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import XLACompatibleLayout, LayoutRequest
|
||||
from jax._src.layout import AutoLayout, SpecifiedLayout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
@ -834,10 +834,10 @@ def _to_physical_op_sharding(
|
||||
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
|
||||
|
||||
|
||||
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
|
||||
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
|
||||
if layout is None:
|
||||
return "default"
|
||||
if isinstance(layout, LayoutRequest):
|
||||
if isinstance(layout, AutoLayout):
|
||||
return "auto"
|
||||
return layout._to_xla_layout()
|
||||
|
||||
@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
|
||||
replicated_args: Sequence[bool] | None = None,
|
||||
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
|
||||
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
|
||||
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
|
||||
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
|
||||
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
|
||||
arg_names: Sequence[str | None] | None = None,
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
num_replicas: int = 1,
|
||||
|
@ -60,7 +60,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
|
||||
from jax._src.layout import SpecifiedLayout, AutoLayout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -1985,13 +1985,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
return False
|
||||
return True
|
||||
|
||||
MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]]
|
||||
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]
|
||||
|
||||
|
||||
class AllArgsInfo(NamedTuple):
|
||||
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
|
||||
in_avals: Sequence[core.ShapedArray]
|
||||
in_shardings: Any
|
||||
in_layouts: Any
|
||||
debug_info: core.JaxprDebugInfo | None
|
||||
|
||||
|
||||
@ -2023,7 +2024,7 @@ def lower_sharding_computation(
|
||||
auto_spmd_lowering = check_if_any_auto(
|
||||
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
|
||||
closed_jaxpr.jaxpr.debug_info)
|
||||
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
@ -2227,8 +2228,6 @@ def lower_mesh_computation(
|
||||
out_jaxpr_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
|
||||
|
||||
assert len(out_shardings) == len(out_jaxpr_avals)
|
||||
if spmd_lowering:
|
||||
global_out_avals = out_jaxpr_avals
|
||||
@ -2319,7 +2318,7 @@ def lower_mesh_computation(
|
||||
in_layouts=(None,) * len(global_in_avals),
|
||||
out_layouts=(None,) * len(global_out_avals),
|
||||
shape_poly_state=lowering_result.shape_poly_state,
|
||||
all_args_info=all_args_info)
|
||||
all_args_info=None)
|
||||
|
||||
class MeshComputation(stages.XlaLowering):
|
||||
_hlo: ir.Module | None
|
||||
@ -2599,7 +2598,7 @@ def _get_layouts_from_executable(
|
||||
if isinstance(i, SpecifiedLayout):
|
||||
if i != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
|
||||
new_in_layouts.append(i)
|
||||
else:
|
||||
new_in_layouts.append(x)
|
||||
@ -2610,7 +2609,7 @@ def _get_layouts_from_executable(
|
||||
if isinstance(o, SpecifiedLayout):
|
||||
if o != x:
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
|
||||
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
|
||||
new_out_layouts.append(o)
|
||||
else:
|
||||
new_out_layouts.append(x)
|
||||
@ -3016,6 +3015,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
|
||||
ref_avals = self.in_avals
|
||||
in_shardings = self._in_shardings
|
||||
in_layouts = self._in_layouts
|
||||
debug_info = None
|
||||
else:
|
||||
kept_args = args
|
||||
@ -3023,12 +3023,16 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
iter_in_shardings = iter(self._in_shardings)
|
||||
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
|
||||
for i, s in enumerate(self._all_args_info.in_shardings)]
|
||||
iter_in_layouts = iter(self._in_layouts)
|
||||
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
|
||||
for i, s in enumerate(self._all_args_info.in_layouts)]
|
||||
debug_info = self._all_args_info.debug_info
|
||||
|
||||
arg_avals = map(xla.abstractify, kept_args)
|
||||
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
|
||||
check_array_xla_sharding_layout_match(kept_args, in_shardings,
|
||||
in_layouts, debug_info)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
@ -3184,15 +3188,17 @@ def check_device_backend_on_shardings(shardings) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def check_gda_or_array_xla_sharding_match(
|
||||
def check_array_xla_sharding_layout_match(
|
||||
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
in_xla_layouts: Sequence[SpecifiedLayout],
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
|
||||
from jax._src.array import ArrayImpl
|
||||
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
|
||||
jaxpr_debug_info.arg_names)
|
||||
errors = []
|
||||
num_errors = 5
|
||||
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
|
||||
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
|
||||
arg_names):
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
if is_unspecified_or_auto(xs):
|
||||
@ -3205,27 +3211,45 @@ def check_gda_or_array_xla_sharding_match(
|
||||
# Raise memory kind mismatch error even if the arg is uncommitted.
|
||||
if arg.sharding.memory_kind != xs.memory_kind:
|
||||
errors.append(
|
||||
"Got input sharding(s) that compiled object was called with: "
|
||||
("Got input sharding(s) that compiled object was called with: "
|
||||
f"{arg.sharding} and sharding(s) the computation was compiled "
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
'sharding'))
|
||||
|
||||
if (not db_xs and arg._committed and
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
arg.sharding._to_xla_hlo_sharding(arg.ndim),
|
||||
xs._to_xla_hlo_sharding(arg.ndim))):
|
||||
errors.append(
|
||||
"Got input sharding(s) that compiled object was called with: "
|
||||
("Got input sharding(s) that compiled object was called with: "
|
||||
f"{arg.sharding} and sharding(s) the computation was compiled "
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
'sharding'))
|
||||
|
||||
if (xla_extension_version >= 249 and not db_xs and arg._committed and
|
||||
arg.layout is not None and xl is not None and arg.layout != xl):
|
||||
errors.append(
|
||||
("Got input layout(s) that compiled object was called with: "
|
||||
f"{arg.layout} and layout(s) the computation was compiled "
|
||||
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
'layout'))
|
||||
|
||||
if errors:
|
||||
str_errors = '\n'.join(errors[:num_errors])
|
||||
first_errors, error_kinds = unzip2(errors[:num_errors])
|
||||
str_errors = '\n'.join(first_errors)
|
||||
if all(k == 'sharding' for k in error_kinds):
|
||||
kind_str = r'sharding(s)'
|
||||
elif all(k == 'layout' for k in error_kinds):
|
||||
kind_str = 'layout(s)'
|
||||
else:
|
||||
kind_str = 'sharding(s) and layout(s)'
|
||||
num_mismatch_str = (
|
||||
f'the {len(errors)} mismatches' if len(errors) < num_errors else
|
||||
f"{num_errors} mismatches out of {len(errors)}")
|
||||
raise ValueError(
|
||||
"Compiled object called with input sharding(s) does not match the "
|
||||
"sharding(s) the computation was compiled with. "
|
||||
f"Compiled object called with input {kind_str} does "
|
||||
f"not match the {kind_str} the computation was "
|
||||
"compiled with. "
|
||||
f"Here are {num_mismatch_str}:\n{str_errors}")
|
||||
|
||||
|
||||
|
@ -14,8 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
@ -24,16 +22,10 @@ class Layout:
|
||||
pass
|
||||
|
||||
|
||||
class XLACompatibleLayout(Layout):
|
||||
class SpecifiedLayout(Layout):
|
||||
layout: xc.PjRtLayout
|
||||
|
||||
def _to_xla_layout(self) -> str:
|
||||
raise NotImplementedError("Subclasses should implement this method.")
|
||||
|
||||
|
||||
class SpecifiedLayout(XLACompatibleLayout):
|
||||
layout: xc.Layout
|
||||
|
||||
def __init__(self, layout: xc.Layout):
|
||||
def __init__(self, layout: xc.PjRtLayout):
|
||||
self._layout = layout
|
||||
self._layout_str = str(self._layout)
|
||||
|
||||
@ -51,19 +43,10 @@ class SpecifiedLayout(XLACompatibleLayout):
|
||||
def _to_xla_layout(self) -> str:
|
||||
return self._layout_str
|
||||
|
||||
@property
|
||||
def _minor_to_major(self):
|
||||
m = re.search("{([0-9,]*):", str(self))
|
||||
assert m is not None
|
||||
m2m_str = m.group(1)
|
||||
if m2m_str == "":
|
||||
return ()
|
||||
return tuple(int(x) for x in m2m_str.split(","))
|
||||
|
||||
|
||||
class LayoutRequest:
|
||||
class AutoLayout:
|
||||
|
||||
def __repr__(self):
|
||||
return "Request a layout from the compiler"
|
||||
return "AUTO"
|
||||
|
||||
AUTO = LayoutRequest()
|
||||
AUTO = AutoLayout()
|
||||
|
@ -435,6 +435,8 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
try:
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args_flat, params['in_shardings'], params['out_shardings'], mesh)
|
||||
in_layouts_flat = _resolve_in_layouts(
|
||||
args_flat, in_layouts_flat, in_shardings)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
@ -1130,7 +1132,6 @@ def explain_tracing_cache_miss(
|
||||
p("explanation unavailable! please open an issue at https://github.com/google/jax")
|
||||
return done()
|
||||
|
||||
|
||||
@partial(lu.cache, explain=explain_tracing_cache_miss)
|
||||
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
|
||||
del ignored_inline # just for explain_cache_miss
|
||||
@ -1264,6 +1265,35 @@ pjit_p = core.AxisPrimitive("pjit")
|
||||
pjit_p.multiple_results = True
|
||||
|
||||
|
||||
def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
# If device or backend is set, return the default layout. This is because you
|
||||
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
|
||||
# which causes error checks to fail. Returning the default layout allows
|
||||
# this to exist. It's the same for handling shardings.
|
||||
if pxla.check_device_backend_on_shardings(jit_in_shardings):
|
||||
return (None,) * len(jit_in_layouts)
|
||||
|
||||
resolved_in_layouts = []
|
||||
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
|
||||
arg_layout, committed = (
|
||||
(arg.layout, getattr(arg, '_committed', True))
|
||||
if getattr(arg, 'layout', None) is not None else (None, False))
|
||||
if jit_in_l is None:
|
||||
if committed:
|
||||
resolved_in_layouts.append(arg_layout)
|
||||
else:
|
||||
resolved_in_layouts.append(None)
|
||||
else:
|
||||
if committed and arg_layout != jit_in_l:
|
||||
raise ValueError('Layout passed to jit does not match the layout '
|
||||
'on the respective arg. '
|
||||
f'Got pjit layout: {jit_in_l},\n'
|
||||
f'arg sharding: {arg_layout} for '
|
||||
f'arg shape: {shaped_abstractify(arg).str_short()}')
|
||||
resolved_in_layouts.append(jit_in_l)
|
||||
return tuple(resolved_in_layouts)
|
||||
|
||||
|
||||
def _resolve_in_shardings(
|
||||
args, pjit_in_shardings: Sequence[PjitSharding],
|
||||
out_shardings: Sequence[PjitSharding],
|
||||
@ -1387,7 +1417,8 @@ def _pjit_call_impl_python(
|
||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
if compiled._auto_spmd_lowering and config.enable_checks.value:
|
||||
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
|
||||
pxla.check_array_xla_sharding_layout_match(
|
||||
args, compiled._in_shardings, compiled._in_layouts,
|
||||
jaxpr.jaxpr.debug_info)
|
||||
if config.distributed_debug.value:
|
||||
# Defensively only perform fingerprint logic if debug logging is enabled
|
||||
|
@ -14,10 +14,12 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import layout
|
||||
@ -50,6 +52,15 @@ def tearDownModule():
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
pattern = re.compile(r"\{(.*?):")
|
||||
|
||||
# Extract minor_to_major from str(layout) because layout doesn't have a
|
||||
# minor_to_major property yet.
|
||||
def extract_minor_to_major(l):
|
||||
match = re.search(pattern, str(l))
|
||||
return tuple(int(i) for i in match.groups()[0].split(','))
|
||||
|
||||
|
||||
class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -60,9 +71,11 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
|
||||
def test_auto_layout(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape1 = (128, 128)
|
||||
shape2 = (128, 128)
|
||||
s1 = NamedSharding(mesh, P('x', 'y'))
|
||||
s2 = NamedSharding(mesh, P('x'))
|
||||
|
||||
def apply(x, y):
|
||||
return x.T, y.T
|
||||
@ -71,70 +84,89 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
return x * 2, y * 2
|
||||
|
||||
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', 'y')))
|
||||
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
|
||||
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P('x')))
|
||||
sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1)
|
||||
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
|
||||
|
||||
lowered_apply = jax.jit(apply).lower(arr1, arr2, _in_layouts=layout.AUTO,
|
||||
_out_layouts=layout.AUTO)
|
||||
lowered_apply = jax.jit(apply).lower(
|
||||
sds1, sds2, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO)
|
||||
compiled_apply = lowered_apply.compile()
|
||||
|
||||
arg_layouts, kw_layouts = compiled_apply._input_layouts()
|
||||
self.assertEmpty(kw_layouts)
|
||||
|
||||
for i, o in zip(arg_layouts, compiled_apply._output_layouts()):
|
||||
self.assertEqual(i._minor_to_major, o._minor_to_major[::-1])
|
||||
self.assertEqual(extract_minor_to_major(i),
|
||||
extract_minor_to_major(o)[::-1])
|
||||
|
||||
init_compiled = jax.jit(init).lower(
|
||||
arr1, arr2, _out_layouts=arg_layouts).compile()
|
||||
sds1, sds2, _out_layouts=arg_layouts).compile()
|
||||
|
||||
for i, o in zip(init_compiled._input_layouts()[0],
|
||||
init_compiled._output_layouts()):
|
||||
self.assertEqual(i._minor_to_major, o._minor_to_major)
|
||||
self.assertEqual(i, o)
|
||||
|
||||
arr1 = jax.device_put(np_inp1, s1)
|
||||
arr2 = jax.device_put(np_inp2, s2)
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
|
||||
init_out = init_compiled(arr1, arr2)
|
||||
init_compiled(arr1, arr2)
|
||||
self.assertEqual(init_count[0], 1)
|
||||
|
||||
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
|
||||
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0])
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
|
||||
apply_out = compiled_apply(*init_out)
|
||||
compiled_apply(*init_out)
|
||||
self.assertEqual(apply_count[0], 1)
|
||||
|
||||
self.assertEqual(apply_out[0].layout, compiled_apply._output_layouts()[0])
|
||||
self.assertEqual(apply_out[1].layout, compiled_apply._output_layouts()[1])
|
||||
|
||||
self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout),
|
||||
extract_minor_to_major(init_out[0].layout)[::-1])
|
||||
self.assertTupleEqual(extract_minor_to_major(apply_out[1].layout),
|
||||
extract_minor_to_major(init_out[1].layout)[::-1])
|
||||
|
||||
self.assertArraysEqual(init_out[0], np_inp1 * 2)
|
||||
self.assertArraysEqual(init_out[1], np_inp2 * 2)
|
||||
self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T)
|
||||
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
|
||||
|
||||
def test_default_layout(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
shape = (8, 4, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
lowered = jax.jit(f).lower(arr, _in_layouts=None, _out_layouts=None)
|
||||
lowered = jax.jit(f).lower(sds, _in_layouts=None, _out_layouts=None)
|
||||
self.assertIn("default", lowered.as_text())
|
||||
compiled = lowered.compile()
|
||||
out = compiled(arr)
|
||||
|
||||
self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (2, 1, 0))
|
||||
self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._output_layouts()), (2, 1, 0))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
compiled_auto = jax.jit(f).lower(arr, _in_layouts=layout.AUTO,
|
||||
compiled_auto = jax.jit(f).lower(sds, _in_layouts=layout.AUTO,
|
||||
_out_layouts=layout.AUTO).compile()
|
||||
self.assertTupleEqual(compiled_auto._input_layouts()[0][0]._minor_to_major,
|
||||
(2, 1, 0))
|
||||
self.assertTupleEqual(compiled_auto._output_layouts()._minor_to_major,
|
||||
(0, 1, 2))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto._output_layouts()), (0, 1, 2))
|
||||
|
||||
def test_in_layouts_out_layouts(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (8, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -142,17 +174,21 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
arr, _in_layouts=None, _out_layouts=layout.AUTO).compile()
|
||||
self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0))
|
||||
self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._output_layouts()), (0, 1))
|
||||
|
||||
out = compiled(arr)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.layout, compiled._output_layouts())
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
def test_sharding_and_layouts(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -160,8 +196,10 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
|
||||
np_inp, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
|
||||
out = compiled(np_inp)
|
||||
self.assertTupleEqual(compiled._input_layouts()[0][0]._minor_to_major, (1, 0))
|
||||
self.assertTupleEqual(compiled._output_layouts()._minor_to_major, (0, 1))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._output_layouts()), (0, 1))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
@ -188,6 +226,39 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
|
||||
# and then pass that back into compiled.
|
||||
|
||||
def test_aot_layout_mismatch(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
shape = (256, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
return (x * 2).T
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
jax.jit(f).lower(arr, _in_layouts=layout.AUTO)
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
sds, _in_layouts=layout.AUTO, _out_layouts=layout.AUTO).compile()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Compiled object called with input layout\(s\) does'
|
||||
r' not match the layout\(s\) the computation was'
|
||||
' compiled with'):
|
||||
compiled(arr)
|
||||
|
||||
def test_cpu_default_backend_layout(self):
|
||||
out_cpu = jax.jit(jnp.dot, backend='cpu')(np.ones((8, 8)), np.ones((8, 8)))
|
||||
|
||||
jax.jit(jnp.dot, backend=jax.default_backend()).lower(
|
||||
out_cpu, out_cpu).compile() # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -198,8 +198,10 @@ class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dev.default_memory().kind, "device")
|
||||
|
||||
def test_parameter_streaming(self):
|
||||
_, s_host, _, inp_host = _create_inputs(
|
||||
(8, 2), P("x", "y"), mem_kind="unpinned_host")
|
||||
self.skipTest("Enable after pinned_host support exists")
|
||||
|
||||
_, s_host, np_inp, inp_host = _create_inputs(
|
||||
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
inp_dev = jax.device_put(inp_host, s_dev)
|
||||
|
||||
|
@ -2848,6 +2848,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {5})
|
||||
self.assertLen(compiled._executable.in_avals, 1)
|
||||
|
||||
def test_pjit_relayout_multi_slice(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@jax.jit
|
||||
def mul(x):
|
||||
return x @ x.T
|
||||
|
||||
x = jnp.arange(8).reshape(4, 2)
|
||||
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y')))
|
||||
compiled = mul.lower(jax.ShapeDtypeStruct(
|
||||
y.shape, y.dtype, sharding=y.sharding)).compile()
|
||||
out = compiled(y)
|
||||
self.assertArraysEqual(out, x @ x.T)
|
||||
|
||||
def test_pjit_with_device_arg(self):
|
||||
def mul(x):
|
||||
return x @ x.T
|
||||
|
Loading…
x
Reference in New Issue
Block a user