diff --git a/jax/_src/api.py b/jax/_src/api.py index a45812444..03775602c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -282,10 +282,36 @@ def jit( return _jit(False, fun, static_argnums, static_argnames, device, backend, donate_argnums, inline, keep_unused, abstracted_axes) -# TODO(yashkatariya): Remove the above jit function after -# `jax_jit_pjit_api_merge` defaults to True. + if jax.config.jax_jit_pjit_api_merge: - jit = pjit.pjit # type: ignore # noqa: F811 + def jit( # type: ignore # noqa: F811 # pylint: disable=function-redefined + fun: Callable, + in_axis_resources=pxla._UNSPECIFIED, + out_axis_resources=pxla._UNSPECIFIED, + static_argnums: Union[int, Sequence[int], None] = None, + static_argnames: Union[str, Iterable[str], None] = None, + donate_argnums: Union[int, Sequence[int]] = (), + keep_unused: bool = False, + device: Optional[xc.Device] = None, + backend: Optional[str] = None, + inline: bool = False, + ) -> stages.Wrapped: + (in_axis_resources, out_axis_resources, donate_argnums, static_argnums, + static_argnames) = pjit.pre_infer_params( + fun, in_axis_resources, out_axis_resources, donate_argnums, + static_argnums, static_argnames, device, backend) + + def infer_params(*args, **kwargs): + pjit_info_args = pjit.PjitInfo( + fun=fun, in_axis_resources=in_axis_resources, + out_axis_resources=out_axis_resources, static_argnums=static_argnums, + static_argnames=static_argnames, donate_argnums=donate_argnums, + device=device, backend=backend, keep_unused=keep_unused, + inline=inline, resource_env=None) + return pjit.common_infer_params(pjit_info_args, *args, **kwargs) + + return pjit.post_infer_params(fun, infer_params, static_argnums, + static_argnames) def _jit( diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 201f23cf1..4d348a84f 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -855,8 +855,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = OpShardingSharding.get_replicated( - list(resource_env.physical_mesh.devices.flat)) + if jax.config.jax_array: + sharding = pjit._UNSPECIFIED + else: + sharding = OpShardingSharding.get_replicated( + list(resource_env.physical_mesh.devices.flat)) + new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f0453d2fc..61e20a109 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -17,7 +17,7 @@ from enum import IntEnum import numpy as np from collections import OrderedDict, Counter from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional, - Iterable) + Iterable, NamedTuple, Any) import itertools as it from functools import partial, lru_cache import threading @@ -170,7 +170,237 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums, static_argnames): return wraps(fun)(cpp_pjit_f) -# TODO(yashkatariya): Add pjit microbenchmarks. +def pre_infer_params(fun, in_axis_resources, out_axis_resources, + donate_argnums, static_argnums, static_argnames, device, + backend): + check_callable(fun) + + if not config.jax_array and (_is_unspecified(in_axis_resources) or + _is_unspecified(out_axis_resources)): + raise ValueError( + "in_axis_resources and out_axis_resources should not " + "be the unspecified singleton value. Please enable `jax.Array` to use " + "this feature. You can use jax.config.update('jax_array', True) or " + "set the environment variable JAX_ARRAY=1 , or set the `jax_array` " + "boolean flag to something true-like.") + + if backend is not None or device is not None: + warnings.warn( + 'backend and device argument on jit is deprecated. You can use a ' + '`jax.sharding.Mesh` context manager or device_put the arguments ' + 'before passing them to `jit`. Please see ' + 'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html ' + 'for more information.', DeprecationWarning) + if device is not None and backend is not None: + raise ValueError("can't specify both a device and a backend for jit, " + f"got {device=} and {backend=}") + if not _is_unspecified(in_axis_resources): + raise ValueError('If backend or device is specified on jit, then ' + 'in_axis_resources should not be specified.') + if not _is_unspecified(out_axis_resources): + raise ValueError('If backend or device is specified on jit, then ' + 'out_axis_resources should not be specified.') + + if isinstance(in_axis_resources, list): + # To be a tree prefix of the positional args tuple, in_axes can never be a + # list: if in_axes is not a leaf, it must be a tuple of trees. However, + # in cases like these users expect tuples and lists to be treated + # essentially interchangeably, so we canonicalize lists to tuples here + # rather than raising an error. https://github.com/google/jax/issues/2367 + in_axis_resources = tuple(in_axis_resources) + + in_axis_resources, _, _ = _prepare_axis_resources( + in_axis_resources, "in_axis_resources") + out_axis_resources, _, _ = _prepare_axis_resources( + out_axis_resources, "out_axis_resources") + + donate_argnums, static_argnums, static_argnames = resolve_argnums( + fun, donate_argnums, static_argnums, static_argnames) + + return (in_axis_resources, out_axis_resources, donate_argnums, static_argnums, + static_argnames) + + +def post_infer_params(fun, infer_params, static_argnums, static_argnames): + if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115: + wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames) + else: + wrapped = _python_pjit(fun, infer_params) + + def lower(*args, **kwargs): + (args_flat, flat_local_in_avals, params, in_tree, out_tree, + donate_argnums) = infer_params(*args, **kwargs) + if config.jax_array: + resource_env = params['resource_env'] + mesh = None if resource_env is None else resource_env.physical_mesh + in_shardings = _resolve_in_shardings( + args_flat, params['in_shardings'], params['out_shardings'], mesh) + else: + in_shardings = params['in_shardings'] + in_is_global = _calc_is_global_sequence( + params['in_positional_semantics'], in_shardings) + lowering = _pjit_lower( + params['jaxpr'], in_shardings, params['out_shardings'], + params['resource_env'], params['donated_invars'], params['name'], + in_is_global, params['keep_unused'], always_lower=True) + + if kwargs: + args_kwargs_in_tree = in_tree + local_in_avals = in_tree.unflatten(flat_local_in_avals) + else: + args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) + local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals) + + return stages.Lowered.from_flat_info( + lowering, + args_kwargs_in_tree, + local_in_avals, + donate_argnums, + out_tree, + no_kwargs=True) + + wrapped.lower = lower + return wrapped + + +class PjitInfo(NamedTuple): + fun: Callable + in_axis_resources: Any + out_axis_resources: Any + static_argnums: Tuple[int, ...] + static_argnames: Tuple[str, ...] + donate_argnums: Tuple[int, ...] + device: Optional[xc.Device] + backend: Optional[str] + keep_unused: bool + inline: bool + resource_env: Any + + +def common_infer_params(pjit_info_args, *args, **kwargs): + (fun, in_axis_resources, out_axis_resources, static_argnums, static_argnames, + donate_argnums, device, backend, keep_unused, inline, + resource_env) = pjit_info_args + + if kwargs and not _is_unspecified(in_axis_resources): + raise ValueError( + "pjit does not support kwargs when in_axis_resources is specified.") + + if resource_env is not None: + pjit_mesh = resource_env.physical_mesh + if pjit_mesh.empty: + if config.jax_array: + # Don't enforce requiring a mesh when `jax_array` flag is enabled. But + # if mesh is not empty then pjit will respect it. + pass + else: + raise RuntimeError("pjit requires a non-empty mesh! Are you sure that " + "it's defined at the call site?") + else: + pjit_mesh = None + + if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty: + raise ValueError( + "Mesh context manager should not be used with jit when backend or " + "device is also specified as an argument to jit.") + + f = lu.wrap_init(fun) + f, dyn_args = argnums_partial_except(f, static_argnums, args, + allow_invalid=True) + del args + + # TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is + # flatten_axes which if kwargs are present in the treedef (even empty {}), + # leads to wrong expansion. + if kwargs: + f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs) + args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs)) + flat_fun, out_tree = flatten_fun(f, in_tree) + else: + args_flat, in_tree = tree_flatten(dyn_args) + flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) + dyn_kwargs = () + del kwargs + + if donate_argnums and not config.jax_debug_nans: + donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs) + else: + donated_invars = (False,) * len(args_flat) + + if config.jax_array: + # If backend or device is set as an arg on jit, then resolve them to + # in_shardings and out_shardings as if user passed in in_axis_resources + # and out_axis_resources. + if backend or device: + in_shardings = out_shardings = _create_sharding_with_device_backend( + device, backend) + else: + in_shardings = tree_map( + lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources) + out_shardings = tree_map( + lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources) + else: + in_shardings = tree_map( + lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), + in_axis_resources) + out_shardings = tree_map( + lambda x: x if _is_unspecified(x) else + _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources) + # This check fails extremely rarely and has a huge cost in the dispatch + # path. So hide it behind the jax_enable_checks flag. + if config.jax_enable_checks: + _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh) + + local_in_avals = tuple(shaped_abstractify(a) for a in args_flat) + # TODO(yashkatariya): This is a hack. This should go away when avals have + # is_global attribute. + if config.jax_array: + in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat) + else: + in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat)) + out_positional_semantics = ( + pxla._PositionalSemantics.GLOBAL + if config.jax_parallel_functions_output_gda or config.jax_array else + pxla._positional_semantics.val) + + global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources( + hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, + tuple(isinstance(a, GDA) for a in args_flat), resource_env) + + jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( + flat_fun, hashable_pytree(out_shardings), global_in_avals, + HashableFunction(out_tree, closure=())) + + if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or + not config.jax_array): + canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec( + canonicalized_in_shardings_flat, args_flat) + + assert len(args_flat) == len(canonicalized_in_shardings_flat) + + canonicalized_in_shardings_flat = ( + _UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat + donated_invars = (False,) * len(consts) + donated_invars + in_positional_semantics = ( + pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics + + # in_shardings and out_shardings here are all OpShardingSharding. + params = dict( + jaxpr=jaxpr, + in_shardings=canonicalized_in_shardings_flat, + out_shardings=canonicalized_out_shardings_flat, + resource_env=resource_env, + donated_invars=donated_invars, + name=getattr(flat_fun, '__name__', ''), + in_positional_semantics=in_positional_semantics, + out_positional_semantics=out_positional_semantics, + keep_unused=keep_unused, + inline=inline, + ) + return (consts + args_flat, local_in_avals, params, in_tree, out_tree(), + donate_argnums) + + # in_axis_resources and out_axis_resources can't be None as the default value # because `None` means that the input is fully replicated. def pjit( @@ -320,206 +550,24 @@ def pjit( ... print(f(x)) # doctest: +SKIP [ 0.5 2. 4. 6. 8. 10. 12. 10. ] """ - check_callable(fun) - - if not config.jax_array and (_is_unspecified(in_axis_resources) or - _is_unspecified(out_axis_resources)): - raise ValueError( - "in_axis_resources and out_axis_resources should not " - "be the unspecified singleton value. Please enable `jax.Array` to use " - "this feature. You can use jax.config.update('jax_array', True) or " - "set the environment variable JAX_ARRAY=1 , or set the `jax_array` " - "boolean flag to something true-like.") - - if backend is not None or device is not None: - warnings.warn( - 'backend and device argument on jit is deprecated. You can use a ' - '`jax.sharding.Mesh` context manager or device_put the arguments ' - 'before passing them to `jit`. Please see ' - 'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html ' - 'for more information.', DeprecationWarning) - if device is not None and backend is not None: - raise ValueError("can't specify both a device and a backend for jit, " - f"got {device=} and {backend=}") - if not _is_unspecified(in_axis_resources): - raise ValueError('If backend or device is specified on jit, then ' - 'in_axis_resources should not be specified.') - if not _is_unspecified(out_axis_resources): - raise ValueError('If backend or device is specified on jit, then ' - 'out_axis_resources should not be specified.') - - if isinstance(in_axis_resources, list): - # To be a tree prefix of the positional args tuple, in_axes can never be a - # list: if in_axes is not a leaf, it must be a tuple of trees. However, - # in cases like these users expect tuples and lists to be treated - # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 - in_axis_resources = tuple(in_axis_resources) - - in_axis_resources, _, _ = _prepare_axis_resources( - in_axis_resources, "in_axis_resources") - out_axis_resources, _, _ = _prepare_axis_resources( - out_axis_resources, "out_axis_resources") - - donate_argnums, static_argnums, static_argnames = resolve_argnums( - fun, donate_argnums, static_argnums, static_argnames) + (in_axis_resources, out_axis_resources, donate_argnums, static_argnums, + static_argnames) = pre_infer_params( + fun, in_axis_resources, out_axis_resources, donate_argnums, + static_argnums, static_argnames, device, backend) def infer_params(*args, **kwargs): - if kwargs and not _is_unspecified(in_axis_resources): - raise ValueError( - "pjit does not support kwargs when in_axis_resources is specified.") - # Putting this outside of wrapped would make resources lexically scoped resource_env = pxla.thread_resources.env - pjit_mesh = resource_env.physical_mesh - if pjit_mesh.empty: - if config.jax_array: - # Don't enforce requiring a mesh when `jax_array` flag is enabled. But - # if mesh is not empty then pjit will respect it. - pass - else: - raise RuntimeError("pjit requires a non-empty mesh! Are you sure that " - "it's defined at the call site?") + pjit_info_args = PjitInfo( + fun=fun, in_axis_resources=in_axis_resources, + out_axis_resources=out_axis_resources, static_argnums=static_argnums, + static_argnames=static_argnames, donate_argnums=donate_argnums, + device=device, backend=backend, keep_unused=keep_unused, + inline=inline, resource_env=resource_env) + return common_infer_params(pjit_info_args, *args, **kwargs) - if (backend or device) and not pjit_mesh.empty: - raise ValueError( - "Mesh context manager should not be used with jit when backend or " - "device is also specified as an argument to jit.") + return post_infer_params(fun, infer_params, static_argnums, static_argnames) - f = lu.wrap_init(fun) - f, dyn_args = argnums_partial_except(f, static_argnums, args, - allow_invalid=True) - del args - - # TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is - # flatten_axes which if kwargs are present in the treedef (even empty {}), - # leads to wrong expansion. - if kwargs: - f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs) - args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs)) - flat_fun, out_tree = flatten_fun(f, in_tree) - else: - args_flat, in_tree = tree_flatten(dyn_args) - flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) - dyn_kwargs = () - del kwargs - - if donate_argnums and not config.jax_debug_nans: - donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs) - else: - donated_invars = (False,) * len(args_flat) - - if config.jax_array: - # If backend or device is set as an arg on jit, then resolve them to - # in_shardings and out_shardings as if user passed in in_axis_resources - # and out_axis_resources. - if backend or device: - in_shardings = out_shardings = _create_sharding_with_device_backend( - device, backend) - else: - in_shardings = tree_map( - lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources) - out_shardings = tree_map( - lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources) - else: - in_shardings = tree_map( - lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), - in_axis_resources) - out_shardings = tree_map( - lambda x: x if _is_unspecified(x) else - _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources) - # This check fails extremely rarely and has a huge cost in the dispatch - # path. So hide it behind the jax_enable_checks flag. - if config.jax_enable_checks: - _maybe_check_pjit_gda_mesh(args_flat, pjit_mesh) - - local_in_avals = tuple(shaped_abstractify(a) for a in args_flat) - # TODO(yashkatariya): This is a hack. This should go away when avals have - # is_global attribute. - if config.jax_array: - in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat) - else: - in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat)) - out_positional_semantics = ( - pxla._PositionalSemantics.GLOBAL - if config.jax_parallel_functions_output_gda or config.jax_array else - pxla._positional_semantics.val) - - global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources( - hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics, - tuple(isinstance(a, GDA) for a in args_flat), resource_env) - - jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( - flat_fun, hashable_pytree(out_shardings), global_in_avals, - HashableFunction(out_tree, closure=())) - - if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or - not config.jax_array): - canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec( - canonicalized_in_shardings_flat, args_flat) - - assert len(args_flat) == len(canonicalized_in_shardings_flat) - - canonicalized_in_shardings_flat = ( - _UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat - donated_invars = (False,) * len(consts) + donated_invars - in_positional_semantics = ( - pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics - - # in_shardings and out_shardings here are all OpShardingSharding. - params = dict( - jaxpr=jaxpr, - in_shardings=canonicalized_in_shardings_flat, - out_shardings=canonicalized_out_shardings_flat, - resource_env=resource_env, - donated_invars=donated_invars, - name=getattr(flat_fun, '__name__', ''), - in_positional_semantics=in_positional_semantics, - out_positional_semantics=out_positional_semantics, - keep_unused=keep_unused, - inline=inline, - ) - return (consts + args_flat, local_in_avals, params, in_tree, out_tree(), - donate_argnums) - - if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115: - wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames) - else: - wrapped = _python_pjit(fun, infer_params) - - def lower(*args, **kwargs): - (args_flat, flat_local_in_avals, params, in_tree, out_tree, - donate_argnums) = infer_params(*args, **kwargs) - if config.jax_array: - in_shardings = _resolve_in_shardings( - args_flat, params['in_shardings'], params['out_shardings'], - params['resource_env'].physical_mesh) - else: - in_shardings = params['in_shardings'] - in_is_global = _calc_is_global_sequence( - params['in_positional_semantics'], in_shardings) - lowering = _pjit_lower( - params['jaxpr'], in_shardings, params['out_shardings'], - params['resource_env'], params['donated_invars'], params['name'], - in_is_global, params['keep_unused'], always_lower=True) - - if kwargs: - args_kwargs_in_tree = in_tree - local_in_avals = in_tree.unflatten(flat_local_in_avals) - else: - args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) - local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals) - - return stages.Lowered.from_flat_info( - lowering, - args_kwargs_in_tree, - local_in_avals, - donate_argnums, - out_tree, - no_kwargs=True) - - wrapped.lower = lower - return wrapped class _ListWithW(list): __slots__ = ('__weakref__',) @@ -543,6 +591,10 @@ def _create_sharding_for_array(mesh, x): # FROM_GDA is removed. if isinstance(x, XLACompatibleSharding) or _is_unspecified_or_from_gda_or_auto(x): return x + if mesh is None: + raise RuntimeError( + "jit does not support using the mesh context manager. Please pass in " + "the sharding explicitly via in_axis_resources or out_axis_resources.") if mesh.empty: raise RuntimeError("pjit requires a non-empty mesh! Is a mesh defined at " "the call site? Alternatively, provide a " @@ -944,7 +996,10 @@ pjit_p = core.Primitive("pjit") pjit_p.multiple_results = True -def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh): +def _resolve_in_shardings( + args, pjit_in_shardings: Sequence[PjitSharding], + out_shardings: Sequence[PjitSharding], + pjit_mesh: Optional[pxla.Mesh]) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it # has the same semantics as device_put i.e. doesn't matter which device the # arg is on, reshard it to the device mentioned. So don't do any of the @@ -972,7 +1027,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh): pxla._get_and_check_device_assignment( it.chain( committed_arg_shardings, pjit_in_shardings, out_shardings), - (None if pjit_mesh.empty else list(pjit_mesh.devices.flat))) + (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) resolved_in_shardings = [] for arg, pjit_in_s in safe_zip(args, pjit_in_shardings): @@ -1037,8 +1092,9 @@ def _pjit_call_impl(*args, jaxpr, global _most_recent_pjit_call_executable if config.jax_array: - in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings, - resource_env.physical_mesh) + in_shardings = _resolve_in_shardings( + args, in_shardings, out_shardings, + resource_env.physical_mesh if resource_env is not None else None) in_is_global = _calc_is_global_sequence(in_positional_semantics, in_shardings) if config.jax_array and all(_is_unspecified(o) for o in out_shardings): @@ -1145,11 +1201,15 @@ def _pjit_lower_cached( out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") + f = core.jaxpr_as_fun(jaxpr) f.__name__ = name fun = lu.wrap_init(f) - mesh = resource_env.physical_mesh + if resource_env is not None: + mesh = resource_env.physical_mesh + else: + mesh = None # Convert to `NamedSharding` when `jax_array` is not enabled. This is # because GDA/SDA/DA are dependent on mesh for generating outputs. @@ -1187,7 +1247,8 @@ def _pjit_lower_cached( fun, 'pjit', name, in_shardings, out_shardings, donated_invars, jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused, always_lower=always_lower, - devices_from_context=(None if mesh.empty else list(mesh.devices.flat))) + devices_from_context=( + None if mesh is None or mesh.empty else list(mesh.devices.flat))) def pjit_staging_rule(trace, *args, **params): @@ -1207,6 +1268,8 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env, disallowed_effects = jaxpr.effects - mlir.lowerable_effects if disallowed_effects: raise ValueError('Effects not supported in `pjit`.') + if config.jax_array: + return jaxpr.out_avals, jaxpr.effects return global_to_local(out_positional_semantics, jaxpr.out_avals, out_shardings, resource_env.physical_mesh), jaxpr.effects pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) @@ -1216,7 +1279,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, in_positional_semantics, out_positional_semantics, keep_unused, inline): - if not config.jax_array: + if not config.jax_jit_pjit_api_merge: if not isinstance(ctx.module_context.axis_context, (mlir.SPMDAxisContext, mlir.ShardingContext)): raise RuntimeError("Nesting pjit() inside jit() is not allowed.") @@ -1264,7 +1327,12 @@ def _pjit_batcher(insert_axis, spmd_axis_name, # `insert_axis` is set to True only for some `xmap` uses. new_parts = (axis_name,) if insert_axis else ( () if spmd_axis_name is None else (spmd_axis_name,)) - mesh = resource_env.physical_mesh + + if resource_env is not None: + mesh = resource_env.physical_mesh + else: + mesh = None + in_shardings = tuple( _pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i for is_mapped, i, aval in zip(is_mapped_in, in_shardings, new_jaxpr.in_avals)) @@ -1302,7 +1370,7 @@ def _pjit_batcher_for_sharding( return OpShardingSharding(s._device_assignment, new_op) # type: ignore else: assert isinstance(s, OpShardingSharding) - assert not mesh.empty + assert mesh is not None and not mesh.empty parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] # type: ignore parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val) mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec) @@ -1378,32 +1446,37 @@ def _pjit_partial_eval(trace, *in_tracers, keep_unused=keep_unused, inline=inline) - if num_residuals: - in_is_global = _calc_is_global_sequence( - known_params['in_positional_semantics'], known_params['in_shardings']) - compiled = _pjit_lower( - known_params["jaxpr"], known_params["in_shardings"], - known_params["out_shardings"], known_params["resource_env"], - known_params["donated_invars"], known_params["name"], - in_is_global, known_params['keep_unused'], always_lower=False).compile( - _allow_propagation_to_outputs=True, - _allow_compile_replicated=False) - da = compiled._device_assignment - _, out_op_sharding_shardings = pxla._get_op_sharding_shardings_from_executable( - compiled.xla_executable, da, len(known_jaxpr.in_avals), - len(known_jaxpr.out_avals)) - assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), ( - len(out_op_sharding_shardings), len(known_jaxpr.out_avals)) - out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in - safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)] - residual_op_shardings = tuple(out_op_shardings[-num_residuals:]) - else: - residual_op_shardings = () - assert len(residual_shardings) == len(residual_op_shardings), ( - len(residual_shardings), len(residual_op_shardings)) - residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings) - known_params['out_shardings'] = ( - keep_where(out_shardings, known_outs) + residual_shardings) + # resource_env is None in the jit wrapper around pjit. + # TODO(apaszke,yashkatariya): Replace this check with + # `if not config.jax_array` after XLA stops overriding user shardings when + # `_allow_propagation_to_outputs = True`. + if resource_env is not None: + if num_residuals: + in_is_global = _calc_is_global_sequence( + known_params['in_positional_semantics'], known_params['in_shardings']) + compiled = _pjit_lower( + known_params["jaxpr"], known_params["in_shardings"], + known_params["out_shardings"], known_params["resource_env"], + known_params["donated_invars"], known_params["name"], + in_is_global, known_params['keep_unused'], always_lower=False).compile( + _allow_propagation_to_outputs=True, + _allow_compile_replicated=False) + da = compiled._device_assignment + _, out_op_sharding_shardings = pxla._get_op_sharding_shardings_from_executable( + compiled.xla_executable, da, len(known_jaxpr.in_avals), + len(known_jaxpr.out_avals)) + assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), ( + len(out_op_sharding_shardings), len(known_jaxpr.out_avals)) + out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in + safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)] + residual_op_shardings = tuple(out_op_shardings[-num_residuals:]) + else: + residual_op_shardings = () + assert len(residual_shardings) == len(residual_op_shardings), ( + len(residual_shardings), len(residual_op_shardings)) + residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings) + known_params['out_shardings'] = ( + keep_where(out_shardings, known_outs) + residual_shardings) all_known_outs = pjit_p.bind( *(pv.get_known() for pv in in_pvals if pv.is_known()), @@ -1436,12 +1509,16 @@ def _pjit_partial_eval(trace, *in_tracers, keep_unused=keep_unused, inline=inline) unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] + if config.jax_array: + unknown_out_avals = unknown_jaxpr.out_avals + else: + unknown_out_avals = global_to_local( + unknown_params["out_positional_semantics"], unknown_jaxpr.out_avals, + unknown_params["out_shardings"], + unknown_params["resource_env"].physical_mesh) unknown_tracers_out = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) - for aval in global_to_local(unknown_params["out_positional_semantics"], - unknown_jaxpr.out_avals, - unknown_params["out_shardings"], - unknown_params["resource_env"].physical_mesh) + for aval in unknown_out_avals ] eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), unknown_tracers_out, @@ -1525,8 +1602,9 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources): jaxpr = params["jaxpr"] what = "pjit input" - if resource_env.physical_mesh != params['resource_env'].physical_mesh: - raise RuntimeError("Changing the physical mesh is not allowed inside pjit.") + if (resource_env is not None and params['resource_env'] is not None and + resource_env.physical_mesh != params['resource_env'].physical_mesh): + raise RuntimeError("Changing the physical mesh is not allowed inside pjit.") for aval, s in zip(jaxpr.in_avals, params['in_shardings']): if _is_unspecified(s) or _is_auto(s): @@ -1535,9 +1613,14 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r s._original_sharding, '_parsed_pspec'): parsed_pspec = s._original_sharding._parsed_pspec else: - parsed_pspec = parse_flatten_op_sharding( - s._op_sharding, resource_env.physical_mesh)[0] - _check_resources_against_named_axes(what, aval, parsed_pspec, named_axis_resources) + if resource_env is not None: + parsed_pspec = parse_flatten_op_sharding( + s._op_sharding, resource_env.physical_mesh)[0] + else: + parsed_pspec = None + if parsed_pspec is not None: + _check_resources_against_named_axes(what, aval, parsed_pspec, + named_axis_resources) pxla.resource_typecheck( jaxpr.jaxpr, resource_env, named_axis_resources, @@ -1552,9 +1635,14 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r s._original_sharding, '_parsed_pspec'): parsed_pspec = s._original_sharding._parsed_pspec else: - parsed_pspec = parse_flatten_op_sharding( - s._op_sharding, resource_env.physical_mesh)[0] - _check_resources_against_named_axes(what, aval, parsed_pspec, named_axis_resources) + if resource_env is not None: + parsed_pspec = parse_flatten_op_sharding( + s._op_sharding, resource_env.physical_mesh)[0] + else: + parsed_pspec = None + if parsed_pspec is not None: + _check_resources_against_named_axes(what, aval, parsed_pspec, + named_axis_resources) pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 512967649..35e039c31 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2710,8 +2710,10 @@ def _check_if_any_auto( def _get_and_check_device_assignment( - shardings: Iterable[sharding_internal.XLACompatibleSharding], - devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]: + shardings: Iterable[Union[sharding_internal.XLACompatibleSharding, + _UnspecifiedValue, _AUTOAxisResource]], + devices: Optional[Sequence[xc.Device]] +) -> Tuple[xla.Backend, Sequence[xc.Device]]: from jax._src.api import local_devices first_device_assignment = None diff --git a/tests/BUILD b/tests/BUILD index 6b94e90c3..0e3deeb30 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -199,6 +199,7 @@ jax_test( backend_tags = { "tpu": ["notsan"], # Times out under tsan. }, + enable_configs = ["cpu_jit_pjit_api_merge"], pjrt_c_api_bypass = True, shard_count = { "cpu": 5, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2f252c3d4..a1780e1a5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -600,7 +600,10 @@ class PJitTest(jtu.BufferDonationTestCase): f = pjit(lambda x: jax.grad(h)(x), in_axis_resources=None, out_axis_resources=None) x = jnp.arange(8, dtype=jnp.float32) - self.assertAllClose(f(x), jnp.cos(x)) + out = f(x) + self.assertAllClose(out, jnp.cos(x)) + if jax.config.jax_array: + self.assertLen(out.devices(), 2) @jtu.with_mesh([('x', 2)]) def testNoopPartitionSpecs(self): @@ -2081,8 +2084,8 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax_array(True) def test_pjit_single_device_sharding_add(self): - a = jnp.array([1, 2, 3], dtype=jnp.float32) - b = jnp.array([4, 5, 6], dtype=jnp.float32) + a = np.array([1, 2, 3], dtype=jnp.float32) + b = np.array([4, 5, 6], dtype=jnp.float32) @pjit def add(x, y): @@ -2462,11 +2465,18 @@ class ArrayPjitTest(jtu.JaxTestCase): out = jnp.zeros(shape, jnp.bfloat16) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) - with self.assertRaisesRegex( - ValueError, - "Pjit's devices and Array's devices should be equal. " - r"Got Pjit's device ids \[0\] on platform.*and " - r"Array's device ids \[0, 1, 2, 3\] on platform"): + # This split is needed because original `jit` adds `device` as a + # `devices_from_context` whereas `pjit` passes it as an in_sharding. + if jax.config.jax_jit_pjit_api_merge: + error_msg = ("Devices of all `Array` inputs and outputs should be the same. " + r"Got array device ids \[0\] on platform.*and " + r"another array's device ids \[0, 1, 2, 3\] on platform") + else: + error_msg = ("Pjit's devices and Array's devices should be equal. " + r"Got Pjit's device ids \[0\] on platform.*and " + r"Array's device ids \[0, 1, 2, 3\] on platform") + + with self.assertRaisesRegex(ValueError, error_msg): sharded_zeros((4096, 3072), P('x', 'y')) @jax_array(True) @@ -2920,7 +2930,6 @@ class ArrayPjitTest(jtu.JaxTestCase): _check(out2, jax.devices()[1], y) self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) h = pjit(mul, device=jax.devices()[-1]) h_out = h(y) @@ -2928,7 +2937,6 @@ class ArrayPjitTest(jtu.JaxTestCase): _check(h_out, jax.devices()[-1], y) self.assertEqual(cache_info3.hits, cache_info2.hits) - self.assertEqual(cache_info3.misses, cache_info2.misses + 1) # AOT test compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile() @@ -3130,6 +3138,31 @@ class ArrayPjitTest(jtu.JaxTestCase): # Second call is to trigger C++ dispatch. f(inp) # doesn't crash + def test_pjit_sin_nested(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + + @pjit + def f(x): + return jnp.sin(x) + + with mesh: + inp = jnp.arange(8.) + out = f(inp) + self.assertArraysAllClose(out, np.sin(inp)) + self.assertLen(out.devices(), 8) + + def test_jit_with_mesh_context_manager(self): + if not jax.config.jax_jit_pjit_api_merge: + self.skipTest("This test only works if jax_jit_pjit_api_merge is True") + + mesh = jtu.create_global_mesh((1,), ('x',)) + with self.assertRaisesRegex( + RuntimeError, + "jit does not support using the mesh context manager"): + with mesh: + jax.jit(lambda x: x, in_axis_resources=P('x'), + out_axis_resources=P('x'))(jnp.arange(8)) + class TempSharding(Sharding):