Refactorings to the jit implementation.

Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
This commit is contained in:
Peter Hawkins 2024-03-21 05:35:44 -07:00 committed by jax authors
parent 2bd579bc61
commit d3e03fff5d
2 changed files with 122 additions and 124 deletions

View File

@ -300,34 +300,10 @@ def jit(
>>> g(jnp.arange(4), 3)
Array([ 0, 1, 256, 6561], dtype=int32)
"""
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pjit.pre_infer_params(
return pjit.make_jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
def infer_params(*args, **kwargs):
# TODO(yashkatariya): Remove this when it's added on jit.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = pjit.PjitInfo(
fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=None,
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
out_layouts=out_layouts)
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = pjit._pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)
return pjit.post_infer_params(fun, infer_params, static_argnums,
static_argnames, donate_argnums,
abstracted_axes, has_explicit_sharding)
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env=False)
@contextmanager

View File

@ -132,9 +132,33 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
return msg
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
class PjitInfo(NamedTuple):
"""Things that we know about a jit instance before it is called.
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun: Callable
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
in_shardings: Any
out_shardings: Any
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: xc.Device | None
backend: str | None
keep_unused: bool
inline: bool
abstracted_axes: Any | None
has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit
def _python_pjit_helper(jit_info, *args, **kwargs):
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
infer_params_fn(*args, **kwargs)
_infer_params(jit_info, args, kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
if attrs_tracked:
@ -145,6 +169,7 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
@ -165,14 +190,16 @@ def _get_states(attrs_tracked):
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
def _python_pjit(fun: Callable, infer_params_fn):
def _python_pjit(jit_info: PjitInfo):
fun = jit_info.fun
@wraps(fun)
@api_boundary
def wrapped(*args, **kwargs):
if config.disable_jit.value:
return fun(*args, **kwargs)
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
return _python_pjit_helper(jit_info, *args, **kwargs)[0]
def _python_pjit_evict_fn():
_create_pjit_jaxpr.evict_function(fun) # type: ignore
@ -254,31 +281,31 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding):
return _cpp_pjit_cache
def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding):
def _cpp_pjit(jit_info: PjitInfo):
@api_boundary
def cache_miss(*args, **kwargs):
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
fun, infer_params_fn, *args, **kwargs)
jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects)
return outs, maybe_fastpath_data
fun = jit_info.fun
if xla_extension_version >= 226:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
_get_cpp_global_cache(jit_info.has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(pjit_has_explicit_sharding))
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(jit_info.has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
@ -286,10 +313,29 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
return cpp_pjitted_f
def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames,
static_argnums, static_argnames, device,
backend, abstracted_axes):
def _pjit_explicit_sharding(in_shardings, out_shardings, device,
backend) -> bool:
in_shardings_flat, _ = tree_flatten(in_shardings)
out_shardings_flat, _ = tree_flatten(out_shardings)
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(i) for i in out_shardings_flat))
def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, use_resource_env: bool) -> PjitInfo:
"""Parses the arguments to jit/pjit.
Performs any preprocessing and validation of the arguments that we can do
ahead of time before the jit()-ed function is invoked.
"""
if abstracted_axes and not config.dynamic_shapes.value:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
@ -326,18 +372,31 @@ def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
fun, donate_argnums, donate_argnames, static_argnums, static_argnames)
return (in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)
return PjitInfo(
fun=fun,
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
has_explicit_sharding=has_explicit_sharding,
use_resource_env=use_resource_env)
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, abstracted_axes,
pjit_has_explicit_sharding):
if abstracted_axes is None:
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding)
def _make_jit_wrapper(jit_info: PjitInfo):
if jit_info.abstracted_axes is None:
wrapped = _cpp_pjit(jit_info)
else:
wrapped = _python_pjit(fun, infer_params_fn)
wrapped = _python_pjit(jit_info)
@api_boundary
def lower(*args, **kwargs):
@ -348,8 +407,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
out_layouts = kwargs.pop('_out_layouts', None)
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, in_layouts_flat, out_layouts_flat,
arg_names, ()) = infer_params_fn(
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
arg_names, ()) = _infer_params(
jit_info, args, kwargs, in_layouts=in_layouts, out_layouts=out_layouts)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
try:
@ -363,7 +422,9 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__',
getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
raise ValueError(msg) from None
@ -375,8 +436,9 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
@api_boundary
def eval_shape(*args, **kwargs):
_, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn(
*args, **kwargs, _in_layouts=None, _out_layouts=None)
_, _, params, _, out_tree, _, _, _, _, _ = _infer_params(
jit_info, args, kwargs, in_layouts=None, out_layouts=None
)
out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s)
for s in params['out_shardings']]
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
@ -387,52 +449,43 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
wrapped.eval_shape = eval_shape
return wrapped
def _pjit_explicit_sharding(in_shardings, out_shardings, device,
backend) -> bool:
in_shardings_flat, _ = tree_flatten(in_shardings)
out_shardings_flat, _ = tree_flatten(out_shardings)
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(i) for i in out_shardings_flat))
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, use_resource_env: bool) -> Any:
"""jit() and pjit() are thin wrappers around this function."""
jit_info = _parse_jit_arguments(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env)
return _make_jit_wrapper(jit_info)
class PjitInfo(NamedTuple):
fun: Callable
fun_sourceinfo: str | None
fun_signature: inspect.Signature
in_shardings: Any
out_shardings: Any
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: xc.Device | None
backend: str | None
keep_unused: bool
inline: bool
resource_env: Any
abstracted_axes: Any | None
in_layouts: Any # pytree[XlaCompatibleLayout] | None
out_layouts: Any # pytree[XlaCompatibleLayout] | None
def common_infer_params(pjit_info_args, *args, **kwargs):
def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
(fun, fun_sourceinfo, fun_signature, user_in_shardings, user_out_shardings,
static_argnums, static_argnames,
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args
abstracted_axes, _, use_resource_env) = jit_info
if (kwargs and user_in_shardings is not None and
not is_unspecified(user_in_shardings)):
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
if resource_env is not None:
if use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
jit_name = 'pjit'
else:
resource_env = None
pjit_mesh = None
jit_name = 'jit'
if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
raise ValueError(
@ -441,8 +494,6 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
jit_name = 'jit' if resource_env is None else 'pjit'
dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs,
static_argnums, static_argnames)
f = lu.wrap_init(fun)
@ -782,39 +833,10 @@ def pjit(
... print(f(x)) # doctest: +SKIP
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
"""
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pre_infer_params(
return make_jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
resource_env = mesh_lib.thread_resources.env
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
# layout.DefaultLayout() when out of experimental.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = PjitInfo(
fun=fun,
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
out_layouts=out_layouts)
return common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)
return post_infer_params(fun, infer_params, static_argnums, static_argnames,
donate_argnums, abstracted_axes,
has_explicit_sharding)
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env=True)
def hashable_pytree(pytree):