mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Refactorings to the jit implementation.
Notably: * We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code. * Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback. * If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me. No functional changes intended, this is just to improve readability. PiperOrigin-RevId: 617812557
This commit is contained in:
parent
2bd579bc61
commit
d3e03fff5d
@ -300,34 +300,10 @@ def jit(
|
||||
>>> g(jnp.arange(4), 3)
|
||||
Array([ 0, 1, 256, 6561], dtype=int32)
|
||||
"""
|
||||
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames) = pjit.pre_infer_params(
|
||||
return pjit.make_jit(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
|
||||
fun_sourceinfo = api_util.fun_sourceinfo(fun)
|
||||
fun_signature = api_util.fun_signature(fun)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
# TODO(yashkatariya): Remove this when it's added on jit.
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
pjit_info_args = pjit.PjitInfo(
|
||||
fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline, resource_env=None,
|
||||
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts)
|
||||
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = pjit._pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, device, backend)
|
||||
return pjit.post_infer_params(fun, infer_params, static_argnums,
|
||||
static_argnames, donate_argnums,
|
||||
abstracted_axes, has_explicit_sharding)
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
216
jax/_src/pjit.py
216
jax/_src/pjit.py
@ -132,9 +132,33 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
|
||||
return msg
|
||||
|
||||
|
||||
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
|
||||
class PjitInfo(NamedTuple):
|
||||
"""Things that we know about a jit instance before it is called.
|
||||
|
||||
In other words, this structure contains arguments to jit()/pjit(),
|
||||
preprocessed and validated.
|
||||
"""
|
||||
fun: Callable
|
||||
fun_sourceinfo: str | None
|
||||
fun_signature: inspect.Signature | None
|
||||
in_shardings: Any
|
||||
out_shardings: Any
|
||||
static_argnums: tuple[int, ...]
|
||||
static_argnames: tuple[str, ...]
|
||||
donate_argnums: tuple[int, ...]
|
||||
donate_argnames: tuple[str, ...]
|
||||
device: xc.Device | None
|
||||
backend: str | None
|
||||
keep_unused: bool
|
||||
inline: bool
|
||||
abstracted_axes: Any | None
|
||||
has_explicit_sharding: bool
|
||||
use_resource_env: bool # False for jit, True for pjit
|
||||
|
||||
|
||||
def _python_pjit_helper(jit_info, *args, **kwargs):
|
||||
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
|
||||
infer_params_fn(*args, **kwargs)
|
||||
_infer_params(jit_info, args, kwargs)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
if attrs_tracked:
|
||||
@ -145,6 +169,7 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
fun = jit_info.fun
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, arg_names)
|
||||
@ -165,14 +190,16 @@ def _get_states(attrs_tracked):
|
||||
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
|
||||
|
||||
|
||||
def _python_pjit(fun: Callable, infer_params_fn):
|
||||
def _python_pjit(jit_info: PjitInfo):
|
||||
|
||||
fun = jit_info.fun
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def wrapped(*args, **kwargs):
|
||||
if config.disable_jit.value:
|
||||
return fun(*args, **kwargs)
|
||||
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
|
||||
return _python_pjit_helper(jit_info, *args, **kwargs)[0]
|
||||
|
||||
def _python_pjit_evict_fn():
|
||||
_create_pjit_jaxpr.evict_function(fun) # type: ignore
|
||||
@ -254,31 +281,31 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding):
|
||||
return _cpp_pjit_cache
|
||||
|
||||
|
||||
def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
donate_argnums, pjit_has_explicit_sharding):
|
||||
def _cpp_pjit(jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
|
||||
fun, infer_params_fn, *args, **kwargs)
|
||||
jit_info, *args, **kwargs)
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
maybe_fastpath_data = _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects)
|
||||
return outs, maybe_fastpath_data
|
||||
|
||||
fun = jit_info.fun
|
||||
if xla_extension_version >= 226:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding)) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding))
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
@ -286,10 +313,29 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
return cpp_pjitted_f
|
||||
|
||||
|
||||
def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device,
|
||||
backend, abstracted_axes):
|
||||
def _pjit_explicit_sharding(in_shardings, out_shardings, device,
|
||||
backend) -> bool:
|
||||
in_shardings_flat, _ = tree_flatten(in_shardings)
|
||||
out_shardings_flat, _ = tree_flatten(out_shardings)
|
||||
return (device is not None or
|
||||
backend is not None or
|
||||
any(not is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not is_unspecified(i) for i in out_shardings_flat))
|
||||
|
||||
|
||||
def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
donate_argnums: int | Sequence[int] | None,
|
||||
donate_argnames: str | Iterable[str] | None,
|
||||
static_argnums: int | Sequence[int] | None,
|
||||
static_argnames: str | Iterable[str] | None,
|
||||
device: xc.Device | None, backend: str | None,
|
||||
abstracted_axes: Any | None, keep_unused: bool,
|
||||
inline: bool, use_resource_env: bool) -> PjitInfo:
|
||||
"""Parses the arguments to jit/pjit.
|
||||
|
||||
Performs any preprocessing and validation of the arguments that we can do
|
||||
ahead of time before the jit()-ed function is invoked.
|
||||
"""
|
||||
if abstracted_axes and not config.dynamic_shapes.value:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
|
||||
@ -326,18 +372,31 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
|
||||
fun, donate_argnums, donate_argnames, static_argnums, static_argnames)
|
||||
|
||||
return (in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames)
|
||||
fun_sourceinfo = api_util.fun_sourceinfo(fun)
|
||||
fun_signature = api_util.fun_signature(fun)
|
||||
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, device, backend)
|
||||
|
||||
return PjitInfo(
|
||||
fun=fun,
|
||||
fun_sourceinfo=fun_sourceinfo,
|
||||
fun_signature=fun_signature,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
abstracted_axes=abstracted_axes,
|
||||
has_explicit_sharding=has_explicit_sharding,
|
||||
use_resource_env=use_resource_env)
|
||||
|
||||
|
||||
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
donate_argnums, abstracted_axes,
|
||||
pjit_has_explicit_sharding):
|
||||
if abstracted_axes is None:
|
||||
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
donate_argnums, pjit_has_explicit_sharding)
|
||||
def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
if jit_info.abstracted_axes is None:
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params_fn)
|
||||
wrapped = _python_pjit(jit_info)
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
@ -348,8 +407,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars, in_layouts_flat, out_layouts_flat,
|
||||
arg_names, ()) = infer_params_fn(
|
||||
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
|
||||
arg_names, ()) = _infer_params(
|
||||
jit_info, args, kwargs, in_layouts=in_layouts, out_layouts=out_layouts)
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
try:
|
||||
@ -363,7 +422,9 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
fun = jit_info.fun
|
||||
fun_name = getattr(fun, '__qualname__',
|
||||
getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, arg_names)
|
||||
raise ValueError(msg) from None
|
||||
@ -375,8 +436,9 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
_, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn(
|
||||
*args, **kwargs, _in_layouts=None, _out_layouts=None)
|
||||
_, _, params, _, out_tree, _, _, _, _, _ = _infer_params(
|
||||
jit_info, args, kwargs, in_layouts=None, out_layouts=None
|
||||
)
|
||||
out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s)
|
||||
for s in params['out_shardings']]
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
|
||||
@ -387,52 +449,43 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
wrapped.eval_shape = eval_shape
|
||||
return wrapped
|
||||
|
||||
|
||||
def _pjit_explicit_sharding(in_shardings, out_shardings, device,
|
||||
backend) -> bool:
|
||||
in_shardings_flat, _ = tree_flatten(in_shardings)
|
||||
out_shardings_flat, _ = tree_flatten(out_shardings)
|
||||
return (device is not None or
|
||||
backend is not None or
|
||||
any(not is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not is_unspecified(i) for i in out_shardings_flat))
|
||||
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
donate_argnums: int | Sequence[int] | None,
|
||||
donate_argnames: str | Iterable[str] | None,
|
||||
static_argnums: int | Sequence[int] | None,
|
||||
static_argnames: str | Iterable[str] | None,
|
||||
device: xc.Device | None, backend: str | None,
|
||||
abstracted_axes: Any | None, keep_unused: bool,
|
||||
inline: bool, use_resource_env: bool) -> Any:
|
||||
"""jit() and pjit() are thin wrappers around this function."""
|
||||
jit_info = _parse_jit_arguments(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env)
|
||||
return _make_jit_wrapper(jit_info)
|
||||
|
||||
|
||||
class PjitInfo(NamedTuple):
|
||||
fun: Callable
|
||||
fun_sourceinfo: str | None
|
||||
fun_signature: inspect.Signature
|
||||
in_shardings: Any
|
||||
out_shardings: Any
|
||||
static_argnums: tuple[int, ...]
|
||||
static_argnames: tuple[str, ...]
|
||||
donate_argnums: tuple[int, ...]
|
||||
donate_argnames: tuple[str, ...]
|
||||
device: xc.Device | None
|
||||
backend: str | None
|
||||
keep_unused: bool
|
||||
inline: bool
|
||||
resource_env: Any
|
||||
abstracted_axes: Any | None
|
||||
in_layouts: Any # pytree[XlaCompatibleLayout] | None
|
||||
out_layouts: Any # pytree[XlaCompatibleLayout] | None
|
||||
|
||||
|
||||
def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
|
||||
(fun, fun_sourceinfo, fun_signature, user_in_shardings, user_out_shardings,
|
||||
static_argnums, static_argnames,
|
||||
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
|
||||
resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args
|
||||
abstracted_axes, _, use_resource_env) = jit_info
|
||||
|
||||
if (kwargs and user_in_shardings is not None and
|
||||
not is_unspecified(user_in_shardings)):
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
if resource_env is not None:
|
||||
if use_resource_env:
|
||||
# We need to fetch the mesh from inside the wrapped function, because
|
||||
# meshes are dynamically scoped (i.e., with a context manager).
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
pjit_mesh = resource_env.physical_mesh
|
||||
jit_name = 'pjit'
|
||||
else:
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
jit_name = 'jit'
|
||||
|
||||
if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
|
||||
raise ValueError(
|
||||
@ -441,8 +494,6 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
|
||||
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
||||
|
||||
jit_name = 'jit' if resource_env is None else 'pjit'
|
||||
|
||||
dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs,
|
||||
static_argnums, static_argnames)
|
||||
f = lu.wrap_init(fun)
|
||||
@ -782,39 +833,10 @@ def pjit(
|
||||
... print(f(x)) # doctest: +SKIP
|
||||
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
||||
"""
|
||||
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames) = pre_infer_params(
|
||||
return make_jit(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes)
|
||||
|
||||
fun_sourceinfo = api_util.fun_sourceinfo(fun)
|
||||
fun_signature = api_util.fun_signature(fun)
|
||||
|
||||
def infer_params(*args, **kwargs):
|
||||
# Putting this outside of wrapped would make resources lexically scoped
|
||||
resource_env = mesh_lib.thread_resources.env
|
||||
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
|
||||
# layout.DefaultLayout() when out of experimental.
|
||||
in_layouts = kwargs.pop('_in_layouts', None)
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
pjit_info_args = PjitInfo(
|
||||
fun=fun,
|
||||
fun_sourceinfo=fun_sourceinfo,
|
||||
fun_signature=fun_signature,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
|
||||
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts)
|
||||
return common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, device, backend)
|
||||
return post_infer_params(fun, infer_params, static_argnums, static_argnames,
|
||||
donate_argnums, abstracted_axes,
|
||||
has_explicit_sharding)
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env=True)
|
||||
|
||||
|
||||
def hashable_pytree(pytree):
|
||||
|
Loading…
x
Reference in New Issue
Block a user