Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)

Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
This commit is contained in:
Yash Katariya 2023-01-12 17:23:55 -08:00 committed by jax authors
parent 7206cb5b7b
commit c8ad89e358
6 changed files with 414 additions and 260 deletions

View File

@ -282,10 +282,36 @@ def jit(
return _jit(False, fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline, keep_unused, abstracted_axes)
# TODO(yashkatariya): Remove the above jit function after
# `jax_jit_pjit_api_merge` defaults to True.
if jax.config.jax_jit_pjit_api_merge:
jit = pjit.pjit # type: ignore # noqa: F811
def jit( # type: ignore # noqa: F811 # pylint: disable=function-redefined
fun: Callable,
in_axis_resources=pxla._UNSPECIFIED,
out_axis_resources=pxla._UNSPECIFIED,
static_argnums: Union[int, Sequence[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
keep_unused: bool = False,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
inline: bool = False,
) -> stages.Wrapped:
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
static_argnames) = pjit.pre_infer_params(
fun, in_axis_resources, out_axis_resources, donate_argnums,
static_argnums, static_argnames, device, backend)
def infer_params(*args, **kwargs):
pjit_info_args = pjit.PjitInfo(
fun=fun, in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
device=device, backend=backend, keep_unused=keep_unused,
inline=inline, resource_env=None)
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
return pjit.post_infer_params(fun, infer_params, static_argnums,
static_argnames)
def _jit(

View File

@ -855,8 +855,12 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# Update pjit params to account for extra error values.
num_error_vals = len(err_vals)
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
sharding = OpShardingSharding.get_replicated(
list(resource_env.physical_mesh.devices.flat))
if jax.config.jax_array:
sharding = pjit._UNSPECIFIED
else:
sharding = OpShardingSharding.get_replicated(
list(resource_env.physical_mesh.devices.flat))
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)

View File

@ -17,7 +17,7 @@ from enum import IntEnum
import numpy as np
from collections import OrderedDict, Counter
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
Iterable)
Iterable, NamedTuple, Any)
import itertools as it
from functools import partial, lru_cache
import threading
@ -170,7 +170,237 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums, static_argnames):
return wraps(fun)(cpp_pjit_f)
# TODO(yashkatariya): Add pjit microbenchmarks.
def pre_infer_params(fun, in_axis_resources, out_axis_resources,
donate_argnums, static_argnums, static_argnames, device,
backend):
check_callable(fun)
if not config.jax_array and (_is_unspecified(in_axis_resources) or
_is_unspecified(out_axis_resources)):
raise ValueError(
"in_axis_resources and out_axis_resources should not "
"be the unspecified singleton value. Please enable `jax.Array` to use "
"this feature. You can use jax.config.update('jax_array', True) or "
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
"boolean flag to something true-like.")
if backend is not None or device is not None:
warnings.warn(
'backend and device argument on jit is deprecated. You can use a '
'`jax.sharding.Mesh` context manager or device_put the arguments '
'before passing them to `jit`. Please see '
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
'for more information.', DeprecationWarning)
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if not _is_unspecified(in_axis_resources):
raise ValueError('If backend or device is specified on jit, then '
'in_axis_resources should not be specified.')
if not _is_unspecified(out_axis_resources):
raise ValueError('If backend or device is specified on jit, then '
'out_axis_resources should not be specified.')
if isinstance(in_axis_resources, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axis_resources = tuple(in_axis_resources)
in_axis_resources, _, _ = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")
donate_argnums, static_argnums, static_argnames = resolve_argnums(
fun, donate_argnums, static_argnums, static_argnames)
return (in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
static_argnames)
def post_infer_params(fun, infer_params, static_argnums, static_argnames):
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames)
else:
wrapped = _python_pjit(fun, infer_params)
def lower(*args, **kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, **kwargs)
if config.jax_array:
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
else:
in_shardings = params['in_shardings']
in_is_global = _calc_is_global_sequence(
params['in_positional_semantics'], in_shardings)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
in_is_global, params['keep_unused'], always_lower=True)
if kwargs:
args_kwargs_in_tree = in_tree
local_in_avals = in_tree.unflatten(flat_local_in_avals)
else:
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return stages.Lowered.from_flat_info(
lowering,
args_kwargs_in_tree,
local_in_avals,
donate_argnums,
out_tree,
no_kwargs=True)
wrapped.lower = lower
return wrapped
class PjitInfo(NamedTuple):
fun: Callable
in_axis_resources: Any
out_axis_resources: Any
static_argnums: Tuple[int, ...]
static_argnames: Tuple[str, ...]
donate_argnums: Tuple[int, ...]
device: Optional[xc.Device]
backend: Optional[str]
keep_unused: bool
inline: bool
resource_env: Any
def common_infer_params(pjit_info_args, *args, **kwargs):
(fun, in_axis_resources, out_axis_resources, static_argnums, static_argnames,
donate_argnums, device, backend, keep_unused, inline,
resource_env) = pjit_info_args
if kwargs and not _is_unspecified(in_axis_resources):
raise ValueError(
"pjit does not support kwargs when in_axis_resources is specified.")
if resource_env is not None:
pjit_mesh = resource_env.physical_mesh
if pjit_mesh.empty:
if config.jax_array:
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
# if mesh is not empty then pjit will respect it.
pass
else:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")
else:
pjit_mesh = None
if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args,
allow_invalid=True)
del args
# TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is
# flatten_axes which if kwargs are present in the treedef (even empty {}),
# leads to wrong expansion.
if kwargs:
f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
flat_fun, out_tree = flatten_fun(f, in_tree)
else:
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
dyn_kwargs = ()
del kwargs
if donate_argnums and not config.jax_debug_nans:
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
else:
donated_invars = (False,) * len(args_flat)
if config.jax_array:
# If backend or device is set as an arg on jit, then resolve them to
# in_shardings and out_shardings as if user passed in in_axis_resources
# and out_axis_resources.
if backend or device:
in_shardings = out_shardings = _create_sharding_with_device_backend(
device, backend)
else:
in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
else:
in_shardings = tree_map(
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
in_axis_resources)
out_shardings = tree_map(
lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
# This check fails extremely rarely and has a huge cost in the dispatch
# path. So hide it behind the jax_enable_checks flag.
if config.jax_enable_checks:
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
if config.jax_array:
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
else:
in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
out_positional_semantics = (
pxla._PositionalSemantics.GLOBAL
if config.jax_parallel_functions_output_gda or config.jax_array else
pxla._positional_semantics.val)
global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat), resource_env)
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), global_in_avals,
HashableFunction(out_tree, closure=()))
if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or
not config.jax_array):
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
canonicalized_in_shardings_flat, args_flat)
assert len(args_flat) == len(canonicalized_in_shardings_flat)
canonicalized_in_shardings_flat = (
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
donated_invars = (False,) * len(consts) + donated_invars
in_positional_semantics = (
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics
# in_shardings and out_shardings here are all OpShardingSharding.
params = dict(
jaxpr=jaxpr,
in_shardings=canonicalized_in_shardings_flat,
out_shardings=canonicalized_out_shardings_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unnamed function>'),
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline,
)
return (consts + args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
# in_axis_resources and out_axis_resources can't be None as the default value
# because `None` means that the input is fully replicated.
def pjit(
@ -320,206 +550,24 @@ def pjit(
... print(f(x)) # doctest: +SKIP
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
"""
check_callable(fun)
if not config.jax_array and (_is_unspecified(in_axis_resources) or
_is_unspecified(out_axis_resources)):
raise ValueError(
"in_axis_resources and out_axis_resources should not "
"be the unspecified singleton value. Please enable `jax.Array` to use "
"this feature. You can use jax.config.update('jax_array', True) or "
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
"boolean flag to something true-like.")
if backend is not None or device is not None:
warnings.warn(
'backend and device argument on jit is deprecated. You can use a '
'`jax.sharding.Mesh` context manager or device_put the arguments '
'before passing them to `jit`. Please see '
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
'for more information.', DeprecationWarning)
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if not _is_unspecified(in_axis_resources):
raise ValueError('If backend or device is specified on jit, then '
'in_axis_resources should not be specified.')
if not _is_unspecified(out_axis_resources):
raise ValueError('If backend or device is specified on jit, then '
'out_axis_resources should not be specified.')
if isinstance(in_axis_resources, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axis_resources = tuple(in_axis_resources)
in_axis_resources, _, _ = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")
donate_argnums, static_argnums, static_argnames = resolve_argnums(
fun, donate_argnums, static_argnums, static_argnames)
(in_axis_resources, out_axis_resources, donate_argnums, static_argnums,
static_argnames) = pre_infer_params(
fun, in_axis_resources, out_axis_resources, donate_argnums,
static_argnums, static_argnames, device, backend)
def infer_params(*args, **kwargs):
if kwargs and not _is_unspecified(in_axis_resources):
raise ValueError(
"pjit does not support kwargs when in_axis_resources is specified.")
# Putting this outside of wrapped would make resources lexically scoped
resource_env = pxla.thread_resources.env
pjit_mesh = resource_env.physical_mesh
if pjit_mesh.empty:
if config.jax_array:
# Don't enforce requiring a mesh when `jax_array` flag is enabled. But
# if mesh is not empty then pjit will respect it.
pass
else:
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")
pjit_info_args = PjitInfo(
fun=fun, in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
device=device, backend=backend, keep_unused=keep_unused,
inline=inline, resource_env=resource_env)
return common_infer_params(pjit_info_args, *args, **kwargs)
if (backend or device) and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
return post_infer_params(fun, infer_params, static_argnums, static_argnames)
f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args,
allow_invalid=True)
del args
# TODO(yashkatariya): Merge the nokwargs and kwargs path. One blocker is
# flatten_axes which if kwargs are present in the treedef (even empty {}),
# leads to wrong expansion.
if kwargs:
f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
flat_fun, out_tree = flatten_fun(f, in_tree)
else:
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
dyn_kwargs = ()
del kwargs
if donate_argnums and not config.jax_debug_nans:
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
else:
donated_invars = (False,) * len(args_flat)
if config.jax_array:
# If backend or device is set as an arg on jit, then resolve them to
# in_shardings and out_shardings as if user passed in in_axis_resources
# and out_axis_resources.
if backend or device:
in_shardings = out_shardings = _create_sharding_with_device_backend(
device, backend)
else:
in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
else:
in_shardings = tree_map(
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
in_axis_resources)
out_shardings = tree_map(
lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
# This check fails extremely rarely and has a huge cost in the dispatch
# path. So hide it behind the jax_enable_checks flag.
if config.jax_enable_checks:
_maybe_check_pjit_gda_mesh(args_flat, pjit_mesh)
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
if config.jax_array:
in_positional_semantics = (pxla._PositionalSemantics.GLOBAL,) * len(args_flat)
else:
in_positional_semantics = tuple(tree_map(_get_in_positional_semantics, args_flat))
out_positional_semantics = (
pxla._PositionalSemantics.GLOBAL
if config.jax_parallel_functions_output_gda or config.jax_array else
pxla._positional_semantics.val)
global_in_avals, canonicalized_in_shardings_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat), resource_env)
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), global_in_avals,
HashableFunction(out_tree, closure=()))
if (any(_is_from_gda(i) for i in canonicalized_in_shardings_flat) or
not config.jax_array):
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
canonicalized_in_shardings_flat, args_flat)
assert len(args_flat) == len(canonicalized_in_shardings_flat)
canonicalized_in_shardings_flat = (
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
donated_invars = (False,) * len(consts) + donated_invars
in_positional_semantics = (
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics
# in_shardings and out_shardings here are all OpShardingSharding.
params = dict(
jaxpr=jaxpr,
in_shardings=canonicalized_in_shardings_flat,
out_shardings=canonicalized_out_shardings_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unnamed function>'),
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline,
)
return (consts + args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames)
else:
wrapped = _python_pjit(fun, infer_params)
def lower(*args, **kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, **kwargs)
if config.jax_array:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'],
params['resource_env'].physical_mesh)
else:
in_shardings = params['in_shardings']
in_is_global = _calc_is_global_sequence(
params['in_positional_semantics'], in_shardings)
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
in_is_global, params['keep_unused'], always_lower=True)
if kwargs:
args_kwargs_in_tree = in_tree
local_in_avals = in_tree.unflatten(flat_local_in_avals)
else:
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return stages.Lowered.from_flat_info(
lowering,
args_kwargs_in_tree,
local_in_avals,
donate_argnums,
out_tree,
no_kwargs=True)
wrapped.lower = lower
return wrapped
class _ListWithW(list):
__slots__ = ('__weakref__',)
@ -543,6 +591,10 @@ def _create_sharding_for_array(mesh, x):
# FROM_GDA is removed.
if isinstance(x, XLACompatibleSharding) or _is_unspecified_or_from_gda_or_auto(x):
return x
if mesh is None:
raise RuntimeError(
"jit does not support using the mesh context manager. Please pass in "
"the sharding explicitly via in_axis_resources or out_axis_resources.")
if mesh.empty:
raise RuntimeError("pjit requires a non-empty mesh! Is a mesh defined at "
"the call site? Alternatively, provide a "
@ -944,7 +996,10 @@ pjit_p = core.Primitive("pjit")
pjit_p.multiple_results = True
def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
def _resolve_in_shardings(
args, pjit_in_shardings: Sequence[PjitSharding],
out_shardings: Sequence[PjitSharding],
pjit_mesh: Optional[pxla.Mesh]) -> Sequence[PjitSharding]:
# If True, means that device or backend is set by the user on pjit and it
# has the same semantics as device_put i.e. doesn't matter which device the
# arg is on, reshard it to the device mentioned. So don't do any of the
@ -972,7 +1027,7 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
pxla._get_and_check_device_assignment(
it.chain(
committed_arg_shardings, pjit_in_shardings, out_shardings),
(None if pjit_mesh.empty else list(pjit_mesh.devices.flat)))
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
resolved_in_shardings = []
for arg, pjit_in_s in safe_zip(args, pjit_in_shardings):
@ -1037,8 +1092,9 @@ def _pjit_call_impl(*args, jaxpr,
global _most_recent_pjit_call_executable
if config.jax_array:
in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings,
resource_env.physical_mesh)
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
in_is_global = _calc_is_global_sequence(in_positional_semantics, in_shardings)
if config.jax_array and all(_is_unspecified(o) for o in out_shardings):
@ -1145,11 +1201,15 @@ def _pjit_lower_cached(
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)
mesh = resource_env.physical_mesh
if resource_env is not None:
mesh = resource_env.physical_mesh
else:
mesh = None
# Convert to `NamedSharding` when `jax_array` is not enabled. This is
# because GDA/SDA/DA are dependent on mesh for generating outputs.
@ -1187,7 +1247,8 @@ def _pjit_lower_cached(
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
always_lower=always_lower,
devices_from_context=(None if mesh.empty else list(mesh.devices.flat)))
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)))
def pjit_staging_rule(trace, *args, **params):
@ -1207,6 +1268,8 @@ def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
disallowed_effects = jaxpr.effects - mlir.lowerable_effects
if disallowed_effects:
raise ValueError('Effects not supported in `pjit`.')
if config.jax_array:
return jaxpr.out_avals, jaxpr.effects
return global_to_local(out_positional_semantics, jaxpr.out_avals,
out_shardings, resource_env.physical_mesh), jaxpr.effects
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
@ -1216,7 +1279,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
in_positional_semantics, out_positional_semantics,
keep_unused, inline):
if not config.jax_array:
if not config.jax_jit_pjit_api_merge:
if not isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
raise RuntimeError("Nesting pjit() inside jit() is not allowed.")
@ -1264,7 +1327,12 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
# `insert_axis` is set to True only for some `xmap` uses.
new_parts = (axis_name,) if insert_axis else (
() if spmd_axis_name is None else (spmd_axis_name,))
mesh = resource_env.physical_mesh
if resource_env is not None:
mesh = resource_env.physical_mesh
else:
mesh = None
in_shardings = tuple(
_pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i
for is_mapped, i, aval in zip(is_mapped_in, in_shardings, new_jaxpr.in_avals))
@ -1302,7 +1370,7 @@ def _pjit_batcher_for_sharding(
return OpShardingSharding(s._device_assignment, new_op) # type: ignore
else:
assert isinstance(s, OpShardingSharding)
assert not mesh.empty
assert mesh is not None and not mesh.empty
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] # type: ignore
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
@ -1378,32 +1446,37 @@ def _pjit_partial_eval(trace, *in_tracers,
keep_unused=keep_unused,
inline=inline)
if num_residuals:
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_shardings'])
compiled = _pjit_lower(
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global, known_params['keep_unused'], always_lower=False).compile(
_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
da = compiled._device_assignment
_, out_op_sharding_shardings = pxla._get_op_sharding_shardings_from_executable(
compiled.xla_executable, da, len(known_jaxpr.in_avals),
len(known_jaxpr.out_avals))
assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), (
len(out_op_sharding_shardings), len(known_jaxpr.out_avals))
out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in
safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)]
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
else:
residual_op_shardings = ()
assert len(residual_shardings) == len(residual_op_shardings), (
len(residual_shardings), len(residual_op_shardings))
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
known_params['out_shardings'] = (
keep_where(out_shardings, known_outs) + residual_shardings)
# resource_env is None in the jit wrapper around pjit.
# TODO(apaszke,yashkatariya): Replace this check with
# `if not config.jax_array` after XLA stops overriding user shardings when
# `_allow_propagation_to_outputs = True`.
if resource_env is not None:
if num_residuals:
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_shardings'])
compiled = _pjit_lower(
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global, known_params['keep_unused'], always_lower=False).compile(
_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
da = compiled._device_assignment
_, out_op_sharding_shardings = pxla._get_op_sharding_shardings_from_executable(
compiled.xla_executable, da, len(known_jaxpr.in_avals),
len(known_jaxpr.out_avals))
assert len(out_op_sharding_shardings) == len(known_jaxpr.out_avals), (
len(out_op_sharding_shardings), len(known_jaxpr.out_avals))
out_op_shardings = [o._to_xla_op_sharding(a.ndim) for o, a in
safe_zip(out_op_sharding_shardings, known_jaxpr.out_avals)]
residual_op_shardings = tuple(out_op_shardings[-num_residuals:])
else:
residual_op_shardings = ()
assert len(residual_shardings) == len(residual_op_shardings), (
len(residual_shardings), len(residual_op_shardings))
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
known_params['out_shardings'] = (
keep_where(out_shardings, known_outs) + residual_shardings)
all_known_outs = pjit_p.bind(
*(pv.get_known() for pv in in_pvals if pv.is_known()),
@ -1436,12 +1509,16 @@ def _pjit_partial_eval(trace, *in_tracers,
keep_unused=keep_unused,
inline=inline)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
if config.jax_array:
unknown_out_avals = unknown_jaxpr.out_avals
else:
unknown_out_avals = global_to_local(
unknown_params["out_positional_semantics"], unknown_jaxpr.out_avals,
unknown_params["out_shardings"],
unknown_params["resource_env"].physical_mesh)
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in global_to_local(unknown_params["out_positional_semantics"],
unknown_jaxpr.out_avals,
unknown_params["out_shardings"],
unknown_params["resource_env"].physical_mesh)
for aval in unknown_out_avals
]
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
unknown_tracers_out,
@ -1525,8 +1602,9 @@ def _check_resources_against_named_axes(what, aval, pos_axis_resources, named_ax
def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_resources):
jaxpr = params["jaxpr"]
what = "pjit input"
if resource_env.physical_mesh != params['resource_env'].physical_mesh:
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")
if (resource_env is not None and params['resource_env'] is not None and
resource_env.physical_mesh != params['resource_env'].physical_mesh):
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")
for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
if _is_unspecified(s) or _is_auto(s):
@ -1535,9 +1613,14 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
s._original_sharding, '_parsed_pspec'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
parsed_pspec = parse_flatten_op_sharding(
s._op_sharding, resource_env.physical_mesh)[0]
_check_resources_against_named_axes(what, aval, parsed_pspec, named_axis_resources)
if resource_env is not None:
parsed_pspec = parse_flatten_op_sharding(
s._op_sharding, resource_env.physical_mesh)[0]
else:
parsed_pspec = None
if parsed_pspec is not None:
_check_resources_against_named_axes(what, aval, parsed_pspec,
named_axis_resources)
pxla.resource_typecheck(
jaxpr.jaxpr, resource_env, named_axis_resources,
@ -1552,9 +1635,14 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
s._original_sharding, '_parsed_pspec'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
parsed_pspec = parse_flatten_op_sharding(
s._op_sharding, resource_env.physical_mesh)[0]
_check_resources_against_named_axes(what, aval, parsed_pspec, named_axis_resources)
if resource_env is not None:
parsed_pspec = parse_flatten_op_sharding(
s._op_sharding, resource_env.physical_mesh)[0]
else:
parsed_pspec = None
if parsed_pspec is not None:
_check_resources_against_named_axes(what, aval, parsed_pspec,
named_axis_resources)
pxla.custom_resource_typing_rules[pjit_p] = _resource_typing_pjit

View File

@ -2710,8 +2710,10 @@ def _check_if_any_auto(
def _get_and_check_device_assignment(
shardings: Iterable[sharding_internal.XLACompatibleSharding],
devices: Optional[Sequence[xc.Device]]) -> Tuple[xla.Backend, Sequence[xc.Device]]:
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
_UnspecifiedValue, _AUTOAxisResource]],
devices: Optional[Sequence[xc.Device]]
) -> Tuple[xla.Backend, Sequence[xc.Device]]:
from jax._src.api import local_devices
first_device_assignment = None

View File

@ -199,6 +199,7 @@ jax_test(
backend_tags = {
"tpu": ["notsan"], # Times out under tsan.
},
enable_configs = ["cpu_jit_pjit_api_merge"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 5,

View File

@ -600,7 +600,10 @@ class PJitTest(jtu.BufferDonationTestCase):
f = pjit(lambda x: jax.grad(h)(x),
in_axis_resources=None, out_axis_resources=None)
x = jnp.arange(8, dtype=jnp.float32)
self.assertAllClose(f(x), jnp.cos(x))
out = f(x)
self.assertAllClose(out, jnp.cos(x))
if jax.config.jax_array:
self.assertLen(out.devices(), 2)
@jtu.with_mesh([('x', 2)])
def testNoopPartitionSpecs(self):
@ -2081,8 +2084,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_pjit_single_device_sharding_add(self):
a = jnp.array([1, 2, 3], dtype=jnp.float32)
b = jnp.array([4, 5, 6], dtype=jnp.float32)
a = np.array([1, 2, 3], dtype=jnp.float32)
b = np.array([4, 5, 6], dtype=jnp.float32)
@pjit
def add(x, y):
@ -2462,11 +2465,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = jnp.zeros(shape, jnp.bfloat16)
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
with self.assertRaisesRegex(
ValueError,
"Pjit's devices and Array's devices should be equal. "
r"Got Pjit's device ids \[0\] on platform.*and "
r"Array's device ids \[0, 1, 2, 3\] on platform"):
# This split is needed because original `jit` adds `device` as a
# `devices_from_context` whereas `pjit` passes it as an in_sharding.
if jax.config.jax_jit_pjit_api_merge:
error_msg = ("Devices of all `Array` inputs and outputs should be the same. "
r"Got array device ids \[0\] on platform.*and "
r"another array's device ids \[0, 1, 2, 3\] on platform")
else:
error_msg = ("Pjit's devices and Array's devices should be equal. "
r"Got Pjit's device ids \[0\] on platform.*and "
r"Array's device ids \[0, 1, 2, 3\] on platform")
with self.assertRaisesRegex(ValueError, error_msg):
sharded_zeros((4096, 3072), P('x', 'y'))
@jax_array(True)
@ -2920,7 +2930,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
_check(out2, jax.devices()[1], y)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
h = pjit(mul, device=jax.devices()[-1])
h_out = h(y)
@ -2928,7 +2937,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
_check(h_out, jax.devices()[-1], y)
self.assertEqual(cache_info3.hits, cache_info2.hits)
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
# AOT test
compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile()
@ -3130,6 +3138,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
# Second call is to trigger C++ dispatch.
f(inp) # doesn't crash
def test_pjit_sin_nested(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
@pjit
def f(x):
return jnp.sin(x)
with mesh:
inp = jnp.arange(8.)
out = f(inp)
self.assertArraysAllClose(out, np.sin(inp))
self.assertLen(out.devices(), 8)
def test_jit_with_mesh_context_manager(self):
if not jax.config.jax_jit_pjit_api_merge:
self.skipTest("This test only works if jax_jit_pjit_api_merge is True")
mesh = jtu.create_global_mesh((1,), ('x',))
with self.assertRaisesRegex(
RuntimeError,
"jit does not support using the mesh context manager"):
with mesh:
jax.jit(lambda x: x, in_axis_resources=P('x'),
out_axis_resources=P('x'))(jnp.arange(8))
class TempSharding(Sharding):