mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry non-None debug info. We change `lu.wrap_init` to construct the result paths thunk whenever it is passed a `debug_info`. The goal is to make sure that all `WrappedFun` have a debug info with result paths support. We change some calling conventions for internal functions to not pass along a separate debug_info if we have a `WrappedFun` or a `Jaxpr`. We obtain several improvements in presence of debug infos in debug_info_test.py
457 lines
17 KiB
Python
457 lines
17 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import functools
|
|
from typing import Any
|
|
|
|
from jax._src import api_util
|
|
from jax._src import core
|
|
from jax._src import custom_api_util
|
|
from jax._src import errors
|
|
from jax._src import linear_util as lu
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
|
|
source_info_util.register_exclusion(__file__)
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
|
|
@custom_api_util.register_custom_decorator_type
|
|
class custom_dce:
|
|
"""Customize the DCE behavior of a JAX-transformable function.
|
|
|
|
JAX uses dead code elimination (DCE) to remove unused computations from a
|
|
JAX program. This typically works transparently when the program is
|
|
completely specified by known JAX operations, but opaque kernels like calls
|
|
to :py:func:`~jax.experimental.pallas.pallas_call` or
|
|
:py:func:`~jax.ffi.ffi_call`, for example, may cause problems.
|
|
|
|
In JAX, DCE is performed when a function is staged out using
|
|
:py:func:`jax.jit`, so it won't be applied when running JAX in eager mode.
|
|
Similarly, the ``custom_dce`` decorator requires that both the decorated
|
|
function and the custom DCE rule be compatible with :py:func:`~jax.jit`.
|
|
|
|
This decorator allows users to customize the DCE behavior of a function by
|
|
defining a custom DCE rule. For a ``custom_dce`` wrapped function
|
|
``f(*args)``, the signature of the DCE rule is ``dce_rule(used_outs, *args)``
|
|
where ``used_outs`` is a Pytree with the same structure as the output of
|
|
``f``, and each leaf is is a ``bool`` indicating which outputs should
|
|
be computed. The remaining arguments ``*args`` are the original arguments to
|
|
``f``. The rule ``dce_rule`` should return a Pytree with the same structure
|
|
as the original output of ``f``, but any unused outputs can be replaced with
|
|
``None``.
|
|
|
|
For example::
|
|
|
|
>>> @jax.experimental.custom_dce.custom_dce
|
|
... def f(x, y):
|
|
... return jnp.sin(x) * y, x * jnp.sin(y)
|
|
...
|
|
>>> @f.def_dce
|
|
... def f_dce_rule(used_outs, x, y):
|
|
... return (
|
|
... jnp.sin(x) * y if used_outs[0] else None,
|
|
... x * jnp.sin(y) if used_outs[1] else None,
|
|
... )
|
|
|
|
In this example, ``used_outs`` is a ``tuple`` with two ``bool`` values,
|
|
indicating which outputs are required. The DCE rule only computes the
|
|
required outputs, replacing the unused outputs with ``None``.
|
|
|
|
If the ``static_argnums`` argument is provided to ``custom_dce``, the
|
|
indicated arguments are treated as static when the function is traced, and
|
|
they will be moved to the front when calling the DCE rule. For example, if
|
|
``fun`` takes 2 arguments ``fun(x, y)``, and ``static_argnums`` is ``(1,)``,
|
|
then the DCE rule will be called as ``dce_rule(y, used_outs, x)``.
|
|
"""
|
|
|
|
fun: Callable[..., Any]
|
|
static_argnums: Sequence[int]
|
|
dce_rule: Callable[..., Any] | None
|
|
|
|
def __init__(
|
|
self, fun: Callable[..., Any], *, static_argnums: Sequence[int] = ()
|
|
):
|
|
functools.update_wrapper(self, fun)
|
|
self.fun = fun
|
|
self.static_argnums = static_argnums
|
|
self.dce_rule = None
|
|
|
|
__getattr__ = custom_api_util.forward_attr
|
|
|
|
def def_dce(
|
|
self,
|
|
dce_rule: Callable[..., Any],
|
|
) -> Callable[..., Any]:
|
|
"""Define a custom DCE rule for this function.
|
|
|
|
Args:
|
|
dce_rule: A function that takes (a) any arguments indicated as static
|
|
using ``static_argnums``, (b) a Pytree of ``bool`` values
|
|
(``used_outs``) indicating which outputs should be computed, and (c)
|
|
the rest of the (non-static) arguments to the original function. The
|
|
rule should return a Pytree with with the same structure as the output
|
|
of the original function, but any unused outputs (as indicated by
|
|
``used_outs``) can be replaced with ``None``.
|
|
"""
|
|
self.dce_rule = dce_rule
|
|
return dce_rule
|
|
|
|
@traceback_util.api_boundary
|
|
def __call__(self, *args, **kwargs):
|
|
fun_name = util.fun_name(self.fun)
|
|
if self.dce_rule is None:
|
|
raise AttributeError(
|
|
f"No DCE rule defined for custom_dce function {fun_name} using "
|
|
"def_dce."
|
|
)
|
|
rule_name = util.fun_name(self.dce_rule)
|
|
debug = api_util.debug_info("custom_dce", self.fun,
|
|
args, {},
|
|
static_argnums=self.static_argnums)
|
|
debug_rule = api_util.debug_info("custom_dce_rule", self.dce_rule,
|
|
args, {},
|
|
static_argnums=self.static_argnums)
|
|
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
|
if self.static_argnums:
|
|
static_argnums = set(self.static_argnums)
|
|
for i in static_argnums:
|
|
check_for_tracers(args[i])
|
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
|
fun, dyn_args = api_util.argnums_partial(
|
|
lu.wrap_init(self.fun, debug_info=debug),
|
|
dyn_argnums,
|
|
args,
|
|
require_static_args_hashable=False,
|
|
)
|
|
static_args = [args[i] for i in self.static_argnums]
|
|
dce_rule = api_util.prepend_static_args(
|
|
lu.wrap_init(self.dce_rule, debug_info=debug_rule), static_args
|
|
)
|
|
else:
|
|
fun = lu.wrap_init(self.fun, debug_info=debug)
|
|
dce_rule = lu.wrap_init(self.dce_rule, debug_info=debug_rule)
|
|
dyn_args = args
|
|
|
|
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
|
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
|
|
in_avals = [core.get_aval(x) for x in args_flat]
|
|
|
|
@pe._memoize
|
|
def dce_jaxpr_thunk(
|
|
*used_outs: bool,
|
|
) -> tuple[core.ClosedJaxpr, Sequence[bool]]:
|
|
for store in dce_rule.stores:
|
|
if store:
|
|
store.reset()
|
|
flat_rule, rule_out_tree = flatten_dce_rule( # pylint: disable=unbalanced-tuple-unpacking
|
|
dce_rule,
|
|
fun_name,
|
|
rule_name,
|
|
used_outs,
|
|
in_tree,
|
|
out_tree(),
|
|
out_avals,
|
|
)
|
|
assert self.dce_rule is not None
|
|
dce_jaxpr, _, dce_consts, () = pe.trace_to_jaxpr_dynamic(
|
|
flat_rule, in_avals
|
|
)
|
|
|
|
# This second round of DCE is used to work out which inputs are actually
|
|
# referenced by the DCEed Jaxpr. To avoid infinite recursion when the DCE
|
|
# rule calls back into the primal, we replace all custom_dce primitives
|
|
# with a sentinel primitive with a no-op DCE rule.
|
|
dce_jaxpr = swap_primitives(dce_jaxpr, custom_dce_p, dce_sential_p)
|
|
dce_jaxpr, used_ins = pe.dce_jaxpr(
|
|
dce_jaxpr, [True] * len(dce_jaxpr.outvars)
|
|
)
|
|
dce_jaxpr = swap_primitives(dce_jaxpr, dce_sential_p, custom_dce_p)
|
|
|
|
return core.ClosedJaxpr(dce_jaxpr, dce_consts), used_ins
|
|
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
|
closed_call = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
|
out_avals = closed_call.out_avals
|
|
out_flat = custom_dce_p.bind(
|
|
*consts,
|
|
*args_flat,
|
|
num_consts=len(consts),
|
|
fun_jaxpr=closed_call,
|
|
dce_jaxpr_thunk=dce_jaxpr_thunk,
|
|
)
|
|
return tree_util.tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
|
def check_for_tracers(x):
|
|
# TODO(danfm): de-duplicate this with the version in custom_derivatives
|
|
for leaf in tree_util.tree_leaves(x):
|
|
if isinstance(leaf, core.Tracer):
|
|
msg = (
|
|
"Found a JAX Tracer object passed as an argument to a custom_dce "
|
|
"function in a position indicated by static_argnums as static. "
|
|
"Tracers cannot be passed as static arguments to custom_dce "
|
|
"functions; instead, static_argnums should only be used for "
|
|
"arguments that can't be or contain JAX tracers, e.g. "
|
|
"function-valued arguments. In particular, array-valued arguments "
|
|
"should typically not be indicated as static_argnums."
|
|
)
|
|
raise errors.UnexpectedTracerError(msg)
|
|
|
|
|
|
@lu.transformation_with_aux2
|
|
def flatten_dce_rule(
|
|
f: Callable[..., Any],
|
|
store: lu.Store,
|
|
fun_name: str,
|
|
rule_name: str,
|
|
used_outs: Sequence[bool],
|
|
in_tree,
|
|
out_tree,
|
|
out_avals: Sequence[core.AbstractValue],
|
|
*args_flat,
|
|
):
|
|
py_used_outs = tree_util.tree_unflatten(out_tree, used_outs)
|
|
py_args = tree_util.tree_unflatten(in_tree, args_flat)
|
|
py_out = f(py_used_outs, *py_args)
|
|
if isinstance(py_out, list) and len(py_out) == len(
|
|
tree_util.treedef_children(out_tree)
|
|
):
|
|
py_out = tuple(py_out)
|
|
|
|
# For error checking purposes, we need to reformat the pytree structure
|
|
# of the output of the DCE rule to match the original output. The catch is
|
|
# that the DCE rule can return a None to indicated an unused subtree, so we
|
|
# need to rebuild those subtrees with a sentinal value at the leaves. This
|
|
# logic is very similar to what is used in custom_dervatives._flatten_bwd.
|
|
sentinal = object()
|
|
dummy = tree_util.tree_unflatten(out_tree, [object()] * out_tree.num_leaves)
|
|
keypaths, _ = util.unzip2(tree_util.tree_flatten_with_path(dummy)[0])
|
|
out_flat = []
|
|
|
|
def append(x, d):
|
|
num_leaves = len(tree_util.tree_flatten(d)[0])
|
|
if x is None and d is not None:
|
|
out_flat.extend([sentinal] * num_leaves)
|
|
elif x is not None:
|
|
out_flat.extend([x] * num_leaves)
|
|
return x
|
|
|
|
_, rule_tree = tree_util.tree_flatten(py_out)
|
|
try:
|
|
tree_util.tree_map(append, py_out, dummy, is_leaf=lambda x: x is None)
|
|
except ValueError:
|
|
raise TypeError(
|
|
f"Custom DCE rule {rule_name} for function {fun_name} must produce an "
|
|
"output with the same container (pytree) structure as the output of "
|
|
f"the function {fun_name}. But, the DCE rule returned structure "
|
|
f"{rule_tree}, while {fun_name} returned {out_tree}."
|
|
) from None
|
|
|
|
results = []
|
|
for kp, used, aval, val in zip(keypaths, used_outs, out_avals, out_flat):
|
|
if not used:
|
|
continue
|
|
if val is sentinal:
|
|
raise ValueError(
|
|
f"Custom DCE rule {rule_name} for function {fun_name} must produce "
|
|
"values for all of the required outputs (as specified by the "
|
|
"used_outs argument to the rule). But, at "
|
|
f"output{tree_util.keystr(kp)}, the DCE rule returned None instead "
|
|
f"of an output with shape/dtype {aval.str_short()}."
|
|
)
|
|
if not core.typematch(aval, aval_ := core.get_aval(val)):
|
|
raise ValueError(
|
|
f"Custom DCE rule {rule_name} for function {fun_name} must produce "
|
|
"an output with the same shapes/dtypes as the output of the function "
|
|
f"{fun_name}. But, at output{tree_util.keystr(kp)}, the DCE rule "
|
|
f"returned an output with shape/dtype {aval_.str_short()}, while "
|
|
f"{aval.str_short()} was expected."
|
|
)
|
|
results.append(val)
|
|
|
|
store.store(rule_tree)
|
|
return results
|
|
|
|
|
|
def custom_dce_impl(*args, fun_jaxpr: core.ClosedJaxpr, **_):
|
|
return core.jaxpr_as_fun(fun_jaxpr)(*args)
|
|
|
|
|
|
def custom_dce_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **_):
|
|
del args # unused
|
|
return fun_jaxpr.out_avals, fun_jaxpr.effects
|
|
|
|
|
|
def custom_dce_batching(
|
|
axis_data: batching.AxisData,
|
|
args,
|
|
dims,
|
|
*,
|
|
num_consts: int,
|
|
fun_jaxpr: core.ClosedJaxpr,
|
|
dce_jaxpr_thunk: Callable[..., tuple[core.ClosedJaxpr, Sequence[bool]]],
|
|
):
|
|
in_batched = [d is not batching.not_mapped for d in dims]
|
|
args = [
|
|
batching.moveaxis(x, d, 0) if b else x
|
|
for b, x, d in zip(in_batched, args, dims)
|
|
]
|
|
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
|
fun_jaxpr, axis_data, in_batched, False
|
|
)
|
|
|
|
@pe._memoize
|
|
def batched_dce_jaxpr_thunk(
|
|
*used_outs: bool,
|
|
) -> tuple[core.ClosedJaxpr, Sequence[bool]]:
|
|
dce_jaxpr, used_ins = dce_jaxpr_thunk(*used_outs)
|
|
used_out_batched = [b for used, b in zip(used_outs, out_batched) if used]
|
|
dce_jaxpr_batched, dce_out_batched = batching.batch_jaxpr(
|
|
dce_jaxpr,
|
|
axis_data,
|
|
[b for used, b in zip(used_ins, in_batched[num_consts:]) if used],
|
|
used_out_batched,
|
|
)
|
|
# TODO(danfm): For now we require that the DCE rule produce the same
|
|
# batching pattern for the used outputs as the original function's batching.
|
|
# While it is OK for the DCE rule to produce _less batched_ outputs (using
|
|
# instantiate), we don't support the opposite case. It seems like we could
|
|
# solve this by using instantiate=True when batching fun_jaxpr, but this
|
|
# ends up changing the batching behavior sufficiently that it breaks some
|
|
# real world use cases. Revisit if needed.
|
|
assert all(a == b for a, b in zip(dce_out_batched, used_out_batched))
|
|
return dce_jaxpr_batched, used_ins
|
|
|
|
out_flat = custom_dce_p.bind(
|
|
*args,
|
|
num_consts=num_consts,
|
|
fun_jaxpr=batched_fun_jaxpr,
|
|
dce_jaxpr_thunk=batched_dce_jaxpr_thunk,
|
|
)
|
|
out_dims = [0 if b else batching.not_mapped for b in out_batched]
|
|
return out_flat, out_dims
|
|
|
|
|
|
def custom_dce_jvp(primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, **_):
|
|
in_nz = [not isinstance(t, ad.Zero) for t in tangents]
|
|
tangents = [t for nz, t in zip(in_nz, tangents) if nz]
|
|
jvp_jaxpr, out_nz = ad.jvp_jaxpr(fun_jaxpr, in_nz, False)
|
|
|
|
# TODO(danfm): We should avoid losing the DCE rule here, but it is more
|
|
# straightforward to implement it like this to start. Instead, we should
|
|
# bind a custom_dce primitive. To support that, we would need to add a
|
|
# partial eval rule, and maybe a transpose rule. That being said, I expect
|
|
# that most users of this API would compose this with a custom_jvp or
|
|
# custom_vjp, which makes this less urgent.
|
|
out = core.call_p.bind(
|
|
lu.wrap_init(core.jaxpr_as_fun(jvp_jaxpr),
|
|
debug_info=jvp_jaxpr.jaxpr.debug_info), *primals, *tangents
|
|
)
|
|
|
|
out_primals, out_tangents = util.split_list(out, [len(out_nz)])
|
|
out_tangents_iter = iter(out_tangents)
|
|
out_tangents = [
|
|
next(out_tangents_iter) if nz else ad.Zero.from_primal_value(p)
|
|
for p, nz in zip(out_primals, out_nz)
|
|
]
|
|
return out_primals, out_tangents
|
|
|
|
|
|
def custom_dce_rule(used_outs: Sequence[bool], eqn: core.JaxprEqn):
|
|
if not any(used_outs) and not pe.has_effects(eqn):
|
|
return [False] * len(eqn.invars), None
|
|
if all(used_outs):
|
|
return [True] * len(eqn.invars), eqn
|
|
|
|
num_consts = eqn.params["num_consts"]
|
|
dce_jaxpr_thunk = eqn.params["dce_jaxpr_thunk"]
|
|
jaxpr, used_ins = dce_jaxpr_thunk(*used_outs)
|
|
|
|
@pe._memoize
|
|
def new_dce_jaxpr_thunk(
|
|
*new_used_outs: bool,
|
|
) -> tuple[core.ClosedJaxpr, Sequence[bool]]:
|
|
if all(new_used_outs):
|
|
return jaxpr, used_ins
|
|
all_used_outs = util.merge_lists(
|
|
used_outs,
|
|
[False] * (len(used_outs) - len(new_used_outs)),
|
|
new_used_outs,
|
|
)
|
|
new_jaxpr, all_used_ins = dce_jaxpr_thunk(*all_used_outs)
|
|
not_used, new_used_ins = util.partition_list(used_ins, all_used_ins)
|
|
assert not any(not_used)
|
|
return new_jaxpr, new_used_ins
|
|
|
|
_, invars = util.split_list(eqn.invars, [num_consts])
|
|
invars = [v for used, v in zip(used_ins, invars) if used]
|
|
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
|
|
new_params = dict(
|
|
eqn.params,
|
|
num_consts=0,
|
|
fun_jaxpr=jaxpr,
|
|
dce_jaxpr_thunk=new_dce_jaxpr_thunk,
|
|
)
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
invars,
|
|
outvars,
|
|
custom_dce_p,
|
|
new_params,
|
|
jaxpr.effects,
|
|
eqn.source_info,
|
|
eqn.ctx,
|
|
)
|
|
return [False] * eqn.params["num_consts"] + used_ins, new_eqn
|
|
|
|
|
|
custom_dce_p = core.Primitive("custom_dce_call")
|
|
custom_dce_p.multiple_results = True
|
|
custom_dce_p.def_impl(custom_dce_impl)
|
|
custom_dce_p.def_effectful_abstract_eval(custom_dce_abstract_eval)
|
|
mlir.register_lowering(
|
|
custom_dce_p, mlir.lower_fun(custom_dce_impl, multiple_results=True)
|
|
)
|
|
batching.fancy_primitive_batchers[custom_dce_p] = custom_dce_batching
|
|
ad.primitive_jvps[custom_dce_p] = custom_dce_jvp
|
|
pe.dce_rules[custom_dce_p] = custom_dce_rule
|
|
|
|
|
|
def swap_primitives(
|
|
jaxpr: core.Jaxpr, old: core.Primitive, new: core.Primitive
|
|
) -> core.Jaxpr:
|
|
new_eqns = []
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive is old:
|
|
new_eqns.append(eqn.replace(primitive=new))
|
|
else:
|
|
new_eqns.append(eqn)
|
|
return jaxpr.replace(eqns=new_eqns)
|
|
|
|
|
|
dce_sential_p = core.Primitive("dce_sential")
|
|
dce_sential_p.multiple_results = True
|
|
dce_sential_p.def_impl(custom_dce_impl)
|
|
dce_sential_p.def_effectful_abstract_eval(custom_dce_abstract_eval)
|