Add Layout support to jax.jit.

`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.

Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

PiperOrigin-RevId: 622352537
This commit is contained in:
Yash Katariya 2024-04-05 20:08:48 -07:00 committed by jax authors
parent f88139bf67
commit c125442644
15 changed files with 390 additions and 217 deletions

View File

@ -34,7 +34,6 @@ from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import layout
from jax._src import profiler
from jax._src import tree_util
from jax._src import xla_bridge
@ -47,6 +46,7 @@ from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.typing import ArrayLike
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
@ -529,15 +529,17 @@ class ArrayImpl(basearray.Array):
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
return out
@property
@functools.cached_property
def layout(self):
# TODO(yashkatariya): Remove the deleted check from here.
if self.is_deleted():
return Layout(None, self.sharding)
try:
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
self.sharding)
return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding)
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
return layout.Layout(None, self.sharding)
return Layout(None, self.sharding)
else:
raise

View File

@ -895,9 +895,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
inline, keep_unused):
in_shardings, out_shardings,
in_layouts, out_layouts,
resource_env, donated_invars, name, inline, keep_unused):
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
@ -908,10 +908,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# Update pjit params to account for extra error values.
num_error_vals = len(err_vals)
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
sharding = sharding_impls.UNSPECIFIED
sharding = sharding_impls.UNSPECIFIED
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
new_in_layouts = (*[None] * num_error_vals, *in_layouts)
new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
err_and_out = pjit.pjit_p.bind(
@ -919,6 +921,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
jaxpr=checked_jaxpr,
in_shardings=new_in_shardings,
out_shardings=new_out_shardings,
in_layouts=new_in_layouts,
out_layouts=new_out_layouts,
resource_env=resource_env,
donated_invars=new_donated_invars,
name=name,

View File

@ -452,10 +452,7 @@ def _device_put_impl(
return x
if x_dll is None and dll is None:
return _device_put_sharding_impl(x, aval, l.sharding)
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
# out_layouts from lower.
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
x, _out_layouts=l).compile()(x)
return api.jit(_identity_fn, out_shardings=l)(x)
return _device_put_sharding_impl(x, aval, device)

View File

@ -954,11 +954,6 @@ def lower_jaxpr_to_module(
else:
dim_vars = ()
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
else in_layouts)
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
else out_layouts)
ctx = ModuleContext(backend_or_name=backend_or_name,
platforms=platforms, axis_context=axis_context,
keepalives=keepalives,
@ -992,8 +987,8 @@ def lower_jaxpr_to_module(
result_names=result_names,
arg_memory_kinds=arg_memory_kinds,
result_memory_kinds=result_memory_kinds,
arg_layouts=arg_layouts,
result_layouts=result_layouts)
arg_layouts=in_layouts,
result_layouts=out_layouts)
try:
if not ctx.module.operation.verify():
@ -1140,8 +1135,8 @@ def lower_jaxpr_to_fun(
result_names: Sequence[str | None] | None = None,
arg_memory_kinds: Sequence[str | None] | None = None,
result_memory_kinds: Sequence[str | None] | None = None,
arg_layouts: Sequence[str | None] | None = None,
result_layouts: Sequence[str | None] | None = None,
arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
@ -1262,7 +1257,8 @@ def lower_jaxpr_to_fun(
ir_arg_layouts = None
if arg_layouts is not None:
ir_arg_layouts = util.flatten(
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
[[_to_xla_layout(l)] * len(types)
for l, types in zip(arg_layouts, input_types)])
ir_donated_args = None
if xla_donated_args is not None:
@ -1285,7 +1281,8 @@ def lower_jaxpr_to_fun(
ir_result_layouts = None
if result_layouts is not None:
ir_result_layouts = util.flatten(
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
[[_to_xla_layout(l)] * len(types)
for l, types in zip(result_layouts, output_types)])
if (
replicated_args is not None

View File

@ -2035,15 +2035,15 @@ def lower_sharding_computation(
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Sequence[MaybeSharding],
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
keep_unused: bool,
inline: bool,
devices_from_context: Sequence[xc.Device] | None = None,
lowering_parameters: mlir.LoweringParameters,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
lowering_parameters: mlir.LoweringParameters
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -3266,8 +3266,9 @@ def check_array_xla_sharding_layout_match(
arg.layout.device_local_layout != xl):
errors.append(
("Got input layout(s) that compiled object was called with: "
f"{arg.layout} and layout(s) the computation was compiled "
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
f"{arg.layout.device_local_layout} and layout(s) the computation was "
f"compiled with: {xl} for arg {name} with "
f"shape: {arg.aval.str_short()}",
'layout'))
if errors:

View File

@ -714,9 +714,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
return pxla.lower_sharding_computation(
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
(None,) * len(in_avals), (None,) * len(out_avals),
donated_invars, in_avals, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
devices_from_context=None, lowering_parameters=lowering_parameters)
class EvaluationPlan(NamedTuple):

View File

@ -53,7 +53,6 @@ from jax._src.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import xla
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -67,13 +66,13 @@ from jax._src.sharding_impls import (
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.layout import Layout, LayoutOptions
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
from jax._src.state import discharge as state_discharge, RefEffect
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
PyTreeDef)
PyTreeDef, none_leaf_registry as none_lr)
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
@ -150,6 +149,10 @@ class PjitInfo(NamedTuple):
in_shardings_leaves: tuple[Any, ...]
out_shardings_treedef: PyTreeDef
out_shardings_leaves: tuple[Any, ...]
in_layouts_treedef: PyTreeDef
in_layouts_leaves: tuple[Any, ...]
out_layouts_treedef: PyTreeDef
out_layouts_leaves: tuple[Any, ...]
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
@ -164,8 +167,9 @@ class PjitInfo(NamedTuple):
def _python_pjit_helper(jit_info, *args, **kwargs):
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
_infer_params(jit_info, args, kwargs)
(args_flat, _, params, _, out_tree, _, arg_names,
attrs_tracked) = _infer_params(jit_info, args, kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
@ -202,6 +206,7 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
if attrs_tracked:
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
_set_states(attrs_tracked, final_states)
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
@ -335,6 +340,30 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device,
any(not is_unspecified(i) for i in out_shardings_flat))
def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
for e in entries_flat:
if e is None or is_unspecified_or_auto(e):
layouts.append(None)
shardings.append(e)
elif isinstance(e, Layout):
layouts.append(e.device_local_layout)
shardings.append(e.sharding)
elif isinstance(e, (DeviceLocalLayout, AutoLayout)):
raise ValueError(
'`jax.jit` does not accept device-local layouts directly. Create '
'a `Layout` instance wrapping this device-local layout and pass '
f'that to `jit` instead. Got {e}')
else:
layouts.append(None)
shardings.append(e)
assert len(layouts) == len(shardings)
return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings)
def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
@ -378,16 +407,19 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_shardings = tuple(in_shardings)
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
out_layouts, out_shardings = _split_layout_and_sharding(out_shardings)
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
user_specified_in_shardings = (in_shardings is not None and
not is_unspecified(in_shardings))
none_leaf_registry = tree_util.none_leaf_registry
in_shardings_leaves, in_shardings_treedef = none_leaf_registry.flatten(
in_shardings)
out_shardings_leaves, out_shardings_treedef = none_leaf_registry.flatten(
out_shardings)
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
@ -408,6 +440,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
in_shardings_leaves=tuple(in_shardings_leaves),
out_shardings_treedef=out_shardings_treedef,
out_shardings_leaves=tuple(out_shardings_leaves),
in_layouts_treedef=in_layouts_treedef,
in_layouts_leaves=tuple(in_layouts_leaves),
out_layouts_treedef=out_layouts_treedef,
out_layouts_leaves=tuple(out_layouts_leaves),
static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
@ -417,37 +453,58 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
use_resource_env=use_resource_env)
# TODO(yashkatariya): Delete this function once internal users migrate off of
# the deprecated AOT API.
def _handle_layouts_in_aot(jit_info: PjitInfo, kwargs):
if '_in_layouts' in kwargs or '_out_layouts' in kwargs:
warnings.warn(
'Passing `_in_layouts` and `_out_layouts` to `.lower` is deprecated and'
' will be removed soon. Please pass your `Layout` instances to'
' `in_shardings` and `out_shardings` arguments of `jax.jit`',
DeprecationWarning)
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
in_layouts, _ = _split_layout_and_sharding(in_layouts)
out_layouts, _ = _split_layout_and_sharding(out_layouts)
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
return jit_info._replace(in_layouts_treedef=in_layouts_treedef,
in_layouts_leaves=tuple(in_layouts_leaves),
out_layouts_treedef=out_layouts_treedef,
out_layouts_leaves=tuple(out_layouts_leaves))
return jit_info
def _make_jit_wrapper(jit_info: PjitInfo):
wrapped = _cpp_pjit(jit_info)
@api_boundary
def lower(*args, **kwargs):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
# TODO(yashkatariya): Remove this when it's added on jit.
in_layouts = kwargs.pop('_in_layouts', Layout())
out_layouts = kwargs.pop('_out_layouts', Layout())
# TODO(yashkatariya): Remove this handling once internal users migrate off
# of the deprecated API
new_jit_info = _handle_layouts_in_aot(jit_info, kwargs)
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, in_layouts_flat, out_layouts_flat,
arg_names, ()) = _infer_params(
jit_info, args, kwargs, in_layouts=in_layouts, out_layouts=out_layouts)
donated_invars, arg_names, ()) = _infer_params(new_jit_info, args, kwargs)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
try:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
in_layouts_flat = _resolve_in_layouts(
args_flat, in_layouts_flat, in_shardings)
out_layouts_flat = _resolve_out_layouts(out_layouts_flat)
in_layouts = _resolve_in_layouts(
args_flat, params['in_layouts'], in_shardings,
params['jaxpr'].in_avals)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
in_layouts, params['out_layouts'],
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'], in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters)
params['keep_unused'], params['inline'],
lowering_parameters=lowering_parameters)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
fun = jit_info.fun
fun = new_jit_info.fun
fun_name = getattr(fun, '__qualname__',
getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
@ -461,19 +518,18 @@ def _make_jit_wrapper(jit_info: PjitInfo):
@api_boundary
def eval_shape(*args, **kwargs):
_, _, params, _, out_tree, _, _, _, _, _ = _infer_params(
jit_info, args, kwargs, in_layouts=None, out_layouts=None
)
out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s)
for s in params['out_shardings']]
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in params['out_shardings']]
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
for x, s in zip(params['jaxpr'].out_avals, out_s)]
return tree_unflatten(out_tree, out)
wrapped = _cpp_pjit(jit_info)
wrapped.lower = lower
wrapped.eval_shape = eval_shape
return wrapped
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
@ -490,10 +546,11 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
return _make_jit_wrapper(jit_info)
def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
def _infer_params(jit_info, args, kwargs):
(fun, fun_sourceinfo, fun_signature, user_specified_in_shardings,
in_shardings_treedef, in_shardings_leaves, out_shardings_treedef,
out_shardings_leaves, static_argnums, static_argnames,
out_shardings_leaves, in_layouts_treedef, in_layouts_leaves,
out_layouts_treedef, out_layouts_leaves, static_argnums, static_argnames,
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
abstracted_axes, _, use_resource_env) = jit_info
@ -576,17 +633,18 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
) from e
in_type = in_avals = tuple(avals)
canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves, hashable_pytree(in_layouts),
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
in_layouts_treedef, in_layouts_leaves,
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
flat_fun, out_shardings_treedef, out_shardings_leaves,
hashable_pytree(out_layouts), in_type, dbg, device_or_backend_set,
HashableFunction(out_tree, closure=()),
out_layouts_treedef, out_layouts_leaves, in_type, dbg,
device_or_backend_set, HashableFunction(out_tree, closure=()),
HashableFunction(res_paths, closure=()), inline)
assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat)
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
if config.dynamic_shapes.value:
implicit_args = _extract_implicit_args(in_type, explicit_args)
@ -595,18 +653,19 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
args_flat = [*implicit_args, *explicit_args]
num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts)
canonicalized_in_shardings_flat = \
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) ==
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat))
# in_shardings and out_shardings here are all GSPMDSharding.
params = dict(
jaxpr=jaxpr,
in_shardings=canonicalized_in_shardings_flat,
out_shardings=out_shardings,
in_shardings=in_shardings_flat,
out_shardings=out_shardings_flat,
in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unknown>'),
@ -614,8 +673,7 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
inline=inline,
)
return (consts + args_flat, in_type, params, in_tree, out_tree(),
donated_invars, in_layouts_flat, out_layouts_flat,
dbg.arg_names if dbg else None, attrs_tracked)
donated_invars, dbg.arg_names if dbg else None, attrs_tracked)
def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
@ -973,8 +1031,8 @@ class PytreeLeaf:
@lru_cache(maxsize=4096)
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
in_layouts_thunk, in_avals,
in_tree, debug_info,
in_layouts_treedef, in_layouts_leaves,
in_avals, in_tree, debug_info,
device_or_backend_set, kws):
if not kws:
in_tree, _ = treedef_children(in_tree)
@ -988,7 +1046,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
in_layouts = in_layouts_thunk()
in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves)
if in_layouts is None:
in_layouts_flat = (in_layouts,) * len(in_avals)
else:
@ -1001,7 +1059,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.arg_names,
"pjit arguments", allow_uneven_sharding=False)
return in_shardings_flat, tuple(in_layouts_flat)
return in_shardings_flat, in_layouts_flat
callsites: set[str] = set()
@ -1168,13 +1226,9 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
@lru_cache(maxsize=4096)
def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_thunk, out_tree,
out_type, debug_info, device_or_backend_set):
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
# instead. This condition exists because flatten_axis_resources passes in an
# `object()` while unflattening which breaks assertion is user defined
# pytrees (which shouldn't exist but they do).
if (is_unspecified(orig_out_shardings) or
isinstance(orig_out_shardings, XLACompatibleSharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_type)
@ -1183,7 +1237,7 @@ def _check_and_canonicalize_out_shardings(
"pjit out_shardings", out_tree(), orig_out_shardings,
tupled_args=False)
out_layouts = out_layouts_thunk()
out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves)
if out_layouts is None:
out_layouts_flat = (out_layouts,) * len(out_type)
else:
@ -1195,18 +1249,20 @@ def _check_and_canonicalize_out_shardings(
out_shardings_flat, out_type,
None if debug_info is None else debug_info.result_paths,
"pjit outputs", allow_uneven_sharding=False)
return out_shardings_flat, tuple(out_layouts_flat)
return out_shardings_flat, out_layouts_flat
def _pjit_jaxpr(fun, out_shardings_treedef, out_shardings_leaves,
out_layouts_thunk, in_type, debug_info, device_or_backend_set,
out_tree, result_paths, inline):
out_layouts_treedef, out_layouts_leaves, in_type, debug_info,
device_or_backend_set, out_tree, result_paths, inline):
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
fun, in_type, debug_info, result_paths, IgnoreKey(inline))
canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_thunk, out_tree, tuple(out_type),
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, tuple(out_type),
jaxpr.jaxpr.debug_info, device_or_backend_set)
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked
return (jaxpr, final_consts, canonicalized_out_shardings_flat,
out_layouts_flat, attrs_tracked)
@dataclasses.dataclass(frozen=True)
@ -1259,30 +1315,65 @@ pjit_p = core.AxisPrimitive("pjit")
pjit_p.multiple_results = True
def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
@lru_cache(maxsize=2048)
def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval):
if is_unspecified_or_auto(sharding):
return None
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return None
if not core.is_constant_shape(aval.shape):
return None
shard_shape = sharding.shard_shape(aval.shape)
d = sharding._device_assignment[0]
# If a backend doesn't implement `get_default_layout` return `None` to avoid
# cache misses. This can happen when you have `jit(f, in_shardings=s)`. On
# first call you pass it a sharded array with layout and on second call you
# pass a numpy array. The layouts should be the same to get cache hits.
try:
al = DeviceLocalLayout(
d.client.get_default_layout(aval.dtype, shard_shape, d))
except:
return None
# argument does not have `.layout` property. ShapedArray, ShapedDtypeStruct,
# numpy array, etc are some examples.
if arg_layout is None:
return al if jit_in_layout is None else arg_layout # arg_layout is None
# If arg has a `.layout` property, then return device_local_layout as is.
return arg_layout.device_local_layout
def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
# If device or backend is set, return the default layout. This is because you
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
# which causes error checks to fail. Returning the default layout allows
# this to exist. It's the same for handling shardings.
if pxla.check_device_backend_on_shardings(jit_in_shardings):
if pxla.check_device_backend_on_shardings(resolved_in_shardings):
return (None,) * len(jit_in_layouts)
resolved_in_layouts = []
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
for arg, jit_in_l, rs, aval in safe_zip(
args, jit_in_layouts, resolved_in_shardings, in_avals):
arg_layout, committed = (
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
if getattr(arg, 'layout', None) is not None else (None, False))
jit_in_l = None if jit_in_l is None else jit_in_l.device_local_layout
_maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval),
getattr(arg, '_committed', True))
# Sharding can be unspecified when array is committed if it's a PmapSharding.
is_pmap_sharding = (is_unspecified(rs) or
isinstance(getattr(arg, 'sharding', None), PmapSharding))
if jit_in_l is None:
if committed:
resolved_in_layouts.append(arg_layout)
if is_pmap_sharding:
resolved_in_layouts.append(None)
else:
resolved_in_layouts.append(arg_layout)
else:
resolved_in_layouts.append(None)
else:
# arg_layout can be None because some backends don't implement the
# required layout methods. Hence `arr.layout` can return
# `Layout(None, sharding)`
if committed and arg_layout is not None and arg_layout != jit_in_l:
if (committed and not is_pmap_sharding and
arg_layout is not None and arg_layout != jit_in_l):
raise ValueError('Layout passed to jit does not match the layout '
'on the respective arg. '
f'Got pjit layout: {jit_in_l},\n'
@ -1292,13 +1383,6 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
return tuple(resolved_in_layouts)
def _resolve_out_layouts(out_layouts: Sequence[Layout]
) -> Sequence[LayoutOptions]:
# TODO(yashkatariya): Remove the if condition when all layouts come via the
# `layout.Layout` API or handle this properly when layout is on jit.
return tuple(None if o is None else o.device_local_layout for o in out_layouts)
def _resolve_in_shardings(
args, pjit_in_shardings: Sequence[PjitSharding],
out_shardings: Sequence[PjitSharding],
@ -1335,8 +1419,10 @@ def _resolve_in_shardings(
pxla._get_and_check_device_assignment(
it.chain(
util.stable_unique(committed_arg_shardings),
((i, pxla.MismatchType.IN_SHARDING, None) for i in util.stable_unique(pjit_in_shardings)),
((o, pxla.MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings))),
((i, pxla.MismatchType.IN_SHARDING, None)
for i in util.stable_unique(pjit_in_shardings)),
((o, pxla.MismatchType.OUT_SHARDING, None)
for o in util.stable_unique(out_shardings))),
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
resolved_in_shardings = []
@ -1405,16 +1491,17 @@ def _resolve_in_shardings(
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars,
name, keep_unused, inline):
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
global _most_recent_pjit_call_executable
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals)
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline,
lowering_parameters=mlir.LoweringParameters()).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
@ -1434,6 +1521,8 @@ def _pjit_call_impl_python(
distributed_debug_log(("Running pjit'd function", name),
("in_shardings", in_shardings),
("out_shardings", out_shardings),
("in_layouts", in_layouts),
("out_layouts", out_layouts),
("abstract args", map(xla.abstractify, args)),
("fingerprint", fingerprint))
try:
@ -1465,8 +1554,9 @@ def _pjit_call_impl_python(
@weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, keep_unused, inline):
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name,
keep_unused, inline):
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
# the jaxpr defeating the purpose of weakref_lru_cache. So return a function
@ -1478,12 +1568,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env,
def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env,
donated_invars, name, keep_unused, inline):
def call_impl_cache_miss(*args_, **kwargs_):
out_flat, compiled = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, resource_env=resource_env,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
fastpath_data = _get_fastpath_data(
@ -1492,7 +1584,7 @@ def _pjit_call_impl(*args, jaxpr,
return out_flat, fastpath_data
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
@ -1520,22 +1612,15 @@ def _pjit_lower_cached(
jaxpr: core.ClosedJaxpr,
in_shardings,
out_shardings,
in_layouts: pxla.MaybeLayout,
out_layouts: pxla.MaybeLayout,
resource_env,
donated_invars,
name: str,
keep_unused: bool,
inline: bool,
*,
lowering_parameters: mlir.LoweringParameters,
in_layouts: pxla.MaybeLayout | None = None,
out_layouts: pxla.MaybeLayout | None = None):
# TODO(yashkatariya): Remove this when layouts are supported on jit and
# passed to params.
if in_layouts is None:
in_layouts = (None,) * len(in_shardings)
if out_layouts is None:
out_layouts = (None,) * len(out_shardings)
lowering_parameters: mlir.LoweringParameters):
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
@ -1558,18 +1643,19 @@ def _pjit_lower_cached(
else:
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
tuple(donated_invars), tuple(jaxpr.in_avals),
in_layouts, out_layouts, tuple(donated_invars), tuple(jaxpr.in_avals),
keep_unused=keep_unused, inline=inline,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_parameters=lowering_parameters, in_layouts=in_layouts,
out_layouts=out_layouts)
lowering_parameters=lowering_parameters)
def pjit_staging_rule(trace, *args, **params):
if (params["inline"] and
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"])):
all(is_unspecified(o) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
jaxpr = params['jaxpr']
if config.dynamic_shapes.value:
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
@ -1598,14 +1684,16 @@ def pjit_staging_rule(trace, *args, **params):
jaxpr, consts = pxla._move_mutable_consts(params['jaxpr'])
consts = map(trace.instantiate_const, consts)
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
donated_invars=donated_invars)
in_layouts=in_layouts, donated_invars=donated_invars)
return trace.default_process_primitive(pjit_p, (*args, *consts), new_params)
else:
return trace.default_process_primitive(pjit_p, args, params)
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
# since it's actually not possible in general to infer the type from the term
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
@ -1630,13 +1718,14 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
core.custom_typechecks[pjit_p] = _pjit_typecheck
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_):
def _pjit_abstract_eval(*args, jaxpr, **_):
return jaxpr.out_avals, jaxpr.effects
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
out_shardings, api_name):
out_shardings, in_layouts, out_layouts,
api_name):
mod_ctx = ctx.module_context
axis_ctx = ctx.module_context.axis_context
num_devices = None
@ -1647,7 +1736,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
key = (pjit_p, name, jaxpr, effects, num_devices,
pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals),
pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals),
api_name)
in_layouts, out_layouts, api_name)
func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
@ -1659,14 +1748,15 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
func = mlir.lower_jaxpr_to_fun(
mod_ctx, name, jaxpr, effects, ctx.name_stack,
arg_shardings=arg_shardings, result_shardings=result_shardings,
use_sharding_annotations=False, api_name=api_name)
use_sharding_annotations=False, api_name=api_name,
arg_layouts=in_layouts, result_layouts=out_layouts)
mod_ctx.cached_primitive_lowerings[key] = func
return func
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
keep_unused, inline):
out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, keep_unused, inline):
effects = list(ctx.tokens_in.effects())
output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
@ -1674,7 +1764,8 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
func = _pjit_cached_lower_jaxpr_to_fun(
ctx, name, jaxpr, tuple(effects), in_shardings,
out_shardings, api_name=('jit' if resource_env is None else 'pjit'))
out_shardings, in_layouts, out_layouts,
api_name=('jit' if resource_env is None else 'pjit'))
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
args = (*ctx.dim_var_values, *tokens_in, *args)
@ -1693,7 +1784,7 @@ mlir.register_lowering(pjit_p, _pjit_lowering)
def _pjit_batcher(insert_axis, spmd_axis_name,
axis_size, axis_name, main_type,
vals_in, dims_in,
jaxpr, in_shardings, out_shardings,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
new_jaxpr, axes_out = batching.batch_jaxpr2(
@ -1718,16 +1809,24 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
_pjit_batcher_for_sharding(o, axis_out, new_parts, mesh, aval.ndim)
if axis_out is not None else o
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
# TODO(yashkatariya): Figure out layouts should change under vmap.
if not (all(l is None for l in in_layouts) and
all(l is None for l in out_layouts)):
raise NotImplementedError
vals_out = pjit_p.bind(
*vals_in,
jaxpr=new_jaxpr,
in_shardings=in_shardings,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
resource_env=resource_env,
donated_invars=donated_invars,
name=name,
keep_unused=keep_unused,
inline=inline)
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
vals_in, vals_out, axes_out)
return vals_out, resolved_axes_out
@ -1773,7 +1872,7 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
@ -1788,6 +1887,8 @@ def _pjit_jvp(primals_in, tangents_in,
jaxpr=jaxpr_jvp,
in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)),
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
resource_env=resource_env,
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
name=name,
@ -1813,7 +1914,8 @@ def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, keep_unused, inline):
in_layouts, out_layouts, resource_env, donated_invars,
name, keep_unused, inline):
in_pvals = [t.pval for t in in_tracers]
known_ins = tuple(pv.is_known() for pv in in_pvals)
@ -1824,25 +1926,31 @@ def _pjit_partial_eval(trace, *in_tracers,
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals)
res_shardings = (UNSPECIFIED,) * num_residuals
res_layouts = (None,) * num_residuals
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
] + in_fwd_res
in_fwd = [
fwd if is_unspecified(os) and ol is None else None
for os, ol, fwd in zip(
keep_where(out_shardings, known_outs),
keep_where(out_layouts, known_outs), in_fwd_primal)
] + in_fwd_res
del in_fwd_primal, in_fwd_res
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
keep = [f is None for f in in_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(known_out_shardings, keep)
known_out_layouts = keep_where(known_out_layouts, keep)
# Update num_out_primals to reflect pruning.
kept_primals, kept_res = split_list(keep, [num_out_primals])
num_out_primals = sum(kept_primals)
@ -1856,14 +1964,18 @@ def _pjit_partial_eval(trace, *in_tracers,
keep = [f is None for f in out_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(known_out_shardings, keep)
known_out_layouts = keep_where(known_out_layouts, keep)
del keep
known_params = dict(
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
out_shardings=known_out_shardings, resource_env=resource_env,
out_shardings=known_out_shardings,
in_layouts=keep_where(in_layouts, known_ins),
out_layouts=known_out_layouts, resource_env=resource_env,
donated_invars=keep_where(donated_invars, known_ins),
name=name, keep_unused=keep_unused, inline=inline)
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals)
# Bind known things to pjit_p.
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
@ -1888,6 +2000,8 @@ def _pjit_partial_eval(trace, *in_tracers,
jaxpr=unknown_jaxpr,
in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings),
out_shardings=keep_where(out_shardings, unknown_outs),
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
out_layouts=keep_where(out_layouts, unknown_outs),
resource_env=resource_env,
donated_invars=(keep_where(donated_invars, unknown_ins) +
(False,) * num_residuals),
@ -1921,28 +2035,41 @@ def _pjit_partial_eval_custom_params_updater(
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
in_layouts_known, _ = pe.partition_list(unks_in, params_known['in_layouts'])
_, out_layouts_known = pe.partition_list(kept_outs_known, params_known['out_layouts'])
new_params_known = dict(params_known,
in_shardings=tuple(in_shardings_known),
out_shardings=(*out_shardings_known,
*[UNSPECIFIED] * num_res_out),
in_layouts=tuple(in_layouts_known),
out_layouts=(*out_layouts_known, *[None] * num_res_out),
donated_invars=tuple(donated_invars_known))
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
assert len(new_params_known['in_layouts']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_layouts']) == len(params_known['jaxpr'].out_avals)
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res_in + donated_invars_staged
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
_, in_layouts_staged = pe.partition_list(inst_in, params_staged['in_layouts'])
in_layouts_staged = [*[None] * num_res_in, *in_layouts_staged]
_, out_layouts_staged = pe.partition_list(kept_outs_staged, params_staged['out_layouts'])
new_params_staged = dict(params_staged,
in_shardings=tuple(in_shardings_staged),
out_shardings=tuple(out_shardings_staged),
in_layouts=tuple(in_layouts_staged),
out_layouts=tuple(out_layouts_staged),
donated_invars=tuple(donated_invars_staged))
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
assert len(new_params_staged['in_layouts']) == len(params_staged['jaxpr'].in_avals)
assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals)
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
@ -1959,7 +2086,7 @@ def _pjit_transpose_trace(fun, in_avals):
def _pjit_transpose(cts_in, *primals_in,
jaxpr, in_shardings, out_shardings,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
@ -1973,6 +2100,10 @@ def _pjit_transpose(cts_in, *primals_in,
*prune_type(ad.UndefinedPrimal, in_shardings, primals_in),
*prune_type(ad.Zero, out_shardings, cts_in)
)
transpose_in_layouts = (
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
*prune_type(ad.Zero, out_layouts, cts_in)
)
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
for ct in primals_and_nz_cts_in)
@ -1983,26 +2114,36 @@ def _pjit_transpose(cts_in, *primals_in,
ad.Zero,
in_shardings,
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
transpose_out_layouts = prune_type(
ad.Zero,
in_layouts,
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
if attrs_tracked:
init_states = _get_states(attrs_tracked)
primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in]
transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings
transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings
transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts
transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts
nz_cts_out = pjit_p.bind(
*primals_and_nz_cts_in,
jaxpr=transpose_jaxpr,
in_shardings=transpose_in_shardings,
out_shardings=transpose_out_shardings,
in_layouts=transpose_in_layouts,
out_layouts=transpose_out_layouts,
resource_env=resource_env,
donated_invars=(False,) * len(primals_and_nz_cts_in),
name=name,
keep_unused=keep_unused,
inline=inline)
if attrs_tracked:
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
_set_states(attrs_tracked, final_states)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose
@ -2029,6 +2170,8 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
jaxpr=dced_jaxpr,
in_shardings=keep_where(eqn_params["in_shardings"], used_inputs),
out_shardings=keep_where(eqn_params["out_shardings"], used_outputs),
in_layouts=keep_where(eqn_params["in_layouts"], used_inputs),
out_layouts=keep_where(eqn_params["out_layouts"], used_outputs),
donated_invars=keep_where(eqn_params["donated_invars"], used_inputs),
)
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
@ -2112,6 +2255,10 @@ def _pjit_pp_rule(eqn, context, settings):
del params['in_shardings']
if all(is_unspecified(s) for s in params['out_shardings']):
del params['out_shardings']
if all(l is None for l in params['in_layouts']):
del params['in_layouts']
if all(l is None for l in params['out_layouts']):
del params['out_layouts']
if not params['keep_unused']:
del params['keep_unused']
if (params['resource_env'] is None or
@ -2126,18 +2273,28 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
def _pjit_state_discharge_rule(
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, **params):
if not (all(map(is_unspecified, in_shardings)) and
all(map(is_unspecified, out_shardings))): raise NotImplementedError
all(map(is_unspecified, out_shardings))):
raise NotImplementedError
if not (all(l is None for l in in_layouts) and
all(l is None for l in out_layouts)):
raise NotImplementedError
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
num_outs = len(jaxpr.outvars)
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts)
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars)
new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars)
new_in_layouts = (None,) * len(discharged_jaxpr.invars)
new_out_layouts = (None,) * len(discharged_jaxpr.outvars)
out_and_ref_vals = pjit_p.bind(
*args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings,
out_shardings=new_out_shardings, **params)
out_shardings=new_out_shardings, in_layouts=new_in_layouts,
out_layouts=new_out_layouts, **params)
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
ref_vals_iter = iter(ref_vals)
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, state_discharge.AbstractRef)

View File

@ -88,12 +88,10 @@ class Executable(Protocol):
"""
raise NotImplementedError
# Layouts are exposed via jax.experimental.layouts
# TODO(frostig,yashkatariya): expose here when no longer experimental.
def _input_layouts(self):
def input_layouts(self):
raise NotImplementedError
def _output_layouts(self):
def output_layouts(self):
raise NotImplementedError
def as_text(self) -> str:
@ -228,11 +226,11 @@ class XlaExecutable(Executable):
raise NotImplementedError(
"compiled executable carries no output sharding information")
def _input_layouts(self):
def input_layouts(self):
raise NotImplementedError(
"compiled executable carries no input layout information")
def _output_layouts(self):
def output_layouts(self):
raise NotImplementedError(
"compiled executable carries no input layout information")
@ -511,7 +509,7 @@ class Compiled(Stage):
shardings_flat = self._executable.output_shardings()
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
def _input_layouts(self):
def input_layouts(self):
layouts_flat = self._executable.input_layouts()
assert all(isinstance(l, Layout) for l in layouts_flat)
# Some input layouts got DCE'd
@ -521,7 +519,7 @@ class Compiled(Stage):
else Layout() for i in range(self.in_tree.num_leaves)]
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
def _output_layouts(self):
def output_layouts(self):
layouts_flat = self._executable.output_layouts()
assert all(isinstance(l, Layout) for l in layouts_flat)
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error

View File

@ -1643,6 +1643,8 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn],
eqn.params["out_shardings"]
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
),
in_layouts=(eqn.params["in_layouts"] + (None, None)),
out_layouts=(eqn.params["out_layouts"] + (None, None)),
),
)
)

View File

@ -3504,6 +3504,7 @@ def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_shardings: Sequence[sharding.XLACompatibleSharding],
out_shardings: Sequence[sharding.XLACompatibleSharding],
in_layouts, out_layouts,
resource_env: mesh.ResourceEnv,
donated_invars,
name: str,

View File

@ -739,6 +739,8 @@ def _pjit_jet_rule(primals_in, series_in, **params):
params['out_shardings']
+ (sharding_impls.UNSPECIFIED,) * num_series_out
),
'in_layouts': params['in_layouts'] + (None,) * num_series_in,
'out_layouts': params['out_layouts'] + (None,) * num_series_out,
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
}
result = pjit.pjit_p.bind(*primals_and_series, **new_params)

View File

@ -772,7 +772,8 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, keep_unused, inline):
in_layouts, out_layouts, resource_env, donated_invars, name,
keep_unused, inline):
if any(donated_invars):
raise NotImplementedError("sparse xla_call with donated_invars")
@ -790,12 +791,20 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
sharding_impls.UNSPECIFIED
for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings))
)
in_layouts = in_layouts + tuple(
None for _ in range(len(args_flat) - len(in_layouts))
)
out_layouts = out_layouts + tuple(
None for _ in range(len(sp_call_jaxpr.out_avals) - len(out_layouts))
)
out_flat = pjit.pjit_p.bind(
*args_flat,
jaxpr=sp_call_jaxpr,
in_shardings=in_shardings,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
resource_env=resource_env,
donated_invars=donated_invars,
name=name,

View File

@ -451,8 +451,8 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"h\"}.*"
r"%arg1: tensor<i..> {jax.global_constant = \"w\"}.*"
r"%arg0: tensor<i..> {jax.global_constant = \"h\".*"
r"%arg1: tensor<i..> {jax.global_constant = \"w\".*"
r"%arg2: tensor<\?x\?xf32>"
)
self.assertRegex(module_str, wrapped_main_expected_re)
@ -1238,12 +1238,12 @@ class JaxExportTest(jtu.JaxTestCase):
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"b1\"}.*, "
r"%arg1: tensor<i..> {jax.global_constant = \"b2\"}.*, "
r"%arg2: !stablehlo.token {jax.token = true}.*, "
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
r"%arg2: !stablehlo.token {jax.token = true.* "
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
@ -1254,10 +1254,10 @@ class JaxExportTest(jtu.JaxTestCase):
else:
main_expected_re = (
r"@main\("
r"%arg0: !stablehlo.token {jax.token = true}.*, "
r"%arg0: !stablehlo.token {jax.token = true.*, "
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
@ -1284,13 +1284,13 @@ class JaxExportTest(jtu.JaxTestCase):
mlir_module_str = str(exp.mlir_module())
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\"}.*, "
r"%arg1: tensor<i..> {jax.global_constant = \"b1\"}.*, "
r"%arg2: tensor<i..> {jax.global_constant = \"b2\"}.*, "
r"%arg3: !stablehlo.token {jax.token = true}.*, "
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
r"%arg3: !stablehlo.token {jax.token = true.*, "
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
@ -1301,11 +1301,11 @@ class JaxExportTest(jtu.JaxTestCase):
else:
main_expected_re = (
r"@main\("
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\"}.*, "
r"%arg1: !stablehlo.token {jax.token = true}.*, "
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
r"%arg1: !stablehlo.token {jax.token = true.*, "
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
self.assertRegex(mlir_module_str, main_expected_re)
res = export.call_exported(exp)(x)
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),

View File

@ -88,22 +88,22 @@ class LayoutTest(jtu.JaxTestCase):
sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1)
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
lowered_apply = jax.jit(apply).lower(
sds1, sds2, _in_layouts=Layout(DLL.AUTO), _out_layouts=Layout(DLL.AUTO))
lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO),
out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2)
compiled_apply = lowered_apply.compile()
arg_layouts, kw_layouts = compiled_apply._input_layouts()
arg_layouts, kw_layouts = compiled_apply.input_layouts()
self.assertEmpty(kw_layouts)
for i, o in zip(arg_layouts, compiled_apply._output_layouts()):
for i, o in zip(arg_layouts, compiled_apply.output_layouts()):
self.assertEqual(extract_minor_to_major(i),
extract_minor_to_major(o)[::-1])
init_compiled = jax.jit(init).lower(
sds1, sds2, _out_layouts=arg_layouts).compile()
init_compiled = jax.jit(
init, out_shardings=arg_layouts).lower(sds1, sds2).compile()
for i, o in zip(init_compiled._input_layouts()[0],
init_compiled._output_layouts()):
for i, o in zip(init_compiled.input_layouts()[0],
init_compiled.output_layouts()):
self.assertEqual(i, o)
arr1 = jax.device_put(np_inp1, s1)
@ -114,16 +114,16 @@ class LayoutTest(jtu.JaxTestCase):
init_compiled(arr1, arr2)
self.assertEqual(init_count[0], 1)
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1])
self.assertEqual(init_out[0].layout, init_compiled.output_layouts()[0])
self.assertEqual(init_out[1].layout, init_compiled.output_layouts()[1])
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
apply_out = compiled_apply(*init_out)
compiled_apply(*init_out)
self.assertEqual(apply_count[0], 1)
self.assertEqual(apply_out[0].layout, compiled_apply._output_layouts()[0])
self.assertEqual(apply_out[1].layout, compiled_apply._output_layouts()[1])
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0])
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1])
self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout),
extract_minor_to_major(init_out[0].layout)[::-1])
@ -146,24 +146,29 @@ class LayoutTest(jtu.JaxTestCase):
def f(x):
return x.T
lowered = jax.jit(f).lower(sds, _in_layouts=None, _out_layouts=None)
lowered = jax.jit(f, in_shardings=None, out_shardings=None).lower(sds)
self.assertIn("default", lowered.as_text())
compiled = lowered.compile()
out = compiled(arr)
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (2, 1, 0))
extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._output_layouts()), (2, 1, 0))
extract_minor_to_major(compiled.output_layouts()), (2, 1, 0))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
compiled_auto = jax.jit(f).lower(sds, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO),
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
self.assertTupleEqual(
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled_auto._output_layouts()), (0, 1, 2))
extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2))
with self.assertRaisesRegex(
ValueError, "jax.jit` does not accept device-local layouts directly"):
jax.jit(f, in_shardings=DLL.AUTO,
out_shardings=DLL.AUTO).lower(sds).compile()
def test_in_layouts_out_layouts(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -175,16 +180,16 @@ class LayoutTest(jtu.JaxTestCase):
def f(x):
return x.T
compiled = jax.jit(f).lower(
arr, _in_layouts=Layout(), _out_layouts=Layout(DLL.AUTO)).compile()
compiled = jax.jit(f, in_shardings=Layout(),
out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._output_layouts()), (0, 1))
extract_minor_to_major(compiled.output_layouts()), (0, 1))
out = compiled(arr)
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.layout, compiled._output_layouts())
self.assertEqual(out.layout, compiled.output_layouts())
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
def test_sharding_and_layouts(self):
@ -193,14 +198,13 @@ class LayoutTest(jtu.JaxTestCase):
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, P('x', 'y'))
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
np_inp, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s),
out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile()
out = compiled(np_inp)
self.assertTupleEqual(
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
self.assertTupleEqual(
extract_minor_to_major(compiled._output_layouts()), (0, 1))
extract_minor_to_major(compiled.output_layouts()), (0, 1))
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, s)
@ -210,15 +214,15 @@ class LayoutTest(jtu.JaxTestCase):
shape = (8, 2)
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
compiled = jax.jit(f).lower(*inps, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
arg_layouts, _ = compiled._input_layouts()
compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO),
out_shardings=Layout(DLL.AUTO)).lower(*inps).compile()
arg_layouts, _ = compiled.input_layouts()
out1, out2 = compiled(*inps)
compiled2 = jax.jit(f).lower(*inps, _in_layouts=arg_layouts).compile()
compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile()
out3, out4 = compiled2(*inps)
for l1, l2 in safe_zip(arg_layouts, compiled2._input_layouts()[0]):
for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts()[0]):
self.assertEqual(l1, l2)
self.assertArraysEqual(out1, out3)
@ -244,11 +248,10 @@ class LayoutTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
'Layout passed to jit does not match the layout on the respective arg'):
jax.jit(f).lower(arr, _in_layouts=Layout(DLL.AUTO))
jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr)
compiled = jax.jit(f).lower(
sds, _in_layouts=Layout(DLL.AUTO),
_out_layouts=Layout(DLL.AUTO)).compile()
compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO),
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
with self.assertRaisesRegex(
ValueError,
@ -271,8 +274,8 @@ class LayoutTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
compiled = jax.jit(
lambda x: x * 2).lower(arr, _out_layouts=Layout(DLL.AUTO)).compile()
col = compiled._output_layouts()
lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
col = compiled.output_layouts()
out = jax.device_put(np_inp, col)
self.assertEqual(out.layout, col)
@ -304,7 +307,7 @@ class LayoutTest(jtu.JaxTestCase):
compiled = jax.jit(lambda x: x).lower(x).compile()
with self.assertRaisesRegex(
ValueError, 'Sharding has to be concrete when layout.*'):
Layout(compiled._output_layouts()[0], None)
Layout(compiled.output_layouts()[0], None)
if __name__ == '__main__':

View File

@ -1500,7 +1500,7 @@ class PallasCallInputOutputAliasingTest(PallasTPUTest):
)(x)
o = f(x)
np.testing.assert_array_equal(o, expected)
compiled = f.lower(x).compile()
compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile()
mem_analysis = compiled.memory_analysis()
expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize
self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes)
@ -1528,7 +1528,7 @@ class PallasCallInputOutputAliasingTest(PallasTPUTest):
)(jnp.array([1,2,3]), x)
o = f(x)
np.testing.assert_array_equal(o, expected)
compiled = f.lower(x).compile()
compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile()
mem_analysis = compiled.memory_analysis()
expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize
self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes)