From c125442644df7524d2e1ebc4ee3fc45217df4810 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 5 Apr 2024 20:08:48 -0700 Subject: [PATCH] Add `Layout` support to `jax.jit`. `jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere. Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding. Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU). PiperOrigin-RevId: 622352537 --- jax/_src/array.py | 12 +- jax/_src/checkify.py | 12 +- jax/_src/dispatch.py | 5 +- jax/_src/interpreters/mlir.py | 19 +- jax/_src/interpreters/pxla.py | 11 +- jax/_src/maps.py | 4 +- jax/_src/pjit.py | 393 +++++++++++++++++++-------- jax/_src/stages.py | 14 +- jax/experimental/host_callback.py | 2 + jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/jet.py | 2 + jax/experimental/sparse/transform.py | 11 +- tests/export_test.py | 32 +-- tests/layout_test.py | 85 +++--- tests/pallas/pallas_call_tpu_test.py | 4 +- 15 files changed, 390 insertions(+), 217 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 4c85f1fad..5592e4d3a 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -34,7 +34,6 @@ from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors -from jax._src import layout from jax._src import profiler from jax._src import tree_util from jax._src import xla_bridge @@ -47,6 +46,7 @@ from jax._src.sharding import Sharding from jax._src.sharding_impls import ( SingleDeviceSharding, XLACompatibleSharding, PmapSharding, device_replica_id_map, hashed_index) +from jax._src.layout import DeviceLocalLayout, Layout from jax._src.typing import ArrayLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method @@ -529,15 +529,17 @@ class ArrayImpl(basearray.Array): out.append(Shard(_get_device(a), self.sharding, self.shape, a)) return out - @property + @functools.cached_property def layout(self): + # TODO(yashkatariya): Remove the deleted check from here. + if self.is_deleted(): + return Layout(None, self.sharding) try: - return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout), - self.sharding) + return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding) except xe.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return layout.Layout(None, self.sharding) + return Layout(None, self.sharding) else: raise diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 9b68afcf3..69e3dd158 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -895,9 +895,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, - in_shardings, out_shardings, resource_env, - donated_invars, name, - inline, keep_unused): + in_shardings, out_shardings, + in_layouts, out_layouts, + resource_env, donated_invars, name, inline, keep_unused): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] @@ -908,10 +908,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED + sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_in_layouts = (*[None] * num_error_vals, *in_layouts) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) err_and_out = pjit.pjit_p.bind( @@ -919,6 +921,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, jaxpr=checked_jaxpr, in_shardings=new_in_shardings, out_shardings=new_out_shardings, + in_layouts=new_in_layouts, + out_layouts=new_out_layouts, resource_env=resource_env, donated_invars=new_donated_invars, name=name, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e1ffb0f87..d62d57ea2 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -452,10 +452,7 @@ def _device_put_impl( return x if x_dll is None and dll is None: return _device_put_sharding_impl(x, aval, l.sharding) - # TODO(yashkatariya): Pass layout to out_shardings directly and remove - # out_layouts from lower. - return api.jit(_identity_fn, out_shardings=l.sharding).lower( - x, _out_layouts=l).compile()(x) + return api.jit(_identity_fn, out_shardings=l)(x) return _device_put_sharding_impl(x, aval, device) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2ed7ab713..8a9450201 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -954,11 +954,6 @@ def lower_jaxpr_to_module( else: dim_vars = () - arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None - else in_layouts) - result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None - else out_layouts) - ctx = ModuleContext(backend_or_name=backend_or_name, platforms=platforms, axis_context=axis_context, keepalives=keepalives, @@ -992,8 +987,8 @@ def lower_jaxpr_to_module( result_names=result_names, arg_memory_kinds=arg_memory_kinds, result_memory_kinds=result_memory_kinds, - arg_layouts=arg_layouts, - result_layouts=result_layouts) + arg_layouts=in_layouts, + result_layouts=out_layouts) try: if not ctx.module.operation.verify(): @@ -1140,8 +1135,8 @@ def lower_jaxpr_to_fun( result_names: Sequence[str | None] | None = None, arg_memory_kinds: Sequence[str | None] | None = None, result_memory_kinds: Sequence[str | None] | None = None, - arg_layouts: Sequence[str | None] | None = None, - result_layouts: Sequence[str | None] | None = None, + arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, + result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, ) -> func_dialect.FuncOp: """Lowers jaxpr and its callees to an IR function. @@ -1262,7 +1257,8 @@ def lower_jaxpr_to_fun( ir_arg_layouts = None if arg_layouts is not None: ir_arg_layouts = util.flatten( - [[l] * len(types) for l, types in zip(arg_layouts, input_types)]) + [[_to_xla_layout(l)] * len(types) + for l, types in zip(arg_layouts, input_types)]) ir_donated_args = None if xla_donated_args is not None: @@ -1285,7 +1281,8 @@ def lower_jaxpr_to_fun( ir_result_layouts = None if result_layouts is not None: ir_result_layouts = util.flatten( - [[l] * len(types) for l, types in zip(result_layouts, output_types)]) + [[_to_xla_layout(l)] * len(types) + for l, types in zip(result_layouts, output_types)]) if ( replicated_args is not None diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f173bc50b..403ee11c5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2035,15 +2035,15 @@ def lower_sharding_computation( fun_name: str, in_shardings: Sequence[MaybeSharding], out_shardings: Sequence[MaybeSharding], + in_layouts: MaybeLayout, + out_layouts: MaybeLayout, donated_invars: Sequence[bool], global_in_avals: Sequence[core.ShapedArray], *, keep_unused: bool, inline: bool, devices_from_context: Sequence[xc.Device] | None = None, - lowering_parameters: mlir.LoweringParameters, - in_layouts: MaybeLayout, - out_layouts: MaybeLayout, + lowering_parameters: mlir.LoweringParameters ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -3266,8 +3266,9 @@ def check_array_xla_sharding_layout_match( arg.layout.device_local_layout != xl): errors.append( ("Got input layout(s) that compiled object was called with: " - f"{arg.layout} and layout(s) the computation was compiled " - f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}", + f"{arg.layout.device_local_layout} and layout(s) the computation was " + f"compiled with: {xl} for arg {name} with " + f"shape: {arg.aval.str_short()}", 'layout')) if errors: diff --git a/jax/_src/maps.py b/jax/_src/maps.py index eca363684..4e37153b4 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -714,9 +714,9 @@ def make_xmap_callable(fun: lu.WrappedFun, return pxla.lower_sharding_computation( core.ClosedJaxpr(jaxpr, consts), 'jit', name, (UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals), + (None,) * len(in_avals), (None,) * len(out_avals), donated_invars, in_avals, keep_unused=True, inline=False, - devices_from_context=None, lowering_parameters=lowering_parameters, - in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals)) + devices_from_context=None, lowering_parameters=lowering_parameters) class EvaluationPlan(NamedTuple): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c090bc79d..05d95fffd 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -53,7 +53,6 @@ from jax._src.errors import JAXTypeError from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec from jax._src.interpreters import xla - from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -67,13 +66,13 @@ from jax._src.sharding_impls import ( SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) -from jax._src.layout import Layout, LayoutOptions +from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout from jax._src.state import discharge as state_discharge, RefEffect from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, - PyTreeDef) + PyTreeDef, none_leaf_registry as none_lr) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, weakref_lru_cache, @@ -150,6 +149,10 @@ class PjitInfo(NamedTuple): in_shardings_leaves: tuple[Any, ...] out_shardings_treedef: PyTreeDef out_shardings_leaves: tuple[Any, ...] + in_layouts_treedef: PyTreeDef + in_layouts_leaves: tuple[Any, ...] + out_layouts_treedef: PyTreeDef + out_layouts_leaves: tuple[Any, ...] static_argnums: tuple[int, ...] static_argnames: tuple[str, ...] donate_argnums: tuple[int, ...] @@ -164,8 +167,9 @@ class PjitInfo(NamedTuple): def _python_pjit_helper(jit_info, *args, **kwargs): - args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \ - _infer_params(jit_info, args, kwargs) + (args_flat, _, params, _, out_tree, _, arg_names, + attrs_tracked) = _infer_params(jit_info, args, kwargs) + for arg in args_flat: dispatch.check_arg(arg) @@ -202,6 +206,7 @@ def _python_pjit_helper(jit_info, *args, **kwargs): if attrs_tracked: final_states, out_flat = split_list(out_flat, [len(attrs_tracked)]) _set_states(attrs_tracked, final_states) + outs = tree_unflatten(out_tree, out_flat) return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked @@ -335,6 +340,30 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device, any(not is_unspecified(i) for i in out_shardings_flat)) +def _split_layout_and_sharding(entries): + entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) + layouts, shardings = [], [] + + for e in entries_flat: + if e is None or is_unspecified_or_auto(e): + layouts.append(None) + shardings.append(e) + elif isinstance(e, Layout): + layouts.append(e.device_local_layout) + shardings.append(e.sharding) + elif isinstance(e, (DeviceLocalLayout, AutoLayout)): + raise ValueError( + '`jax.jit` does not accept device-local layouts directly. Create ' + 'a `Layout` instance wrapping this device-local layout and pass ' + f'that to `jit` instead. Got {e}') + else: + layouts.append(None) + shardings.append(e) + + assert len(layouts) == len(shardings) + return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings) + + def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnums: int | Sequence[int] | None, donate_argnames: str | Iterable[str] | None, @@ -378,16 +407,19 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, # rather than raising an error. https://github.com/google/jax/issues/2367 in_shardings = tuple(in_shardings) + in_layouts, in_shardings = _split_layout_and_sharding(in_shardings) + out_layouts, out_shardings = _split_layout_and_sharding(out_shardings) + in_shardings = prepare_axis_resources(in_shardings, 'in_shardings') out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') user_specified_in_shardings = (in_shardings is not None and not is_unspecified(in_shardings)) - none_leaf_registry = tree_util.none_leaf_registry - in_shardings_leaves, in_shardings_treedef = none_leaf_registry.flatten( - in_shardings) - out_shardings_leaves, out_shardings_treedef = none_leaf_registry.flatten( - out_shardings) + + in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings) + out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings) + in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts) + out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts) fun_sourceinfo = api_util.fun_sourceinfo(fun) fun_signature = api_util.fun_signature(fun) @@ -408,6 +440,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, in_shardings_leaves=tuple(in_shardings_leaves), out_shardings_treedef=out_shardings_treedef, out_shardings_leaves=tuple(out_shardings_leaves), + in_layouts_treedef=in_layouts_treedef, + in_layouts_leaves=tuple(in_layouts_leaves), + out_layouts_treedef=out_layouts_treedef, + out_layouts_leaves=tuple(out_layouts_leaves), static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend, @@ -417,37 +453,58 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, use_resource_env=use_resource_env) +# TODO(yashkatariya): Delete this function once internal users migrate off of +# the deprecated AOT API. +def _handle_layouts_in_aot(jit_info: PjitInfo, kwargs): + if '_in_layouts' in kwargs or '_out_layouts' in kwargs: + warnings.warn( + 'Passing `_in_layouts` and `_out_layouts` to `.lower` is deprecated and' + ' will be removed soon. Please pass your `Layout` instances to' + ' `in_shardings` and `out_shardings` arguments of `jax.jit`', + DeprecationWarning) + in_layouts = kwargs.pop('_in_layouts', None) + out_layouts = kwargs.pop('_out_layouts', None) + in_layouts, _ = _split_layout_and_sharding(in_layouts) + out_layouts, _ = _split_layout_and_sharding(out_layouts) + in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts) + out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts) + return jit_info._replace(in_layouts_treedef=in_layouts_treedef, + in_layouts_leaves=tuple(in_layouts_leaves), + out_layouts_treedef=out_layouts_treedef, + out_layouts_leaves=tuple(out_layouts_leaves)) + return jit_info + + def _make_jit_wrapper(jit_info: PjitInfo): - wrapped = _cpp_pjit(jit_info) @api_boundary def lower(*args, **kwargs): lowering_parameters = kwargs.pop( '_experimental_lowering_parameters', mlir.LoweringParameters()) - # TODO(yashkatariya): Remove this when it's added on jit. - in_layouts = kwargs.pop('_in_layouts', Layout()) - out_layouts = kwargs.pop('_out_layouts', Layout()) + # TODO(yashkatariya): Remove this handling once internal users migrate off + # of the deprecated API + new_jit_info = _handle_layouts_in_aot(jit_info, kwargs) + (args_flat, flat_global_in_avals, params, in_tree, out_tree, - donated_invars, in_layouts_flat, out_layouts_flat, - arg_names, ()) = _infer_params( - jit_info, args, kwargs, in_layouts=in_layouts, out_layouts=out_layouts) + donated_invars, arg_names, ()) = _infer_params(new_jit_info, args, kwargs) resource_env = params['resource_env'] mesh = None if resource_env is None else resource_env.physical_mesh try: in_shardings = _resolve_in_shardings( args_flat, params['in_shardings'], params['out_shardings'], mesh) - in_layouts_flat = _resolve_in_layouts( - args_flat, in_layouts_flat, in_shardings) - out_layouts_flat = _resolve_out_layouts(out_layouts_flat) + in_layouts = _resolve_in_layouts( + args_flat, params['in_layouts'], in_shardings, + params['jaxpr'].in_avals) lowering = _pjit_lower( params['jaxpr'], in_shardings, params['out_shardings'], + in_layouts, params['out_layouts'], params['resource_env'], params['donated_invars'], params['name'], - params['keep_unused'], params['inline'], in_layouts=in_layouts_flat, - out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters) + params['keep_unused'], params['inline'], + lowering_parameters=lowering_parameters) except pxla.DeviceAssignmentMismatchError as e: fails, = e.args api_name = 'jit' if params['resource_env'] is None else 'pjit' - fun = jit_info.fun + fun = new_jit_info.fun fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( @@ -461,19 +518,18 @@ def _make_jit_wrapper(jit_info: PjitInfo): @api_boundary def eval_shape(*args, **kwargs): - _, _, params, _, out_tree, _, _, _, _, _ = _infer_params( - jit_info, args, kwargs, in_layouts=None, out_layouts=None - ) - out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s) - for s in params['out_shardings']] + _, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs) + out_s = [None if is_unspecified(s) else s for s in params['out_shardings']] out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s) for x, s in zip(params['jaxpr'].out_avals, out_s)] return tree_unflatten(out_tree, out) + wrapped = _cpp_pjit(jit_info) wrapped.lower = lower wrapped.eval_shape = eval_shape return wrapped + def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnums: int | Sequence[int] | None, donate_argnames: str | Iterable[str] | None, @@ -490,10 +546,11 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, return _make_jit_wrapper(jit_info) -def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None): +def _infer_params(jit_info, args, kwargs): (fun, fun_sourceinfo, fun_signature, user_specified_in_shardings, in_shardings_treedef, in_shardings_leaves, out_shardings_treedef, - out_shardings_leaves, static_argnums, static_argnames, + out_shardings_leaves, in_layouts_treedef, in_layouts_leaves, + out_layouts_treedef, out_layouts_leaves, static_argnums, static_argnames, donate_argnums, donate_argnames, device, backend, keep_unused, inline, abstracted_axes, _, use_resource_env) = jit_info @@ -576,17 +633,18 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None): ) from e in_type = in_avals = tuple(avals) - canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources( - in_shardings_treedef, in_shardings_leaves, hashable_pytree(in_layouts), + in_shardings_flat, in_layouts_flat = _process_in_axis_resources( + in_shardings_treedef, in_shardings_leaves, + in_layouts_treedef, in_layouts_leaves, in_avals, in_tree, dbg, device_or_backend_set, have_kwargs) - jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr( + jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr( flat_fun, out_shardings_treedef, out_shardings_leaves, - hashable_pytree(out_layouts), in_type, dbg, device_or_backend_set, - HashableFunction(out_tree, closure=()), + out_layouts_treedef, out_layouts_leaves, in_type, dbg, + device_or_backend_set, HashableFunction(out_tree, closure=()), HashableFunction(res_paths, closure=()), inline) - assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat) + assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat) if config.dynamic_shapes.value: implicit_args = _extract_implicit_args(in_type, explicit_args) @@ -595,18 +653,19 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None): args_flat = [*implicit_args, *explicit_args] num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts) - canonicalized_in_shardings_flat = \ - (UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat + in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars - assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) == + assert (len(in_shardings_flat) == len(in_layouts_flat) == len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat)) # in_shardings and out_shardings here are all GSPMDSharding. params = dict( jaxpr=jaxpr, - in_shardings=canonicalized_in_shardings_flat, - out_shardings=out_shardings, + in_shardings=in_shardings_flat, + out_shardings=out_shardings_flat, + in_layouts=in_layouts_flat, + out_layouts=out_layouts_flat, resource_env=resource_env, donated_invars=donated_invars, name=getattr(flat_fun, '__name__', ''), @@ -614,8 +673,7 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None): inline=inline, ) return (consts + args_flat, in_type, params, in_tree, out_tree(), - donated_invars, in_layouts_flat, out_layouts_flat, - dbg.arg_names if dbg else None, attrs_tracked) + donated_invars, dbg.arg_names if dbg else None, attrs_tracked) def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], @@ -973,8 +1031,8 @@ class PytreeLeaf: @lru_cache(maxsize=4096) def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, - in_layouts_thunk, in_avals, - in_tree, debug_info, + in_layouts_treedef, in_layouts_leaves, + in_avals, in_tree, debug_info, device_or_backend_set, kws): if not kws: in_tree, _ = treedef_children(in_tree) @@ -988,7 +1046,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, in_shardings_flat = flatten_axis_resources( "pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True) - in_layouts = in_layouts_thunk() + in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves) if in_layouts is None: in_layouts_flat = (in_layouts,) * len(in_avals) else: @@ -1001,7 +1059,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, pjit_check_aval_sharding(in_shardings_flat, in_avals, None if debug_info is None else debug_info.arg_names, "pjit arguments", allow_uneven_sharding=False) - return in_shardings_flat, tuple(in_layouts_flat) + return in_shardings_flat, in_layouts_flat callsites: set[str] = set() @@ -1168,13 +1226,9 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): @lru_cache(maxsize=4096) def _check_and_canonicalize_out_shardings( - out_shardings_treedef, out_shardings_leaves, out_layouts_thunk, out_tree, - out_type, debug_info, device_or_backend_set): + out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, + out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) - # TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources - # instead. This condition exists because flatten_axis_resources passes in an - # `object()` while unflattening which breaks assertion is user defined - # pytrees (which shouldn't exist but they do). if (is_unspecified(orig_out_shardings) or isinstance(orig_out_shardings, XLACompatibleSharding)): out_shardings_flat = (orig_out_shardings,) * len(out_type) @@ -1183,7 +1237,7 @@ def _check_and_canonicalize_out_shardings( "pjit out_shardings", out_tree(), orig_out_shardings, tupled_args=False) - out_layouts = out_layouts_thunk() + out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves) if out_layouts is None: out_layouts_flat = (out_layouts,) * len(out_type) else: @@ -1195,18 +1249,20 @@ def _check_and_canonicalize_out_shardings( out_shardings_flat, out_type, None if debug_info is None else debug_info.result_paths, "pjit outputs", allow_uneven_sharding=False) - return out_shardings_flat, tuple(out_layouts_flat) + return out_shardings_flat, out_layouts_flat def _pjit_jaxpr(fun, out_shardings_treedef, out_shardings_leaves, - out_layouts_thunk, in_type, debug_info, device_or_backend_set, - out_tree, result_paths, inline): + out_layouts_treedef, out_layouts_leaves, in_type, debug_info, + device_or_backend_set, out_tree, result_paths, inline): jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr( fun, in_type, debug_info, result_paths, IgnoreKey(inline)) canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( - out_shardings_treedef, out_shardings_leaves, out_layouts_thunk, out_tree, tuple(out_type), + out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, + out_layouts_leaves, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info, device_or_backend_set) - return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked + return (jaxpr, final_consts, canonicalized_out_shardings_flat, + out_layouts_flat, attrs_tracked) @dataclasses.dataclass(frozen=True) @@ -1259,30 +1315,65 @@ pjit_p = core.AxisPrimitive("pjit") pjit_p.multiple_results = True -def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings): +@lru_cache(maxsize=2048) +def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval): + if is_unspecified_or_auto(sharding): + return None + # TODO(yashkatariya): Figure out how layouts work with extended dtypes. + if dtypes.issubdtype(aval.dtype, dtypes.extended): + return None + if not core.is_constant_shape(aval.shape): + return None + shard_shape = sharding.shard_shape(aval.shape) + d = sharding._device_assignment[0] + # If a backend doesn't implement `get_default_layout` return `None` to avoid + # cache misses. This can happen when you have `jit(f, in_shardings=s)`. On + # first call you pass it a sharded array with layout and on second call you + # pass a numpy array. The layouts should be the same to get cache hits. + try: + al = DeviceLocalLayout( + d.client.get_default_layout(aval.dtype, shard_shape, d)) + except: + return None + # argument does not have `.layout` property. ShapedArray, ShapedDtypeStruct, + # numpy array, etc are some examples. + if arg_layout is None: + return al if jit_in_layout is None else arg_layout # arg_layout is None + # If arg has a `.layout` property, then return device_local_layout as is. + return arg_layout.device_local_layout + + +def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # If device or backend is set, return the default layout. This is because you # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' # which causes error checks to fail. Returning the default layout allows # this to exist. It's the same for handling shardings. - if pxla.check_device_backend_on_shardings(jit_in_shardings): + if pxla.check_device_backend_on_shardings(resolved_in_shardings): return (None,) * len(jit_in_layouts) resolved_in_layouts = [] - for arg, jit_in_l in safe_zip(args, jit_in_layouts): + for arg, jit_in_l, rs, aval in safe_zip( + args, jit_in_layouts, resolved_in_shardings, in_avals): arg_layout, committed = ( - (arg.layout.device_local_layout, getattr(arg, '_committed', True)) - if getattr(arg, 'layout', None) is not None else (None, False)) - jit_in_l = None if jit_in_l is None else jit_in_l.device_local_layout + _maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval), + getattr(arg, '_committed', True)) + # Sharding can be unspecified when array is committed if it's a PmapSharding. + is_pmap_sharding = (is_unspecified(rs) or + isinstance(getattr(arg, 'sharding', None), PmapSharding)) if jit_in_l is None: if committed: - resolved_in_layouts.append(arg_layout) + if is_pmap_sharding: + resolved_in_layouts.append(None) + else: + resolved_in_layouts.append(arg_layout) else: resolved_in_layouts.append(None) else: # arg_layout can be None because some backends don't implement the # required layout methods. Hence `arr.layout` can return # `Layout(None, sharding)` - if committed and arg_layout is not None and arg_layout != jit_in_l: + if (committed and not is_pmap_sharding and + arg_layout is not None and arg_layout != jit_in_l): raise ValueError('Layout passed to jit does not match the layout ' 'on the respective arg. ' f'Got pjit layout: {jit_in_l},\n' @@ -1292,13 +1383,6 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings): return tuple(resolved_in_layouts) -def _resolve_out_layouts(out_layouts: Sequence[Layout] - ) -> Sequence[LayoutOptions]: - # TODO(yashkatariya): Remove the if condition when all layouts come via the - # `layout.Layout` API or handle this properly when layout is on jit. - return tuple(None if o is None else o.device_local_layout for o in out_layouts) - - def _resolve_in_shardings( args, pjit_in_shardings: Sequence[PjitSharding], out_shardings: Sequence[PjitSharding], @@ -1335,8 +1419,10 @@ def _resolve_in_shardings( pxla._get_and_check_device_assignment( it.chain( util.stable_unique(committed_arg_shardings), - ((i, pxla.MismatchType.IN_SHARDING, None) for i in util.stable_unique(pjit_in_shardings)), - ((o, pxla.MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings))), + ((i, pxla.MismatchType.IN_SHARDING, None) + for i in util.stable_unique(pjit_in_shardings)), + ((o, pxla.MismatchType.OUT_SHARDING, None) + for o in util.stable_unique(out_shardings))), (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) resolved_in_shardings = [] @@ -1405,16 +1491,17 @@ def _resolve_in_shardings( def _pjit_call_impl_python( - *args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, - name, keep_unused, inline): + *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline): global _most_recent_pjit_call_executable in_shardings = _resolve_in_shardings( args, in_shardings, out_shardings, resource_env.physical_mesh if resource_env is not None else None) + in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) compiled = _pjit_lower( - jaxpr, in_shardings, out_shardings, resource_env, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, lowering_parameters=mlir.LoweringParameters()).compile() _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled @@ -1434,6 +1521,8 @@ def _pjit_call_impl_python( distributed_debug_log(("Running pjit'd function", name), ("in_shardings", in_shardings), ("out_shardings", out_shardings), + ("in_layouts", in_layouts), + ("out_layouts", out_layouts), ("abstract args", map(xla.abstractify, args)), ("fingerprint", fingerprint)) try: @@ -1465,8 +1554,9 @@ def _pjit_call_impl_python( @weakref_lru_cache -def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env, - donated_invars, name, keep_unused, inline): +def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, + out_layouts, resource_env, donated_invars, name, + keep_unused, inline): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to # the jaxpr defeating the purpose of weakref_lru_cache. So return a function @@ -1478,12 +1568,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env, def _pjit_call_impl(*args, jaxpr, - in_shardings, out_shardings, resource_env, + in_shardings, out_shardings, in_layouts, out_layouts, + resource_env, donated_invars, name, keep_unused, inline): def call_impl_cache_miss(*args_, **kwargs_): out_flat, compiled = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, - out_shardings=out_shardings, resource_env=resource_env, + out_shardings=out_shardings, in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) fastpath_data = _get_fastpath_data( @@ -1492,7 +1584,7 @@ def _pjit_call_impl(*args, jaxpr, return out_flat, fastpath_data f = _get_jaxpr_as_fun( - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) donated_argnums = [i for i, d in enumerate(donated_invars) if d] has_explicit_sharding = _pjit_explicit_sharding( @@ -1520,22 +1612,15 @@ def _pjit_lower_cached( jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, + in_layouts: pxla.MaybeLayout, + out_layouts: pxla.MaybeLayout, resource_env, donated_invars, name: str, keep_unused: bool, inline: bool, *, - lowering_parameters: mlir.LoweringParameters, - in_layouts: pxla.MaybeLayout | None = None, - out_layouts: pxla.MaybeLayout | None = None): - # TODO(yashkatariya): Remove this when layouts are supported on jit and - # passed to params. - if in_layouts is None: - in_layouts = (None,) * len(in_shardings) - if out_layouts is None: - out_layouts = (None,) * len(out_shardings) - + lowering_parameters: mlir.LoweringParameters): if resource_env is not None: pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") @@ -1558,18 +1643,19 @@ def _pjit_lower_cached( else: return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, - tuple(donated_invars), tuple(jaxpr.in_avals), + in_layouts, out_layouts, tuple(donated_invars), tuple(jaxpr.in_avals), keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), - lowering_parameters=lowering_parameters, in_layouts=in_layouts, - out_layouts=out_layouts) + lowering_parameters=lowering_parameters) def pjit_staging_rule(trace, *args, **params): if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and - all(is_unspecified(o) for o in params["out_shardings"])): + all(is_unspecified(o) for o in params["out_shardings"]) and + all(i is None for i in params["in_layouts"]) and + all(o is None for o in params["out_layouts"])): jaxpr = params['jaxpr'] if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic @@ -1598,14 +1684,16 @@ def pjit_staging_rule(trace, *args, **params): jaxpr, consts = pxla._move_mutable_consts(params['jaxpr']) consts = map(trace.instantiate_const, consts) in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) + in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, - donated_invars=donated_invars) + in_layouts=in_layouts, donated_invars=donated_invars) return trace.default_process_primitive(pjit_p, (*args, *consts), new_params) else: return trace.default_process_primitive(pjit_p, args, params) pe.custom_staging_rules[pjit_p] = pjit_staging_rule + # TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them, # since it's actually not possible in general to infer the type from the term def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]: @@ -1630,13 +1718,14 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): core.custom_typechecks[pjit_p] = _pjit_typecheck -def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_): +def _pjit_abstract_eval(*args, jaxpr, **_): return jaxpr.out_avals, jaxpr.effects pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, - out_shardings, api_name): + out_shardings, in_layouts, out_layouts, + api_name): mod_ctx = ctx.module_context axis_ctx = ctx.module_context.axis_context num_devices = None @@ -1647,7 +1736,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, key = (pjit_p, name, jaxpr, effects, num_devices, pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), - api_name) + in_layouts, out_layouts, api_name) func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: @@ -1659,14 +1748,15 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, func = mlir.lower_jaxpr_to_fun( mod_ctx, name, jaxpr, effects, ctx.name_stack, arg_shardings=arg_shardings, result_shardings=result_shardings, - use_sharding_annotations=False, api_name=api_name) + use_sharding_annotations=False, api_name=api_name, + arg_layouts=in_layouts, result_layouts=out_layouts) mod_ctx.cached_primitive_lowerings[key] = func return func def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, - out_shardings, resource_env, donated_invars, - keep_unused, inline): + out_shardings, in_layouts, out_layouts, resource_env, + donated_invars, keep_unused, inline): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_types, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1674,7 +1764,8 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, func = _pjit_cached_lower_jaxpr_to_fun( ctx, name, jaxpr, tuple(effects), in_shardings, - out_shardings, api_name=('jit' if resource_env is None else 'pjit')) + out_shardings, in_layouts, out_layouts, + api_name=('jit' if resource_env is None else 'pjit')) tokens_in = [ctx.tokens_in.get(eff) for eff in effects] args = (*ctx.dim_var_values, *tokens_in, *args) @@ -1693,7 +1784,7 @@ mlir.register_lowering(pjit_p, _pjit_lowering) def _pjit_batcher(insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, dims_in, - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2( @@ -1718,16 +1809,24 @@ def _pjit_batcher(insert_axis, spmd_axis_name, _pjit_batcher_for_sharding(o, axis_out, new_parts, mesh, aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) + # TODO(yashkatariya): Figure out layouts should change under vmap. + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + vals_out = pjit_p.bind( *vals_in, jaxpr=new_jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, + in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) + resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( vals_in, vals_out, axes_out) return vals_out, resolved_axes_out @@ -1773,7 +1872,7 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in] jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr( @@ -1788,6 +1887,8 @@ def _pjit_jvp(primals_in, tangents_in, jaxpr=jaxpr_jvp, in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)), out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)), + in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)), + out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)), resource_env=resource_env, donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), name=name, @@ -1813,7 +1914,8 @@ def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr, def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, - resource_env, donated_invars, name, keep_unused, inline): + in_layouts, out_layouts, resource_env, donated_invars, + name, keep_unused, inline): in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) @@ -1824,25 +1926,31 @@ def _pjit_partial_eval(trace, *in_tracers, known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) res_shardings = (UNSPECIFIED,) * num_residuals + res_layouts = (None,) * num_residuals def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings + known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts # Input-to-output forwarding: compute which outputs are just forwarded inputs. num_out_primals = len(known_jaxpr.out_avals) - num_residuals in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) - in_fwd = [fwd if is_unspecified(os) else None for os, fwd in - zip(keep_where(out_shardings, known_outs), in_fwd_primal) - ] + in_fwd_res + in_fwd = [ + fwd if is_unspecified(os) and ol is None else None + for os, ol, fwd in zip( + keep_where(out_shardings, known_outs), + keep_where(out_layouts, known_outs), in_fwd_primal) + ] + in_fwd_res del in_fwd_primal, in_fwd_res # Prune jaxpr outputs and out_shardings by removing the input-forwards. keep = [f is None for f in in_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) known_out_shardings = keep_where(known_out_shardings, keep) + known_out_layouts = keep_where(known_out_layouts, keep) # Update num_out_primals to reflect pruning. kept_primals, kept_res = split_list(keep, [num_out_primals]) num_out_primals = sum(kept_primals) @@ -1856,14 +1964,18 @@ def _pjit_partial_eval(trace, *in_tracers, keep = [f is None for f in out_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) known_out_shardings = keep_where(known_out_shardings, keep) + known_out_layouts = keep_where(known_out_layouts, keep) del keep known_params = dict( jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins), - out_shardings=known_out_shardings, resource_env=resource_env, + out_shardings=known_out_shardings, + in_layouts=keep_where(in_layouts, known_ins), + out_layouts=known_out_layouts, resource_env=resource_env, donated_invars=keep_where(donated_invars, known_ins), name=name, keep_unused=keep_unused, inline=inline) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) + assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals) # Bind known things to pjit_p. known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()] @@ -1888,6 +2000,8 @@ def _pjit_partial_eval(trace, *in_tracers, jaxpr=unknown_jaxpr, in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings), out_shardings=keep_where(out_shardings, unknown_outs), + in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), + out_layouts=keep_where(out_layouts, unknown_outs), resource_env=resource_env, donated_invars=(keep_where(donated_invars, unknown_ins) + (False,) * num_residuals), @@ -1921,28 +2035,41 @@ def _pjit_partial_eval_custom_params_updater( donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars']) in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings']) _, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings']) + in_layouts_known, _ = pe.partition_list(unks_in, params_known['in_layouts']) + _, out_layouts_known = pe.partition_list(kept_outs_known, params_known['out_layouts']) + new_params_known = dict(params_known, in_shardings=tuple(in_shardings_known), out_shardings=(*out_shardings_known, *[UNSPECIFIED] * num_res_out), + in_layouts=tuple(in_layouts_known), + out_layouts=(*out_layouts_known, *[None] * num_res_out), donated_invars=tuple(donated_invars_known)) assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals) assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals) + assert len(new_params_known['in_layouts']) == len(params_known['jaxpr'].in_avals) + assert len(new_params_known['out_layouts']) == len(params_known['jaxpr'].out_avals) # added num_res new inputs to jaxpr_staged, and pruning according to inst_in _, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars']) donated_invars_staged = [False] * num_res_in + donated_invars_staged _, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings']) in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged] - _, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings']) + _, in_layouts_staged = pe.partition_list(inst_in, params_staged['in_layouts']) + in_layouts_staged = [*[None] * num_res_in, *in_layouts_staged] + _, out_layouts_staged = pe.partition_list(kept_outs_staged, params_staged['out_layouts']) new_params_staged = dict(params_staged, in_shardings=tuple(in_shardings_staged), out_shardings=tuple(out_shardings_staged), + in_layouts=tuple(in_layouts_staged), + out_layouts=tuple(out_layouts_staged), donated_invars=tuple(donated_invars_staged)) assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals) assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals) + assert len(new_params_staged['in_layouts']) == len(params_staged['jaxpr'].in_avals) + assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals) return new_params_known, new_params_staged pe.partial_eval_jaxpr_custom_rules[pjit_p] = \ @@ -1959,7 +2086,7 @@ def _pjit_transpose_trace(fun, in_avals): def _pjit_transpose(cts_in, *primals_in, - jaxpr, in_shardings, out_shardings, + jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -1973,6 +2100,10 @@ def _pjit_transpose(cts_in, *primals_in, *prune_type(ad.UndefinedPrimal, in_shardings, primals_in), *prune_type(ad.Zero, out_shardings, cts_in) ) + transpose_in_layouts = ( + *prune_type(ad.UndefinedPrimal, in_layouts, primals_in), + *prune_type(ad.Zero, out_layouts, cts_in) + ) global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct)) for ct in primals_and_nz_cts_in) @@ -1983,26 +2114,36 @@ def _pjit_transpose(cts_in, *primals_in, ad.Zero, in_shardings, tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) + transpose_out_layouts = prune_type( + ad.Zero, + in_layouts, + tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) if attrs_tracked: init_states = _get_states(attrs_tracked) primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings + transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts + transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts nz_cts_out = pjit_p.bind( *primals_and_nz_cts_in, jaxpr=transpose_jaxpr, in_shardings=transpose_in_shardings, out_shardings=transpose_out_shardings, + in_layouts=transpose_in_layouts, + out_layouts=transpose_out_layouts, resource_env=resource_env, donated_invars=(False,) * len(primals_and_nz_cts_in), name=name, keep_unused=keep_unused, inline=inline) + if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) _set_states(attrs_tracked, final_states) + return tree_unflatten(cts_out_treedef, nz_cts_out) ad.reducing_transposes[pjit_p] = _pjit_transpose @@ -2029,6 +2170,8 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn jaxpr=dced_jaxpr, in_shardings=keep_where(eqn_params["in_shardings"], used_inputs), out_shardings=keep_where(eqn_params["out_shardings"], used_outputs), + in_layouts=keep_where(eqn_params["in_layouts"], used_inputs), + out_layouts=keep_where(eqn_params["out_layouts"], used_outputs), donated_invars=keep_where(eqn_params["donated_invars"], used_inputs), ) if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects: @@ -2112,6 +2255,10 @@ def _pjit_pp_rule(eqn, context, settings): del params['in_shardings'] if all(is_unspecified(s) for s in params['out_shardings']): del params['out_shardings'] + if all(l is None for l in params['in_layouts']): + del params['in_layouts'] + if all(l is None for l in params['out_layouts']): + del params['out_layouts'] if not params['keep_unused']: del params['keep_unused'] if (params['resource_env'] is None or @@ -2126,18 +2273,28 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule def _pjit_state_discharge_rule( - in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params): + in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, + in_layouts, out_layouts, **params): if not (all(map(is_unspecified, in_shardings)) and - all(map(is_unspecified, out_shardings))): raise NotImplementedError + all(map(is_unspecified, out_shardings))): + raise NotImplementedError + + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + jaxpr, consts = jaxpr.jaxpr, jaxpr.consts num_outs = len(jaxpr.outvars) discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts) discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars) new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars) + new_in_layouts = (None,) * len(discharged_jaxpr.invars) + new_out_layouts = (None,) * len(discharged_jaxpr.outvars) out_and_ref_vals = pjit_p.bind( *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, - out_shardings=new_out_shardings, **params) + out_shardings=new_out_shardings, in_layouts=new_in_layouts, + out_layouts=new_out_layouts, **params) out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) ref_vals_iter = iter(ref_vals) new_invals = tuple(next(ref_vals_iter) if isinstance(aval, state_discharge.AbstractRef) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 0e0d19d97..2ffb8515c 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -88,12 +88,10 @@ class Executable(Protocol): """ raise NotImplementedError - # Layouts are exposed via jax.experimental.layouts - # TODO(frostig,yashkatariya): expose here when no longer experimental. - def _input_layouts(self): + def input_layouts(self): raise NotImplementedError - def _output_layouts(self): + def output_layouts(self): raise NotImplementedError def as_text(self) -> str: @@ -228,11 +226,11 @@ class XlaExecutable(Executable): raise NotImplementedError( "compiled executable carries no output sharding information") - def _input_layouts(self): + def input_layouts(self): raise NotImplementedError( "compiled executable carries no input layout information") - def _output_layouts(self): + def output_layouts(self): raise NotImplementedError( "compiled executable carries no input layout information") @@ -511,7 +509,7 @@ class Compiled(Stage): shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error - def _input_layouts(self): + def input_layouts(self): layouts_flat = self._executable.input_layouts() assert all(isinstance(l, Layout) for l in layouts_flat) # Some input layouts got DCE'd @@ -521,7 +519,7 @@ class Compiled(Stage): else Layout() for i in range(self.in_tree.num_leaves)] return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error - def _output_layouts(self): + def output_layouts(self): layouts_flat = self._executable.output_layouts() assert all(isinstance(l, Layout) for l in layouts_flat) return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 15c4497d4..f46e9b1fe 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1643,6 +1643,8 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], eqn.params["out_shardings"] + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) ), + in_layouts=(eqn.params["in_layouts"] + (None, None)), + out_layouts=(eqn.params["out_layouts"] + (None, None)), ), ) ) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 84d67ae44..521e95d22 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3504,6 +3504,7 @@ def _pjit(*args: TfVal, jaxpr: core.ClosedJaxpr, in_shardings: Sequence[sharding.XLACompatibleSharding], out_shardings: Sequence[sharding.XLACompatibleSharding], + in_layouts, out_layouts, resource_env: mesh.ResourceEnv, donated_invars, name: str, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 2e304f176..ac23debd6 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -739,6 +739,8 @@ def _pjit_jet_rule(primals_in, series_in, **params): params['out_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_out ), + 'in_layouts': params['in_layouts'] + (None,) * num_series_in, + 'out_layouts': params['out_layouts'] + (None,) * num_series_out, 'donated_invars': params['donated_invars'] + (False,) * num_series_in, } result = pjit.pjit_p.bind(*primals_and_series, **new_params) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 19f4ca736..a6f98ed59 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -772,7 +772,8 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, - resource_env, donated_invars, name, keep_unused, inline): + in_layouts, out_layouts, resource_env, donated_invars, name, + keep_unused, inline): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -790,12 +791,20 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, sharding_impls.UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings)) ) + in_layouts = in_layouts + tuple( + None for _ in range(len(args_flat) - len(in_layouts)) + ) + out_layouts = out_layouts + tuple( + None for _ in range(len(sp_call_jaxpr.out_avals) - len(out_layouts)) + ) out_flat = pjit.pjit_p.bind( *args_flat, jaxpr=sp_call_jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, + in_layouts=in_layouts, + out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, diff --git a/tests/export_test.py b/tests/export_test.py index 144956475..784351487 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -451,8 +451,8 @@ class JaxExportTest(jtu.JaxTestCase): self.assertIn("jax.uses_shape_polymorphism = true", module_str) wrapped_main_expected_re = ( r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"h\"}.*" - r"%arg1: tensor {jax.global_constant = \"w\"}.*" + r"%arg0: tensor {jax.global_constant = \"h\".*" + r"%arg1: tensor {jax.global_constant = \"w\".*" r"%arg2: tensor<\?x\?xf32>" ) self.assertRegex(module_str, wrapped_main_expected_re) @@ -1238,12 +1238,12 @@ class JaxExportTest(jtu.JaxTestCase): mlir_module_str = str(exp.mlir_module()) wrapped_main_expected_re = ( r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"b1\"}.*, " - r"%arg1: tensor {jax.global_constant = \"b2\"}.*, " - r"%arg2: !stablehlo.token {jax.token = true}.*, " + r"%arg0: tensor {jax.global_constant = \"b1\".* " + r"%arg1: tensor {jax.global_constant = \"b2\".* " + r"%arg2: !stablehlo.token {jax.token = true.* " r"%arg3: tensor<\?x\?xf32>.*\) -> \(" # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) @@ -1254,10 +1254,10 @@ class JaxExportTest(jtu.JaxTestCase): else: main_expected_re = ( r"@main\(" - r"%arg0: !stablehlo.token {jax.token = true}.*, " + r"%arg0: !stablehlo.token {jax.token = true.*, " r"%arg1: tensor<\?x\?xf32>.*\) -> \(" # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) res = export.call_exported(exp)(x) @@ -1284,13 +1284,13 @@ class JaxExportTest(jtu.JaxTestCase): mlir_module_str = str(exp.mlir_module()) wrapped_main_expected_re = ( r"@_wrapped_jax_export_main\(" - r"%arg0: tensor {jax.global_constant = \"_platform_index\"}.*, " - r"%arg1: tensor {jax.global_constant = \"b1\"}.*, " - r"%arg2: tensor {jax.global_constant = \"b2\"}.*, " - r"%arg3: !stablehlo.token {jax.token = true}.*, " + r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " + r"%arg1: tensor {jax.global_constant = \"b1\".*, " + r"%arg2: tensor {jax.global_constant = \"b2\".*, " + r"%arg3: !stablehlo.token {jax.token = true.*, " r"%arg4: tensor<\?x\?xf32>.*\) -> \(" # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) @@ -1301,11 +1301,11 @@ class JaxExportTest(jtu.JaxTestCase): else: main_expected_re = ( r"@main\(" - r"%arg0: tensor {jax.global_constant = \"_platform_index\"}.*, " - r"%arg1: !stablehlo.token {jax.token = true}.*, " + r"%arg0: tensor {jax.global_constant = \"_platform_index\".*, " + r"%arg1: !stablehlo.token {jax.token = true.*, " r"%arg2: tensor<\?x\?xf32>.*\) -> \(" # Results - r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") + r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)") self.assertRegex(mlir_module_str, main_expected_re) res = export.call_exported(exp)(x) self.assertAllClose(10. + _testing_multi_platform_fun_expected(x), diff --git a/tests/layout_test.py b/tests/layout_test.py index f17ef7860..c2b87c1e6 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -88,22 +88,22 @@ class LayoutTest(jtu.JaxTestCase): sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1) sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2) - lowered_apply = jax.jit(apply).lower( - sds1, sds2, _in_layouts=Layout(DLL.AUTO), _out_layouts=Layout(DLL.AUTO)) + lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply._input_layouts() + arg_layouts, kw_layouts = compiled_apply.input_layouts() self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply._output_layouts()): + for i, o in zip(arg_layouts, compiled_apply.output_layouts()): self.assertEqual(extract_minor_to_major(i), extract_minor_to_major(o)[::-1]) - init_compiled = jax.jit(init).lower( - sds1, sds2, _out_layouts=arg_layouts).compile() + init_compiled = jax.jit( + init, out_shardings=arg_layouts).lower(sds1, sds2).compile() - for i, o in zip(init_compiled._input_layouts()[0], - init_compiled._output_layouts()): + for i, o in zip(init_compiled.input_layouts()[0], + init_compiled.output_layouts()): self.assertEqual(i, o) arr1 = jax.device_put(np_inp1, s1) @@ -114,16 +114,16 @@ class LayoutTest(jtu.JaxTestCase): init_compiled(arr1, arr2) self.assertEqual(init_count[0], 1) - self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0]) - self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1]) + self.assertEqual(init_out[0].layout, init_compiled.output_layouts()[0]) + self.assertEqual(init_out[1].layout, init_compiled.output_layouts()[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count[0], 1) - self.assertEqual(apply_out[0].layout, compiled_apply._output_layouts()[0]) - self.assertEqual(apply_out[1].layout, compiled_apply._output_layouts()[1]) + self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0]) + self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1]) self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout), extract_minor_to_major(init_out[0].layout)[::-1]) @@ -146,24 +146,29 @@ class LayoutTest(jtu.JaxTestCase): def f(x): return x.T - lowered = jax.jit(f).lower(sds, _in_layouts=None, _out_layouts=None) + lowered = jax.jit(f, in_shardings=None, out_shardings=None).lower(sds) self.assertIn("default", lowered.as_text()) compiled = lowered.compile() out = compiled(arr) self.assertTupleEqual( - extract_minor_to_major(compiled._input_layouts()[0][0]), (2, 1, 0)) + extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled._output_layouts()), (2, 1, 0)) + extract_minor_to_major(compiled.output_layouts()), (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - compiled_auto = jax.jit(f).lower(sds, _in_layouts=Layout(DLL.AUTO), - _out_layouts=Layout(DLL.AUTO)).compile() + compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0)) + extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled_auto._output_layouts()), (0, 1, 2)) + extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2)) + + with self.assertRaisesRegex( + ValueError, "jax.jit` does not accept device-local layouts directly"): + jax.jit(f, in_shardings=DLL.AUTO, + out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -175,16 +180,16 @@ class LayoutTest(jtu.JaxTestCase): def f(x): return x.T - compiled = jax.jit(f).lower( - arr, _in_layouts=Layout(), _out_layouts=Layout(DLL.AUTO)).compile() + compiled = jax.jit(f, in_shardings=Layout(), + out_shardings=Layout(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0)) + extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled._output_layouts()), (0, 1)) + extract_minor_to_major(compiled.output_layouts()), (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout, compiled._output_layouts()) + self.assertEqual(out.layout, compiled.output_layouts()) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): @@ -193,14 +198,13 @@ class LayoutTest(jtu.JaxTestCase): np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) - compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower( - np_inp, _in_layouts=Layout(DLL.AUTO), - _out_layouts=Layout(DLL.AUTO)).compile() + compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s), + out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0)) + extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0)) self.assertTupleEqual( - extract_minor_to_major(compiled._output_layouts()), (0, 1)) + extract_minor_to_major(compiled.output_layouts()), (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -210,15 +214,15 @@ class LayoutTest(jtu.JaxTestCase): shape = (8, 2) inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 - compiled = jax.jit(f).lower(*inps, _in_layouts=Layout(DLL.AUTO), - _out_layouts=Layout(DLL.AUTO)).compile() - arg_layouts, _ = compiled._input_layouts() + compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() + arg_layouts, _ = compiled.input_layouts() out1, out2 = compiled(*inps) - compiled2 = jax.jit(f).lower(*inps, _in_layouts=arg_layouts).compile() + compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2._input_layouts()[0]): + for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts()[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) @@ -244,11 +248,10 @@ class LayoutTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, 'Layout passed to jit does not match the layout on the respective arg'): - jax.jit(f).lower(arr, _in_layouts=Layout(DLL.AUTO)) + jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr) - compiled = jax.jit(f).lower( - sds, _in_layouts=Layout(DLL.AUTO), - _out_layouts=Layout(DLL.AUTO)).compile() + compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), + out_shardings=Layout(DLL.AUTO)).lower(sds).compile() with self.assertRaisesRegex( ValueError, @@ -271,8 +274,8 @@ class LayoutTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, s) compiled = jax.jit( - lambda x: x * 2).lower(arr, _out_layouts=Layout(DLL.AUTO)).compile() - col = compiled._output_layouts() + lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + col = compiled.output_layouts() out = jax.device_put(np_inp, col) self.assertEqual(out.layout, col) @@ -304,7 +307,7 @@ class LayoutTest(jtu.JaxTestCase): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Layout(compiled._output_layouts()[0], None) + Layout(compiled.output_layouts()[0], None) if __name__ == '__main__': diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index ca771c254..dbc3ad47f 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -1500,7 +1500,7 @@ class PallasCallInputOutputAliasingTest(PallasTPUTest): )(x) o = f(x) np.testing.assert_array_equal(o, expected) - compiled = f.lower(x).compile() + compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile() mem_analysis = compiled.memory_analysis() expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) @@ -1528,7 +1528,7 @@ class PallasCallInputOutputAliasingTest(PallasTPUTest): )(jnp.array([1,2,3]), x) o = f(x) np.testing.assert_array_equal(o, expected) - compiled = f.lower(x).compile() + compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile() mem_analysis = compiled.memory_analysis() expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes)