mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add Layout
support to jax.jit
.
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere. Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding. Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU). PiperOrigin-RevId: 622352537
This commit is contained in:
parent
f88139bf67
commit
c125442644
@ -34,7 +34,6 @@ from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
from jax._src import layout
|
||||
from jax._src import profiler
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge
|
||||
@ -47,6 +46,7 @@ from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
||||
device_replica_id_map, hashed_index)
|
||||
from jax._src.layout import DeviceLocalLayout, Layout
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method
|
||||
|
||||
@ -529,15 +529,17 @@ class ArrayImpl(basearray.Array):
|
||||
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
|
||||
return out
|
||||
|
||||
@property
|
||||
@functools.cached_property
|
||||
def layout(self):
|
||||
# TODO(yashkatariya): Remove the deleted check from here.
|
||||
if self.is_deleted():
|
||||
return Layout(None, self.sharding)
|
||||
try:
|
||||
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
|
||||
self.sharding)
|
||||
return Layout(DeviceLocalLayout(self._pjrt_layout), self.sharding)
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
return layout.Layout(None, self.sharding)
|
||||
return Layout(None, self.sharding)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
@ -895,9 +895,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
error_checks[lax.while_p] = while_loop_error_check
|
||||
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
inline, keep_unused):
|
||||
in_shardings, out_shardings,
|
||||
in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, inline, keep_unused):
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
@ -908,10 +908,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
# Update pjit params to account for extra error values.
|
||||
num_error_vals = len(err_vals)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
|
||||
sharding = sharding_impls.UNSPECIFIED
|
||||
|
||||
sharding = sharding_impls.UNSPECIFIED
|
||||
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
||||
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
|
||||
new_in_layouts = (*[None] * num_error_vals, *in_layouts)
|
||||
new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)
|
||||
new_donated_invars = (*[False] * num_error_vals, *donated_invars)
|
||||
|
||||
err_and_out = pjit.pjit_p.bind(
|
||||
@ -919,6 +921,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
jaxpr=checked_jaxpr,
|
||||
in_shardings=new_in_shardings,
|
||||
out_shardings=new_out_shardings,
|
||||
in_layouts=new_in_layouts,
|
||||
out_layouts=new_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=new_donated_invars,
|
||||
name=name,
|
||||
|
@ -452,10 +452,7 @@ def _device_put_impl(
|
||||
return x
|
||||
if x_dll is None and dll is None:
|
||||
return _device_put_sharding_impl(x, aval, l.sharding)
|
||||
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
|
||||
# out_layouts from lower.
|
||||
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
|
||||
x, _out_layouts=l).compile()(x)
|
||||
return api.jit(_identity_fn, out_shardings=l)(x)
|
||||
|
||||
return _device_put_sharding_impl(x, aval, device)
|
||||
|
||||
|
@ -954,11 +954,6 @@ def lower_jaxpr_to_module(
|
||||
else:
|
||||
dim_vars = ()
|
||||
|
||||
arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
|
||||
else in_layouts)
|
||||
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
|
||||
else out_layouts)
|
||||
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
platforms=platforms, axis_context=axis_context,
|
||||
keepalives=keepalives,
|
||||
@ -992,8 +987,8 @@ def lower_jaxpr_to_module(
|
||||
result_names=result_names,
|
||||
arg_memory_kinds=arg_memory_kinds,
|
||||
result_memory_kinds=result_memory_kinds,
|
||||
arg_layouts=arg_layouts,
|
||||
result_layouts=result_layouts)
|
||||
arg_layouts=in_layouts,
|
||||
result_layouts=out_layouts)
|
||||
|
||||
try:
|
||||
if not ctx.module.operation.verify():
|
||||
@ -1140,8 +1135,8 @@ def lower_jaxpr_to_fun(
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
arg_memory_kinds: Sequence[str | None] | None = None,
|
||||
result_memory_kinds: Sequence[str | None] | None = None,
|
||||
arg_layouts: Sequence[str | None] | None = None,
|
||||
result_layouts: Sequence[str | None] | None = None,
|
||||
arg_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
result_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
|
||||
) -> func_dialect.FuncOp:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
@ -1262,7 +1257,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_arg_layouts = None
|
||||
if arg_layouts is not None:
|
||||
ir_arg_layouts = util.flatten(
|
||||
[[l] * len(types) for l, types in zip(arg_layouts, input_types)])
|
||||
[[_to_xla_layout(l)] * len(types)
|
||||
for l, types in zip(arg_layouts, input_types)])
|
||||
|
||||
ir_donated_args = None
|
||||
if xla_donated_args is not None:
|
||||
@ -1285,7 +1281,8 @@ def lower_jaxpr_to_fun(
|
||||
ir_result_layouts = None
|
||||
if result_layouts is not None:
|
||||
ir_result_layouts = util.flatten(
|
||||
[[l] * len(types) for l, types in zip(result_layouts, output_types)])
|
||||
[[_to_xla_layout(l)] * len(types)
|
||||
for l, types in zip(result_layouts, output_types)])
|
||||
|
||||
if (
|
||||
replicated_args is not None
|
||||
|
@ -2035,15 +2035,15 @@ def lower_sharding_computation(
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[MaybeSharding],
|
||||
out_shardings: Sequence[MaybeSharding],
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
devices_from_context: Sequence[xc.Device] | None = None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
lowering_parameters: mlir.LoweringParameters
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -3266,8 +3266,9 @@ def check_array_xla_sharding_layout_match(
|
||||
arg.layout.device_local_layout != xl):
|
||||
errors.append(
|
||||
("Got input layout(s) that compiled object was called with: "
|
||||
f"{arg.layout} and layout(s) the computation was compiled "
|
||||
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
f"{arg.layout.device_local_layout} and layout(s) the computation was "
|
||||
f"compiled with: {xl} for arg {name} with "
|
||||
f"shape: {arg.aval.str_short()}",
|
||||
'layout'))
|
||||
|
||||
if errors:
|
||||
|
@ -714,9 +714,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
return pxla.lower_sharding_computation(
|
||||
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
|
||||
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
|
||||
(None,) * len(in_avals), (None,) * len(out_avals),
|
||||
donated_invars, in_avals, keep_unused=True, inline=False,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters,
|
||||
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
393
jax/_src/pjit.py
393
jax/_src/pjit.py
@ -53,7 +53,6 @@ from jax._src.errors import JAXTypeError
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -67,13 +66,13 @@ from jax._src.sharding_impls import (
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
from jax._src.layout import Layout, LayoutOptions
|
||||
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
|
||||
from jax._src.state import discharge as state_discharge, RefEffect
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import (
|
||||
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
|
||||
PyTreeDef)
|
||||
PyTreeDef, none_leaf_registry as none_lr)
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, weakref_lru_cache,
|
||||
@ -150,6 +149,10 @@ class PjitInfo(NamedTuple):
|
||||
in_shardings_leaves: tuple[Any, ...]
|
||||
out_shardings_treedef: PyTreeDef
|
||||
out_shardings_leaves: tuple[Any, ...]
|
||||
in_layouts_treedef: PyTreeDef
|
||||
in_layouts_leaves: tuple[Any, ...]
|
||||
out_layouts_treedef: PyTreeDef
|
||||
out_layouts_leaves: tuple[Any, ...]
|
||||
static_argnums: tuple[int, ...]
|
||||
static_argnames: tuple[str, ...]
|
||||
donate_argnums: tuple[int, ...]
|
||||
@ -164,8 +167,9 @@ class PjitInfo(NamedTuple):
|
||||
|
||||
|
||||
def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
|
||||
_infer_params(jit_info, args, kwargs)
|
||||
(args_flat, _, params, _, out_tree, _, arg_names,
|
||||
attrs_tracked) = _infer_params(jit_info, args, kwargs)
|
||||
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
|
||||
@ -202,6 +206,7 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
if attrs_tracked:
|
||||
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
|
||||
_set_states(attrs_tracked, final_states)
|
||||
|
||||
outs = tree_unflatten(out_tree, out_flat)
|
||||
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
|
||||
|
||||
@ -335,6 +340,30 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device,
|
||||
any(not is_unspecified(i) for i in out_shardings_flat))
|
||||
|
||||
|
||||
def _split_layout_and_sharding(entries):
|
||||
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
|
||||
layouts, shardings = [], []
|
||||
|
||||
for e in entries_flat:
|
||||
if e is None or is_unspecified_or_auto(e):
|
||||
layouts.append(None)
|
||||
shardings.append(e)
|
||||
elif isinstance(e, Layout):
|
||||
layouts.append(e.device_local_layout)
|
||||
shardings.append(e.sharding)
|
||||
elif isinstance(e, (DeviceLocalLayout, AutoLayout)):
|
||||
raise ValueError(
|
||||
'`jax.jit` does not accept device-local layouts directly. Create '
|
||||
'a `Layout` instance wrapping this device-local layout and pass '
|
||||
f'that to `jit` instead. Got {e}')
|
||||
else:
|
||||
layouts.append(None)
|
||||
shardings.append(e)
|
||||
|
||||
assert len(layouts) == len(shardings)
|
||||
return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings)
|
||||
|
||||
|
||||
def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
donate_argnums: int | Sequence[int] | None,
|
||||
donate_argnames: str | Iterable[str] | None,
|
||||
@ -378,16 +407,19 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_shardings = tuple(in_shardings)
|
||||
|
||||
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
|
||||
out_layouts, out_shardings = _split_layout_and_sharding(out_shardings)
|
||||
|
||||
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
|
||||
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
|
||||
user_specified_in_shardings = (in_shardings is not None and
|
||||
not is_unspecified(in_shardings))
|
||||
none_leaf_registry = tree_util.none_leaf_registry
|
||||
in_shardings_leaves, in_shardings_treedef = none_leaf_registry.flatten(
|
||||
in_shardings)
|
||||
out_shardings_leaves, out_shardings_treedef = none_leaf_registry.flatten(
|
||||
out_shardings)
|
||||
|
||||
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
|
||||
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
|
||||
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
|
||||
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
|
||||
|
||||
fun_sourceinfo = api_util.fun_sourceinfo(fun)
|
||||
fun_signature = api_util.fun_signature(fun)
|
||||
@ -408,6 +440,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
in_shardings_leaves=tuple(in_shardings_leaves),
|
||||
out_shardings_treedef=out_shardings_treedef,
|
||||
out_shardings_leaves=tuple(out_shardings_leaves),
|
||||
in_layouts_treedef=in_layouts_treedef,
|
||||
in_layouts_leaves=tuple(in_layouts_leaves),
|
||||
out_layouts_treedef=out_layouts_treedef,
|
||||
out_layouts_leaves=tuple(out_layouts_leaves),
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
@ -417,37 +453,58 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
use_resource_env=use_resource_env)
|
||||
|
||||
|
||||
# TODO(yashkatariya): Delete this function once internal users migrate off of
|
||||
# the deprecated AOT API.
|
||||
def _handle_layouts_in_aot(jit_info: PjitInfo, kwargs):
|
||||
if '_in_layouts' in kwargs or '_out_layouts' in kwargs:
|
||||
warnings.warn(
|
||||
'Passing `_in_layouts` and `_out_layouts` to `.lower` is deprecated and'
|
||||
' will be removed soon. Please pass your `Layout` instances to'
|
||||
' `in_shardings` and `out_shardings` arguments of `jax.jit`',
|
||||
DeprecationWarning)
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
in_layouts, _ = _split_layout_and_sharding(in_layouts)
|
||||
out_layouts, _ = _split_layout_and_sharding(out_layouts)
|
||||
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
|
||||
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
|
||||
return jit_info._replace(in_layouts_treedef=in_layouts_treedef,
|
||||
in_layouts_leaves=tuple(in_layouts_leaves),
|
||||
out_layouts_treedef=out_layouts_treedef,
|
||||
out_layouts_leaves=tuple(out_layouts_leaves))
|
||||
return jit_info
|
||||
|
||||
|
||||
def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
# TODO(yashkatariya): Remove this when it's added on jit.
|
||||
in_layouts = kwargs.pop('_in_layouts', Layout())
|
||||
out_layouts = kwargs.pop('_out_layouts', Layout())
|
||||
# TODO(yashkatariya): Remove this handling once internal users migrate off
|
||||
# of the deprecated API
|
||||
new_jit_info = _handle_layouts_in_aot(jit_info, kwargs)
|
||||
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars, in_layouts_flat, out_layouts_flat,
|
||||
arg_names, ()) = _infer_params(
|
||||
jit_info, args, kwargs, in_layouts=in_layouts, out_layouts=out_layouts)
|
||||
donated_invars, arg_names, ()) = _infer_params(new_jit_info, args, kwargs)
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
try:
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args_flat, params['in_shardings'], params['out_shardings'], mesh)
|
||||
in_layouts_flat = _resolve_in_layouts(
|
||||
args_flat, in_layouts_flat, in_shardings)
|
||||
out_layouts_flat = _resolve_out_layouts(out_layouts_flat)
|
||||
in_layouts = _resolve_in_layouts(
|
||||
args_flat, params['in_layouts'], in_shardings,
|
||||
params['jaxpr'].in_avals)
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
in_layouts, params['out_layouts'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
params['keep_unused'], params['inline'], in_layouts=in_layouts_flat,
|
||||
out_layouts=out_layouts_flat, lowering_parameters=lowering_parameters)
|
||||
params['keep_unused'], params['inline'],
|
||||
lowering_parameters=lowering_parameters)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
fun = jit_info.fun
|
||||
fun = new_jit_info.fun
|
||||
fun_name = getattr(fun, '__qualname__',
|
||||
getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
@ -461,19 +518,18 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
_, _, params, _, out_tree, _, _, _, _, _ = _infer_params(
|
||||
jit_info, args, kwargs, in_layouts=None, out_layouts=None
|
||||
)
|
||||
out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s)
|
||||
for s in params['out_shardings']]
|
||||
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
|
||||
out_s = [None if is_unspecified(s) else s for s in params['out_shardings']]
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
|
||||
for x, s in zip(params['jaxpr'].out_avals, out_s)]
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
wrapped.lower = lower
|
||||
wrapped.eval_shape = eval_shape
|
||||
return wrapped
|
||||
|
||||
|
||||
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
donate_argnums: int | Sequence[int] | None,
|
||||
donate_argnames: str | Iterable[str] | None,
|
||||
@ -490,10 +546,11 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
return _make_jit_wrapper(jit_info)
|
||||
|
||||
|
||||
def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
|
||||
def _infer_params(jit_info, args, kwargs):
|
||||
(fun, fun_sourceinfo, fun_signature, user_specified_in_shardings,
|
||||
in_shardings_treedef, in_shardings_leaves, out_shardings_treedef,
|
||||
out_shardings_leaves, static_argnums, static_argnames,
|
||||
out_shardings_leaves, in_layouts_treedef, in_layouts_leaves,
|
||||
out_layouts_treedef, out_layouts_leaves, static_argnums, static_argnames,
|
||||
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
|
||||
abstracted_axes, _, use_resource_env) = jit_info
|
||||
|
||||
@ -576,17 +633,18 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
|
||||
) from e
|
||||
in_type = in_avals = tuple(avals)
|
||||
|
||||
canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
in_shardings_treedef, in_shardings_leaves, hashable_pytree(in_layouts),
|
||||
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
|
||||
in_shardings_treedef, in_shardings_leaves,
|
||||
in_layouts_treedef, in_layouts_leaves,
|
||||
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
|
||||
|
||||
jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
|
||||
jaxpr, consts, out_shardings_flat, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
|
||||
flat_fun, out_shardings_treedef, out_shardings_leaves,
|
||||
hashable_pytree(out_layouts), in_type, dbg, device_or_backend_set,
|
||||
HashableFunction(out_tree, closure=()),
|
||||
out_layouts_treedef, out_layouts_leaves, in_type, dbg,
|
||||
device_or_backend_set, HashableFunction(out_tree, closure=()),
|
||||
HashableFunction(res_paths, closure=()), inline)
|
||||
|
||||
assert len(explicit_args) == len(canonicalized_in_shardings_flat) == len(in_layouts_flat)
|
||||
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
|
||||
|
||||
if config.dynamic_shapes.value:
|
||||
implicit_args = _extract_implicit_args(in_type, explicit_args)
|
||||
@ -595,18 +653,19 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
|
||||
args_flat = [*implicit_args, *explicit_args]
|
||||
|
||||
num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts)
|
||||
canonicalized_in_shardings_flat = \
|
||||
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
|
||||
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
|
||||
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
|
||||
donated_invars = (False,) * num_extra_args + donated_invars
|
||||
assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) ==
|
||||
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
|
||||
len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat))
|
||||
|
||||
# in_shardings and out_shardings here are all GSPMDSharding.
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_shardings=canonicalized_in_shardings_flat,
|
||||
out_shardings=out_shardings,
|
||||
in_shardings=in_shardings_flat,
|
||||
out_shardings=out_shardings_flat,
|
||||
in_layouts=in_layouts_flat,
|
||||
out_layouts=out_layouts_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unknown>'),
|
||||
@ -614,8 +673,7 @@ def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
|
||||
inline=inline,
|
||||
)
|
||||
return (consts + args_flat, in_type, params, in_tree, out_tree(),
|
||||
donated_invars, in_layouts_flat, out_layouts_flat,
|
||||
dbg.arg_names if dbg else None, attrs_tracked)
|
||||
donated_invars, dbg.arg_names if dbg else None, attrs_tracked)
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
||||
@ -973,8 +1031,8 @@ class PytreeLeaf:
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
in_layouts_thunk, in_avals,
|
||||
in_tree, debug_info,
|
||||
in_layouts_treedef, in_layouts_leaves,
|
||||
in_avals, in_tree, debug_info,
|
||||
device_or_backend_set, kws):
|
||||
if not kws:
|
||||
in_tree, _ = treedef_children(in_tree)
|
||||
@ -988,7 +1046,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
in_shardings_flat = flatten_axis_resources(
|
||||
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
|
||||
|
||||
in_layouts = in_layouts_thunk()
|
||||
in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves)
|
||||
if in_layouts is None:
|
||||
in_layouts_flat = (in_layouts,) * len(in_avals)
|
||||
else:
|
||||
@ -1001,7 +1059,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
||||
None if debug_info is None else debug_info.arg_names,
|
||||
"pjit arguments", allow_uneven_sharding=False)
|
||||
return in_shardings_flat, tuple(in_layouts_flat)
|
||||
return in_shardings_flat, in_layouts_flat
|
||||
|
||||
callsites: set[str] = set()
|
||||
|
||||
@ -1168,13 +1226,9 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_thunk, out_tree,
|
||||
out_type, debug_info, device_or_backend_set):
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, out_type, debug_info, device_or_backend_set):
|
||||
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
||||
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
|
||||
# instead. This condition exists because flatten_axis_resources passes in an
|
||||
# `object()` while unflattening which breaks assertion is user defined
|
||||
# pytrees (which shouldn't exist but they do).
|
||||
if (is_unspecified(orig_out_shardings) or
|
||||
isinstance(orig_out_shardings, XLACompatibleSharding)):
|
||||
out_shardings_flat = (orig_out_shardings,) * len(out_type)
|
||||
@ -1183,7 +1237,7 @@ def _check_and_canonicalize_out_shardings(
|
||||
"pjit out_shardings", out_tree(), orig_out_shardings,
|
||||
tupled_args=False)
|
||||
|
||||
out_layouts = out_layouts_thunk()
|
||||
out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves)
|
||||
if out_layouts is None:
|
||||
out_layouts_flat = (out_layouts,) * len(out_type)
|
||||
else:
|
||||
@ -1195,18 +1249,20 @@ def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_flat, out_type,
|
||||
None if debug_info is None else debug_info.result_paths,
|
||||
"pjit outputs", allow_uneven_sharding=False)
|
||||
return out_shardings_flat, tuple(out_layouts_flat)
|
||||
return out_shardings_flat, out_layouts_flat
|
||||
|
||||
|
||||
def _pjit_jaxpr(fun, out_shardings_treedef, out_shardings_leaves,
|
||||
out_layouts_thunk, in_type, debug_info, device_or_backend_set,
|
||||
out_tree, result_paths, inline):
|
||||
out_layouts_treedef, out_layouts_leaves, in_type, debug_info,
|
||||
device_or_backend_set, out_tree, result_paths, inline):
|
||||
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
|
||||
fun, in_type, debug_info, result_paths, IgnoreKey(inline))
|
||||
canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_thunk, out_tree, tuple(out_type),
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, tuple(out_type),
|
||||
jaxpr.jaxpr.debug_info, device_or_backend_set)
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked
|
||||
return (jaxpr, final_consts, canonicalized_out_shardings_flat,
|
||||
out_layouts_flat, attrs_tracked)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -1259,30 +1315,65 @@ pjit_p = core.AxisPrimitive("pjit")
|
||||
pjit_p.multiple_results = True
|
||||
|
||||
|
||||
def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
@lru_cache(maxsize=2048)
|
||||
def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval):
|
||||
if is_unspecified_or_auto(sharding):
|
||||
return None
|
||||
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return None
|
||||
if not core.is_constant_shape(aval.shape):
|
||||
return None
|
||||
shard_shape = sharding.shard_shape(aval.shape)
|
||||
d = sharding._device_assignment[0]
|
||||
# If a backend doesn't implement `get_default_layout` return `None` to avoid
|
||||
# cache misses. This can happen when you have `jit(f, in_shardings=s)`. On
|
||||
# first call you pass it a sharded array with layout and on second call you
|
||||
# pass a numpy array. The layouts should be the same to get cache hits.
|
||||
try:
|
||||
al = DeviceLocalLayout(
|
||||
d.client.get_default_layout(aval.dtype, shard_shape, d))
|
||||
except:
|
||||
return None
|
||||
# argument does not have `.layout` property. ShapedArray, ShapedDtypeStruct,
|
||||
# numpy array, etc are some examples.
|
||||
if arg_layout is None:
|
||||
return al if jit_in_layout is None else arg_layout # arg_layout is None
|
||||
# If arg has a `.layout` property, then return device_local_layout as is.
|
||||
return arg_layout.device_local_layout
|
||||
|
||||
|
||||
def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
# If device or backend is set, return the default layout. This is because you
|
||||
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
|
||||
# which causes error checks to fail. Returning the default layout allows
|
||||
# this to exist. It's the same for handling shardings.
|
||||
if pxla.check_device_backend_on_shardings(jit_in_shardings):
|
||||
if pxla.check_device_backend_on_shardings(resolved_in_shardings):
|
||||
return (None,) * len(jit_in_layouts)
|
||||
|
||||
resolved_in_layouts = []
|
||||
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
|
||||
for arg, jit_in_l, rs, aval in safe_zip(
|
||||
args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
arg_layout, committed = (
|
||||
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
|
||||
if getattr(arg, 'layout', None) is not None else (None, False))
|
||||
jit_in_l = None if jit_in_l is None else jit_in_l.device_local_layout
|
||||
_maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, rs, aval),
|
||||
getattr(arg, '_committed', True))
|
||||
# Sharding can be unspecified when array is committed if it's a PmapSharding.
|
||||
is_pmap_sharding = (is_unspecified(rs) or
|
||||
isinstance(getattr(arg, 'sharding', None), PmapSharding))
|
||||
if jit_in_l is None:
|
||||
if committed:
|
||||
resolved_in_layouts.append(arg_layout)
|
||||
if is_pmap_sharding:
|
||||
resolved_in_layouts.append(None)
|
||||
else:
|
||||
resolved_in_layouts.append(arg_layout)
|
||||
else:
|
||||
resolved_in_layouts.append(None)
|
||||
else:
|
||||
# arg_layout can be None because some backends don't implement the
|
||||
# required layout methods. Hence `arr.layout` can return
|
||||
# `Layout(None, sharding)`
|
||||
if committed and arg_layout is not None and arg_layout != jit_in_l:
|
||||
if (committed and not is_pmap_sharding and
|
||||
arg_layout is not None and arg_layout != jit_in_l):
|
||||
raise ValueError('Layout passed to jit does not match the layout '
|
||||
'on the respective arg. '
|
||||
f'Got pjit layout: {jit_in_l},\n'
|
||||
@ -1292,13 +1383,6 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
|
||||
return tuple(resolved_in_layouts)
|
||||
|
||||
|
||||
def _resolve_out_layouts(out_layouts: Sequence[Layout]
|
||||
) -> Sequence[LayoutOptions]:
|
||||
# TODO(yashkatariya): Remove the if condition when all layouts come via the
|
||||
# `layout.Layout` API or handle this properly when layout is on jit.
|
||||
return tuple(None if o is None else o.device_local_layout for o in out_layouts)
|
||||
|
||||
|
||||
def _resolve_in_shardings(
|
||||
args, pjit_in_shardings: Sequence[PjitSharding],
|
||||
out_shardings: Sequence[PjitSharding],
|
||||
@ -1335,8 +1419,10 @@ def _resolve_in_shardings(
|
||||
pxla._get_and_check_device_assignment(
|
||||
it.chain(
|
||||
util.stable_unique(committed_arg_shardings),
|
||||
((i, pxla.MismatchType.IN_SHARDING, None) for i in util.stable_unique(pjit_in_shardings)),
|
||||
((o, pxla.MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings))),
|
||||
((i, pxla.MismatchType.IN_SHARDING, None)
|
||||
for i in util.stable_unique(pjit_in_shardings)),
|
||||
((o, pxla.MismatchType.OUT_SHARDING, None)
|
||||
for o in util.stable_unique(out_shardings))),
|
||||
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
|
||||
resolved_in_shardings = []
|
||||
@ -1405,16 +1491,17 @@ def _resolve_in_shardings(
|
||||
|
||||
|
||||
def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars,
|
||||
name, keep_unused, inline):
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals)
|
||||
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline,
|
||||
lowering_parameters=mlir.LoweringParameters()).compile()
|
||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||
@ -1434,6 +1521,8 @@ def _pjit_call_impl_python(
|
||||
distributed_debug_log(("Running pjit'd function", name),
|
||||
("in_shardings", in_shardings),
|
||||
("out_shardings", out_shardings),
|
||||
("in_layouts", in_layouts),
|
||||
("out_layouts", out_layouts),
|
||||
("abstract args", map(xla.abstractify, args)),
|
||||
("fingerprint", fingerprint))
|
||||
try:
|
||||
@ -1465,8 +1554,9 @@ def _pjit_call_impl_python(
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, keep_unused, inline):
|
||||
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name,
|
||||
keep_unused, inline):
|
||||
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
|
||||
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
|
||||
# the jaxpr defeating the purpose of weakref_lru_cache. So return a function
|
||||
@ -1478,12 +1568,14 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, resource_env,
|
||||
|
||||
|
||||
def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env,
|
||||
donated_invars, name, keep_unused, inline):
|
||||
def call_impl_cache_miss(*args_, **kwargs_):
|
||||
out_flat, compiled = _pjit_call_impl_python(
|
||||
*args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, resource_env=resource_env,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
@ -1492,7 +1584,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
return out_flat, fastpath_data
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline)
|
||||
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
@ -1520,22 +1612,15 @@ def _pjit_lower_cached(
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings,
|
||||
out_shardings,
|
||||
in_layouts: pxla.MaybeLayout,
|
||||
out_layouts: pxla.MaybeLayout,
|
||||
resource_env,
|
||||
donated_invars,
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
*,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
in_layouts: pxla.MaybeLayout | None = None,
|
||||
out_layouts: pxla.MaybeLayout | None = None):
|
||||
# TODO(yashkatariya): Remove this when layouts are supported on jit and
|
||||
# passed to params.
|
||||
if in_layouts is None:
|
||||
in_layouts = (None,) * len(in_shardings)
|
||||
if out_layouts is None:
|
||||
out_layouts = (None,) * len(out_shardings)
|
||||
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
if resource_env is not None:
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
|
||||
@ -1558,18 +1643,19 @@ def _pjit_lower_cached(
|
||||
else:
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
tuple(donated_invars), tuple(jaxpr.in_avals),
|
||||
in_layouts, out_layouts, tuple(donated_invars), tuple(jaxpr.in_avals),
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
lowering_parameters=lowering_parameters, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts)
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
if (params["inline"] and
|
||||
all(is_unspecified(i) for i in params["in_shardings"]) and
|
||||
all(is_unspecified(o) for o in params["out_shardings"])):
|
||||
all(is_unspecified(o) for o in params["out_shardings"]) and
|
||||
all(i is None for i in params["in_layouts"]) and
|
||||
all(o is None for o in params["out_layouts"])):
|
||||
jaxpr = params['jaxpr']
|
||||
if config.dynamic_shapes.value:
|
||||
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
|
||||
@ -1598,14 +1684,16 @@ def pjit_staging_rule(trace, *args, **params):
|
||||
jaxpr, consts = pxla._move_mutable_consts(params['jaxpr'])
|
||||
consts = map(trace.instantiate_const, consts)
|
||||
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
|
||||
in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
|
||||
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
|
||||
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
donated_invars=donated_invars)
|
||||
in_layouts=in_layouts, donated_invars=donated_invars)
|
||||
return trace.default_process_primitive(pjit_p, (*args, *consts), new_params)
|
||||
else:
|
||||
return trace.default_process_primitive(pjit_p, args, params)
|
||||
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
|
||||
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
|
||||
# since it's actually not possible in general to infer the type from the term
|
||||
def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
|
||||
@ -1630,13 +1718,14 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
|
||||
core.custom_typechecks[pjit_p] = _pjit_typecheck
|
||||
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, **_):
|
||||
def _pjit_abstract_eval(*args, jaxpr, **_):
|
||||
return jaxpr.out_avals, jaxpr.effects
|
||||
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
|
||||
def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
out_shardings, api_name):
|
||||
out_shardings, in_layouts, out_layouts,
|
||||
api_name):
|
||||
mod_ctx = ctx.module_context
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
num_devices = None
|
||||
@ -1647,7 +1736,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
key = (pjit_p, name, jaxpr, effects, num_devices,
|
||||
pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals),
|
||||
pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals),
|
||||
api_name)
|
||||
in_layouts, out_layouts, api_name)
|
||||
|
||||
func = mod_ctx.cached_primitive_lowerings.get(key, None)
|
||||
if func is None:
|
||||
@ -1659,14 +1748,15 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
func = mlir.lower_jaxpr_to_fun(
|
||||
mod_ctx, name, jaxpr, effects, ctx.name_stack,
|
||||
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
||||
use_sharding_annotations=False, api_name=api_name)
|
||||
use_sharding_annotations=False, api_name=api_name,
|
||||
arg_layouts=in_layouts, result_layouts=out_layouts)
|
||||
mod_ctx.cached_primitive_lowerings[key] = func
|
||||
return func
|
||||
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
keep_unused, inline):
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, keep_unused, inline):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
@ -1674,7 +1764,8 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
|
||||
func = _pjit_cached_lower_jaxpr_to_fun(
|
||||
ctx, name, jaxpr, tuple(effects), in_shardings,
|
||||
out_shardings, api_name=('jit' if resource_env is None else 'pjit'))
|
||||
out_shardings, in_layouts, out_layouts,
|
||||
api_name=('jit' if resource_env is None else 'pjit'))
|
||||
|
||||
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
|
||||
args = (*ctx.dim_var_values, *tokens_in, *args)
|
||||
@ -1693,7 +1784,7 @@ mlir.register_lowering(pjit_p, _pjit_lowering)
|
||||
def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
axis_size, axis_name, main_type,
|
||||
vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
||||
new_jaxpr, axes_out = batching.batch_jaxpr2(
|
||||
@ -1718,16 +1809,24 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
_pjit_batcher_for_sharding(o, axis_out, new_parts, mesh, aval.ndim)
|
||||
if axis_out is not None else o
|
||||
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
|
||||
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
||||
if not (all(l is None for l in in_layouts) and
|
||||
all(l is None for l in out_layouts)):
|
||||
raise NotImplementedError
|
||||
|
||||
vals_out = pjit_p.bind(
|
||||
*vals_in,
|
||||
jaxpr=new_jaxpr,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
|
||||
vals_in, vals_out, axes_out)
|
||||
return vals_out, resolved_axes_out
|
||||
@ -1773,7 +1872,7 @@ def _pjit_batcher_for_sharding(
|
||||
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
@ -1788,6 +1887,8 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr=jaxpr_jvp,
|
||||
in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)),
|
||||
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
|
||||
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
|
||||
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
||||
name=name,
|
||||
@ -1813,7 +1914,8 @@ def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
||||
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
name, keep_unused, inline):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
@ -1824,25 +1926,31 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
known_outs = tuple(not uk for uk in unknown_outs)
|
||||
num_residuals = len(res_avals)
|
||||
res_shardings = (UNSPECIFIED,) * num_residuals
|
||||
res_layouts = (None,) * num_residuals
|
||||
|
||||
def keep_where(l, should_keep):
|
||||
return tuple(x for x, keep in zip(l, should_keep) if keep)
|
||||
|
||||
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
|
||||
known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts
|
||||
|
||||
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
|
||||
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
|
||||
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
|
||||
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
|
||||
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
|
||||
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
|
||||
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
|
||||
] + in_fwd_res
|
||||
in_fwd = [
|
||||
fwd if is_unspecified(os) and ol is None else None
|
||||
for os, ol, fwd in zip(
|
||||
keep_where(out_shardings, known_outs),
|
||||
keep_where(out_layouts, known_outs), in_fwd_primal)
|
||||
] + in_fwd_res
|
||||
del in_fwd_primal, in_fwd_res
|
||||
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
|
||||
keep = [f is None for f in in_fwd]
|
||||
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
|
||||
known_out_shardings = keep_where(known_out_shardings, keep)
|
||||
known_out_layouts = keep_where(known_out_layouts, keep)
|
||||
# Update num_out_primals to reflect pruning.
|
||||
kept_primals, kept_res = split_list(keep, [num_out_primals])
|
||||
num_out_primals = sum(kept_primals)
|
||||
@ -1856,14 +1964,18 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
keep = [f is None for f in out_fwd]
|
||||
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
|
||||
known_out_shardings = keep_where(known_out_shardings, keep)
|
||||
known_out_layouts = keep_where(known_out_layouts, keep)
|
||||
del keep
|
||||
|
||||
known_params = dict(
|
||||
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
|
||||
out_shardings=known_out_shardings, resource_env=resource_env,
|
||||
out_shardings=known_out_shardings,
|
||||
in_layouts=keep_where(in_layouts, known_ins),
|
||||
out_layouts=known_out_layouts, resource_env=resource_env,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
name=name, keep_unused=keep_unused, inline=inline)
|
||||
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
||||
assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals)
|
||||
|
||||
# Bind known things to pjit_p.
|
||||
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
|
||||
@ -1888,6 +2000,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr=unknown_jaxpr,
|
||||
in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings),
|
||||
out_shardings=keep_where(out_shardings, unknown_outs),
|
||||
in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts),
|
||||
out_layouts=keep_where(out_layouts, unknown_outs),
|
||||
resource_env=resource_env,
|
||||
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
||||
(False,) * num_residuals),
|
||||
@ -1921,28 +2035,41 @@ def _pjit_partial_eval_custom_params_updater(
|
||||
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
|
||||
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
|
||||
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
||||
in_layouts_known, _ = pe.partition_list(unks_in, params_known['in_layouts'])
|
||||
_, out_layouts_known = pe.partition_list(kept_outs_known, params_known['out_layouts'])
|
||||
|
||||
new_params_known = dict(params_known,
|
||||
in_shardings=tuple(in_shardings_known),
|
||||
out_shardings=(*out_shardings_known,
|
||||
*[UNSPECIFIED] * num_res_out),
|
||||
in_layouts=tuple(in_layouts_known),
|
||||
out_layouts=(*out_layouts_known, *[None] * num_res_out),
|
||||
donated_invars=tuple(donated_invars_known))
|
||||
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
|
||||
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
|
||||
assert len(new_params_known['in_layouts']) == len(params_known['jaxpr'].in_avals)
|
||||
assert len(new_params_known['out_layouts']) == len(params_known['jaxpr'].out_avals)
|
||||
|
||||
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
|
||||
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
|
||||
donated_invars_staged = [False] * num_res_in + donated_invars_staged
|
||||
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
|
||||
in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
|
||||
|
||||
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
|
||||
_, in_layouts_staged = pe.partition_list(inst_in, params_staged['in_layouts'])
|
||||
in_layouts_staged = [*[None] * num_res_in, *in_layouts_staged]
|
||||
_, out_layouts_staged = pe.partition_list(kept_outs_staged, params_staged['out_layouts'])
|
||||
|
||||
new_params_staged = dict(params_staged,
|
||||
in_shardings=tuple(in_shardings_staged),
|
||||
out_shardings=tuple(out_shardings_staged),
|
||||
in_layouts=tuple(in_layouts_staged),
|
||||
out_layouts=tuple(out_layouts_staged),
|
||||
donated_invars=tuple(donated_invars_staged))
|
||||
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
|
||||
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
|
||||
assert len(new_params_staged['in_layouts']) == len(params_staged['jaxpr'].in_avals)
|
||||
assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals)
|
||||
return new_params_known, new_params_staged
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
|
||||
@ -1959,7 +2086,7 @@ def _pjit_transpose_trace(fun, in_avals):
|
||||
|
||||
|
||||
def _pjit_transpose(cts_in, *primals_in,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
@ -1973,6 +2100,10 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
*prune_type(ad.UndefinedPrimal, in_shardings, primals_in),
|
||||
*prune_type(ad.Zero, out_shardings, cts_in)
|
||||
)
|
||||
transpose_in_layouts = (
|
||||
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
|
||||
*prune_type(ad.Zero, out_layouts, cts_in)
|
||||
)
|
||||
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
|
||||
for ct in primals_and_nz_cts_in)
|
||||
|
||||
@ -1983,26 +2114,36 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
ad.Zero,
|
||||
in_shardings,
|
||||
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
|
||||
transpose_out_layouts = prune_type(
|
||||
ad.Zero,
|
||||
in_layouts,
|
||||
tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
|
||||
|
||||
if attrs_tracked:
|
||||
init_states = _get_states(attrs_tracked)
|
||||
primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in]
|
||||
transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings
|
||||
transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings
|
||||
transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts
|
||||
transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts
|
||||
|
||||
nz_cts_out = pjit_p.bind(
|
||||
*primals_and_nz_cts_in,
|
||||
jaxpr=transpose_jaxpr,
|
||||
in_shardings=transpose_in_shardings,
|
||||
out_shardings=transpose_out_shardings,
|
||||
in_layouts=transpose_in_layouts,
|
||||
out_layouts=transpose_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
|
||||
if attrs_tracked:
|
||||
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
|
||||
_set_states(attrs_tracked, final_states)
|
||||
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
@ -2029,6 +2170,8 @@ def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
jaxpr=dced_jaxpr,
|
||||
in_shardings=keep_where(eqn_params["in_shardings"], used_inputs),
|
||||
out_shardings=keep_where(eqn_params["out_shardings"], used_outputs),
|
||||
in_layouts=keep_where(eqn_params["in_layouts"], used_inputs),
|
||||
out_layouts=keep_where(eqn_params["out_layouts"], used_outputs),
|
||||
donated_invars=keep_where(eqn_params["donated_invars"], used_inputs),
|
||||
)
|
||||
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
|
||||
@ -2112,6 +2255,10 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
del params['in_shardings']
|
||||
if all(is_unspecified(s) for s in params['out_shardings']):
|
||||
del params['out_shardings']
|
||||
if all(l is None for l in params['in_layouts']):
|
||||
del params['in_layouts']
|
||||
if all(l is None for l in params['out_layouts']):
|
||||
del params['out_layouts']
|
||||
if not params['keep_unused']:
|
||||
del params['keep_unused']
|
||||
if (params['resource_env'] is None or
|
||||
@ -2126,18 +2273,28 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
|
||||
|
||||
def _pjit_state_discharge_rule(
|
||||
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
|
||||
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, **params):
|
||||
if not (all(map(is_unspecified, in_shardings)) and
|
||||
all(map(is_unspecified, out_shardings))): raise NotImplementedError
|
||||
all(map(is_unspecified, out_shardings))):
|
||||
raise NotImplementedError
|
||||
|
||||
if not (all(l is None for l in in_layouts) and
|
||||
all(l is None for l in out_layouts)):
|
||||
raise NotImplementedError
|
||||
|
||||
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
||||
num_outs = len(jaxpr.outvars)
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts)
|
||||
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
|
||||
new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars)
|
||||
new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars)
|
||||
new_in_layouts = (None,) * len(discharged_jaxpr.invars)
|
||||
new_out_layouts = (None,) * len(discharged_jaxpr.outvars)
|
||||
out_and_ref_vals = pjit_p.bind(
|
||||
*args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings,
|
||||
out_shardings=new_out_shardings, **params)
|
||||
out_shardings=new_out_shardings, in_layouts=new_in_layouts,
|
||||
out_layouts=new_out_layouts, **params)
|
||||
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
|
||||
ref_vals_iter = iter(ref_vals)
|
||||
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, state_discharge.AbstractRef)
|
||||
|
@ -88,12 +88,10 @@ class Executable(Protocol):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Layouts are exposed via jax.experimental.layouts
|
||||
# TODO(frostig,yashkatariya): expose here when no longer experimental.
|
||||
def _input_layouts(self):
|
||||
def input_layouts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _output_layouts(self):
|
||||
def output_layouts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def as_text(self) -> str:
|
||||
@ -228,11 +226,11 @@ class XlaExecutable(Executable):
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no output sharding information")
|
||||
|
||||
def _input_layouts(self):
|
||||
def input_layouts(self):
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no input layout information")
|
||||
|
||||
def _output_layouts(self):
|
||||
def output_layouts(self):
|
||||
raise NotImplementedError(
|
||||
"compiled executable carries no input layout information")
|
||||
|
||||
@ -511,7 +509,7 @@ class Compiled(Stage):
|
||||
shardings_flat = self._executable.output_shardings()
|
||||
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
|
||||
|
||||
def _input_layouts(self):
|
||||
def input_layouts(self):
|
||||
layouts_flat = self._executable.input_layouts()
|
||||
assert all(isinstance(l, Layout) for l in layouts_flat)
|
||||
# Some input layouts got DCE'd
|
||||
@ -521,7 +519,7 @@ class Compiled(Stage):
|
||||
else Layout() for i in range(self.in_tree.num_leaves)]
|
||||
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
||||
def _output_layouts(self):
|
||||
def output_layouts(self):
|
||||
layouts_flat = self._executable.output_layouts()
|
||||
assert all(isinstance(l, Layout) for l in layouts_flat)
|
||||
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
|
||||
|
@ -1643,6 +1643,8 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn],
|
||||
eqn.params["out_shardings"]
|
||||
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
|
||||
),
|
||||
in_layouts=(eqn.params["in_layouts"] + (None, None)),
|
||||
out_layouts=(eqn.params["out_layouts"] + (None, None)),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -3504,6 +3504,7 @@ def _pjit(*args: TfVal,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings: Sequence[sharding.XLACompatibleSharding],
|
||||
out_shardings: Sequence[sharding.XLACompatibleSharding],
|
||||
in_layouts, out_layouts,
|
||||
resource_env: mesh.ResourceEnv,
|
||||
donated_invars,
|
||||
name: str,
|
||||
|
@ -739,6 +739,8 @@ def _pjit_jet_rule(primals_in, series_in, **params):
|
||||
params['out_shardings']
|
||||
+ (sharding_impls.UNSPECIFIED,) * num_series_out
|
||||
),
|
||||
'in_layouts': params['in_layouts'] + (None,) * num_series_in,
|
||||
'out_layouts': params['out_layouts'] + (None,) * num_series_out,
|
||||
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
|
||||
}
|
||||
result = pjit.pjit_p.bind(*primals_and_series, **new_params)
|
||||
|
@ -772,7 +772,8 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
|
||||
|
||||
|
||||
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
in_layouts, out_layouts, resource_env, donated_invars, name,
|
||||
keep_unused, inline):
|
||||
if any(donated_invars):
|
||||
raise NotImplementedError("sparse xla_call with donated_invars")
|
||||
|
||||
@ -790,12 +791,20 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
sharding_impls.UNSPECIFIED
|
||||
for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings))
|
||||
)
|
||||
in_layouts = in_layouts + tuple(
|
||||
None for _ in range(len(args_flat) - len(in_layouts))
|
||||
)
|
||||
out_layouts = out_layouts + tuple(
|
||||
None for _ in range(len(sp_call_jaxpr.out_avals) - len(out_layouts))
|
||||
)
|
||||
|
||||
out_flat = pjit.pjit_p.bind(
|
||||
*args_flat,
|
||||
jaxpr=sp_call_jaxpr,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings,
|
||||
in_layouts=in_layouts,
|
||||
out_layouts=out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
|
@ -451,8 +451,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"h\"}.*"
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"w\"}.*"
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"h\".*"
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"w\".*"
|
||||
r"%arg2: tensor<\?x\?xf32>"
|
||||
)
|
||||
self.assertRegex(module_str, wrapped_main_expected_re)
|
||||
@ -1238,12 +1238,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"b1\"}.*, "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b2\"}.*, "
|
||||
r"%arg2: !stablehlo.token {jax.token = true}.*, "
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"b1\".* "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b2\".* "
|
||||
r"%arg2: !stablehlo.token {jax.token = true.* "
|
||||
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
@ -1254,10 +1254,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
else:
|
||||
main_expected_re = (
|
||||
r"@main\("
|
||||
r"%arg0: !stablehlo.token {jax.token = true}.*, "
|
||||
r"%arg0: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg1: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
|
||||
res = export.call_exported(exp)(x)
|
||||
@ -1284,13 +1284,13 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
wrapped_main_expected_re = (
|
||||
r"@_wrapped_jax_export_main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\"}.*, "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b1\"}.*, "
|
||||
r"%arg2: tensor<i..> {jax.global_constant = \"b2\"}.*, "
|
||||
r"%arg3: !stablehlo.token {jax.token = true}.*, "
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
||||
r"%arg1: tensor<i..> {jax.global_constant = \"b1\".*, "
|
||||
r"%arg2: tensor<i..> {jax.global_constant = \"b2\".*, "
|
||||
r"%arg3: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
@ -1301,11 +1301,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
else:
|
||||
main_expected_re = (
|
||||
r"@main\("
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\"}.*, "
|
||||
r"%arg1: !stablehlo.token {jax.token = true}.*, "
|
||||
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\".*, "
|
||||
r"%arg1: !stablehlo.token {jax.token = true.*, "
|
||||
r"%arg2: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
|
||||
r"!stablehlo.token {jax.token = true.*, tensor<\?x\?xf32>.*\)")
|
||||
self.assertRegex(mlir_module_str, main_expected_re)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(10. + _testing_multi_platform_fun_expected(x),
|
||||
|
@ -88,22 +88,22 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1)
|
||||
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
|
||||
|
||||
lowered_apply = jax.jit(apply).lower(
|
||||
sds1, sds2, _in_layouts=Layout(DLL.AUTO), _out_layouts=Layout(DLL.AUTO))
|
||||
lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2)
|
||||
compiled_apply = lowered_apply.compile()
|
||||
|
||||
arg_layouts, kw_layouts = compiled_apply._input_layouts()
|
||||
arg_layouts, kw_layouts = compiled_apply.input_layouts()
|
||||
self.assertEmpty(kw_layouts)
|
||||
|
||||
for i, o in zip(arg_layouts, compiled_apply._output_layouts()):
|
||||
for i, o in zip(arg_layouts, compiled_apply.output_layouts()):
|
||||
self.assertEqual(extract_minor_to_major(i),
|
||||
extract_minor_to_major(o)[::-1])
|
||||
|
||||
init_compiled = jax.jit(init).lower(
|
||||
sds1, sds2, _out_layouts=arg_layouts).compile()
|
||||
init_compiled = jax.jit(
|
||||
init, out_shardings=arg_layouts).lower(sds1, sds2).compile()
|
||||
|
||||
for i, o in zip(init_compiled._input_layouts()[0],
|
||||
init_compiled._output_layouts()):
|
||||
for i, o in zip(init_compiled.input_layouts()[0],
|
||||
init_compiled.output_layouts()):
|
||||
self.assertEqual(i, o)
|
||||
|
||||
arr1 = jax.device_put(np_inp1, s1)
|
||||
@ -114,16 +114,16 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
init_compiled(arr1, arr2)
|
||||
self.assertEqual(init_count[0], 1)
|
||||
|
||||
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
|
||||
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1])
|
||||
self.assertEqual(init_out[0].layout, init_compiled.output_layouts()[0])
|
||||
self.assertEqual(init_out[1].layout, init_compiled.output_layouts()[1])
|
||||
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
|
||||
apply_out = compiled_apply(*init_out)
|
||||
compiled_apply(*init_out)
|
||||
self.assertEqual(apply_count[0], 1)
|
||||
|
||||
self.assertEqual(apply_out[0].layout, compiled_apply._output_layouts()[0])
|
||||
self.assertEqual(apply_out[1].layout, compiled_apply._output_layouts()[1])
|
||||
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0])
|
||||
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1])
|
||||
|
||||
self.assertTupleEqual(extract_minor_to_major(apply_out[0].layout),
|
||||
extract_minor_to_major(init_out[0].layout)[::-1])
|
||||
@ -146,24 +146,29 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
lowered = jax.jit(f).lower(sds, _in_layouts=None, _out_layouts=None)
|
||||
lowered = jax.jit(f, in_shardings=None, out_shardings=None).lower(sds)
|
||||
self.assertIn("default", lowered.as_text())
|
||||
compiled = lowered.compile()
|
||||
out = compiled(arr)
|
||||
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (2, 1, 0))
|
||||
extract_minor_to_major(compiled.input_layouts()[0][0]), (2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._output_layouts()), (2, 1, 0))
|
||||
extract_minor_to_major(compiled.output_layouts()), (2, 1, 0))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
compiled_auto = jax.jit(f).lower(sds, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
|
||||
extract_minor_to_major(compiled_auto.input_layouts()[0][0]), (2, 1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled_auto._output_layouts()), (0, 1, 2))
|
||||
extract_minor_to_major(compiled_auto.output_layouts()), (0, 1, 2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "jax.jit` does not accept device-local layouts directly"):
|
||||
jax.jit(f, in_shardings=DLL.AUTO,
|
||||
out_shardings=DLL.AUTO).lower(sds).compile()
|
||||
|
||||
def test_in_layouts_out_layouts(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
@ -175,16 +180,16 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x.T
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
arr, _in_layouts=Layout(), _out_layouts=Layout(DLL.AUTO)).compile()
|
||||
compiled = jax.jit(f, in_shardings=Layout(),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
|
||||
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._output_layouts()), (0, 1))
|
||||
extract_minor_to_major(compiled.output_layouts()), (0, 1))
|
||||
|
||||
out = compiled(arr)
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.layout, compiled._output_layouts())
|
||||
self.assertEqual(out.layout, compiled.output_layouts())
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
def test_sharding_and_layouts(self):
|
||||
@ -193,14 +198,13 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
|
||||
np_inp, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s),
|
||||
out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile()
|
||||
out = compiled(np_inp)
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
|
||||
extract_minor_to_major(compiled.input_layouts()[0][0]), (1, 0))
|
||||
self.assertTupleEqual(
|
||||
extract_minor_to_major(compiled._output_layouts()), (0, 1))
|
||||
extract_minor_to_major(compiled.output_layouts()), (0, 1))
|
||||
self.assertArraysEqual(out, np_inp.T)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
@ -210,15 +214,15 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
|
||||
shape = (8, 2)
|
||||
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
|
||||
compiled = jax.jit(f).lower(*inps, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
arg_layouts, _ = compiled._input_layouts()
|
||||
compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(*inps).compile()
|
||||
arg_layouts, _ = compiled.input_layouts()
|
||||
out1, out2 = compiled(*inps)
|
||||
|
||||
compiled2 = jax.jit(f).lower(*inps, _in_layouts=arg_layouts).compile()
|
||||
compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile()
|
||||
out3, out4 = compiled2(*inps)
|
||||
|
||||
for l1, l2 in safe_zip(arg_layouts, compiled2._input_layouts()[0]):
|
||||
for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts()[0]):
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
self.assertArraysEqual(out1, out3)
|
||||
@ -244,11 +248,10 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
jax.jit(f).lower(arr, _in_layouts=Layout(DLL.AUTO))
|
||||
jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr)
|
||||
|
||||
compiled = jax.jit(f).lower(
|
||||
sds, _in_layouts=Layout(DLL.AUTO),
|
||||
_out_layouts=Layout(DLL.AUTO)).compile()
|
||||
compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
||||
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -271,8 +274,8 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
compiled = jax.jit(
|
||||
lambda x: x * 2).lower(arr, _out_layouts=Layout(DLL.AUTO)).compile()
|
||||
col = compiled._output_layouts()
|
||||
lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
|
||||
col = compiled.output_layouts()
|
||||
|
||||
out = jax.device_put(np_inp, col)
|
||||
self.assertEqual(out.layout, col)
|
||||
@ -304,7 +307,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
compiled = jax.jit(lambda x: x).lower(x).compile()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Sharding has to be concrete when layout.*'):
|
||||
Layout(compiled._output_layouts()[0], None)
|
||||
Layout(compiled.output_layouts()[0], None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1500,7 +1500,7 @@ class PallasCallInputOutputAliasingTest(PallasTPUTest):
|
||||
)(x)
|
||||
o = f(x)
|
||||
np.testing.assert_array_equal(o, expected)
|
||||
compiled = f.lower(x).compile()
|
||||
compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile()
|
||||
mem_analysis = compiled.memory_analysis()
|
||||
expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize
|
||||
self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes)
|
||||
@ -1528,7 +1528,7 @@ class PallasCallInputOutputAliasingTest(PallasTPUTest):
|
||||
)(jnp.array([1,2,3]), x)
|
||||
o = f(x)
|
||||
np.testing.assert_array_equal(o, expected)
|
||||
compiled = f.lower(x).compile()
|
||||
compiled = f.lower(jax.ShapeDtypeStruct(x.shape, x.dtype)).compile()
|
||||
mem_analysis = compiled.memory_analysis()
|
||||
expected_num_bytes = np.prod(x.shape) * x.dtype.itemsize
|
||||
self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user