mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
This commit is contained in:
parent
414449e142
commit
abcaec7081
@ -564,6 +564,7 @@ pytype_strict_library(
|
||||
srcs = ["_src/interpreters/mlir.py"],
|
||||
deps = [
|
||||
":ad_util",
|
||||
":api_util",
|
||||
":config",
|
||||
":core",
|
||||
":dtypes",
|
||||
|
@ -682,12 +682,13 @@ def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
|
||||
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
|
||||
def _transpose_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
in_lin: Sequence[bool],
|
||||
out_zeros: Sequence[bool]):
|
||||
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
|
||||
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
|
||||
cell = lambda: None
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*args_flat):
|
||||
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])
|
||||
|
||||
@ -715,7 +716,10 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
|
||||
in_cts_nz, _ = partition_list(in_zeros, in_cts)
|
||||
return in_cts_nz
|
||||
|
||||
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
|
||||
transposed_wrapped = lu.wrap_init(transposed,
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
transposed_wrapped, in_avals)
|
||||
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
|
||||
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error
|
||||
|
||||
|
@ -983,7 +983,7 @@ def vmap(fun: F,
|
||||
"to the positional arguments passed to the function, "
|
||||
f"but got {len(in_axes)=}, {len(args)=}")
|
||||
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
|
||||
f = lu.wrap_init(fun)
|
||||
f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs))
|
||||
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
|
||||
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
|
||||
axis_size_ = (axis_size if axis_size is not None else
|
||||
@ -1715,15 +1715,15 @@ def jvp(
|
||||
0.19900084
|
||||
"""
|
||||
check_callable(fun)
|
||||
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
|
||||
|
||||
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
|
||||
"""Variant of jvp() that takes an lu.WrappedFun."""
|
||||
if (not isinstance(primals, (tuple, list)) or
|
||||
not isinstance(tangents, (tuple, list))):
|
||||
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
|
||||
f"found {type(primals).__name__} and {type(tangents).__name__}.")
|
||||
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
|
||||
primals, tangents, has_aux=has_aux)
|
||||
|
||||
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
|
||||
"""Variant of jvp() that takes an lu.WrappedFun."""
|
||||
ps_flat, tree_def = tree_flatten(primals)
|
||||
ts_flat, tree_def_2 = tree_flatten(tangents)
|
||||
if tree_def != tree_def_2:
|
||||
@ -1835,7 +1835,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
|
||||
-6.676704
|
||||
"""
|
||||
check_callable(fun)
|
||||
f = lu.wrap_init(fun)
|
||||
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
if has_aux:
|
||||
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
|
||||
@ -1983,8 +1983,9 @@ def vjp(
|
||||
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
|
||||
del reduce_axes
|
||||
check_callable(fun)
|
||||
return _vjp(
|
||||
lu.wrap_init(fun), *primals, has_aux=has_aux)
|
||||
wrapped_fun = lu.wrap_init(fun,
|
||||
debug_info=debug_info("vjp", fun, primals, {}))
|
||||
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
|
||||
|
||||
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
|
||||
"""Variant of vjp() that takes an lu.WrappedFun."""
|
||||
@ -2049,7 +2050,10 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
|
||||
del reduce_axes
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun,
|
||||
debug_info=debug_info("linear_transpose", fun, primals, {})),
|
||||
in_tree)
|
||||
in_avals = map(shaped_abstractify, primals_flat)
|
||||
in_dtypes = map(dtypes.dtype, in_avals)
|
||||
|
||||
|
@ -67,7 +67,8 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
||||
return tuple(map(_ensure_str, x))
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun(f, store, in_tree, *args_flat):
|
||||
def flatten_fun(f: Callable, store: lu.Store,
|
||||
in_tree: PyTreeDef, *args_flat):
|
||||
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
||||
ans = f(*py_args, **py_kwargs)
|
||||
ans, out_tree = tree_flatten(ans)
|
||||
@ -587,8 +588,8 @@ def debug_info(
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
static_argnums: tuple[int, ...] = (),
|
||||
static_argnames: tuple[str, ...] = (),
|
||||
static_argnums: Sequence[int] = (),
|
||||
static_argnames: Sequence[str] = (),
|
||||
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
|
||||
# TODO(necula): check if we really need this, e.g., to speed up tracing?
|
||||
sourceinfo: str | None = None,
|
||||
|
@ -361,8 +361,9 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
else:
|
||||
jaxpr, consts = call_jaxpr, ()
|
||||
consts_ = tuple(HashableWrapper(c) for c in consts)
|
||||
partial_checkify = lu.hashable_partial(lu.wrap_init(
|
||||
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
|
||||
partial_checkify = lu.hashable_partial(
|
||||
lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
|
||||
jaxpr, consts_, enabled_errors, err_tree)
|
||||
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
|
||||
partial_checkify)
|
||||
|
||||
@ -746,7 +747,7 @@ def jaxpr_to_checkify_jaxpr(
|
||||
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
|
||||
jaxpr.consts, enabled_errors,
|
||||
err_tree)
|
||||
fun = lu.wrap_init(checkify_jaxpr_partial)
|
||||
fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info)
|
||||
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
|
||||
|
||||
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
|
||||
|
@ -2416,8 +2416,9 @@ call_p.def_impl(call_impl)
|
||||
class ClosedCallPrimitive(CallPrimitive):
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
|
||||
jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr')
|
||||
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
return [subfun], new_params
|
||||
|
||||
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
|
||||
|
@ -31,7 +31,7 @@ from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import (
|
||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
|
||||
_non_static_arg_names, prepend_static_args)
|
||||
_non_static_arg_names, prepend_static_args, debug_info)
|
||||
from jax._src.errors import UnexpectedTracerError
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.interpreters import ad
|
||||
@ -44,7 +44,7 @@ from jax._src.lax import lax
|
||||
from jax._src.tree_util import (
|
||||
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
|
||||
tree_leaves_with_path, keystr, treedef_children)
|
||||
tree_leaves_with_path, keystr, treedef_children, PyTreeDef)
|
||||
from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2,
|
||||
weakref_lru_cache)
|
||||
|
||||
@ -78,7 +78,9 @@ _stop_gradient = partial(
|
||||
|
||||
# like the api_util.py function, but also grabs output avals for error checking
|
||||
@lu.transformation_with_aux2
|
||||
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
|
||||
def _flatten_fun_nokwargs(f: Callable,
|
||||
store: lu.Store, in_tree: PyTreeDef,
|
||||
*args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
ans = f(*py_args)
|
||||
ans_flat, ans_tree = tree_flatten(ans)
|
||||
@ -204,7 +206,7 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
*jvps: a sequence of functions, one for each positional argument of the
|
||||
:class:`~jax.custom_jvp` function. Each function takes as arguments
|
||||
the tangent value for the corresponding primal input, the primal
|
||||
output, and the ßprimal inputs. See the example below.
|
||||
output, and the primal inputs. See the example below.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
@ -239,28 +241,40 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
|
||||
primal_name = getattr(self.fun, '__name__', str(self.fun))
|
||||
debug = debug_info("custom_jvp fun", self.fun, args, kwargs,
|
||||
static_argnums=self.nondiff_argnums)
|
||||
primal_name = debug.func_name
|
||||
if not self.jvp:
|
||||
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
|
||||
raise AttributeError(msg)
|
||||
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
|
||||
|
||||
args = resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.nondiff_argnums:
|
||||
nondiff_argnums = set(self.nondiff_argnums)
|
||||
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
|
||||
for i, x in enumerate(args))
|
||||
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug),
|
||||
diff_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
jvp = prepend_static_args(lu.wrap_init(self.jvp), static_args)
|
||||
diff_args = [args[i] for i, a in enumerate(args) if i not in self.nondiff_argnums]
|
||||
debug_jvp = debug_info("custom_jvp jvp", self.jvp,
|
||||
(*static_args, diff_args, diff_args),
|
||||
{},
|
||||
static_argnums=tuple(range(len(static_args))))
|
||||
jvp = prepend_static_args(lu.wrap_init(self.jvp,
|
||||
debug_info=debug_jvp), static_args)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
jvp = lu.wrap_init(self.jvp)
|
||||
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args
|
||||
debug_jvp = debug_info("custom_jvp jvp", self.jvp,
|
||||
(args, args),
|
||||
{})
|
||||
jvp = lu.wrap_init(self.jvp, debug_info=debug_jvp)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
|
||||
out_type1)
|
||||
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, debug_jvp.func_name,
|
||||
in_tree, out_type1)
|
||||
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
|
||||
symbolic_zeros=self.symbolic_zeros)
|
||||
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
|
||||
@ -611,15 +625,20 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
|
||||
primal_name = getattr(self.fun, '__name__', str(self.fun))
|
||||
debug_fun = debug_info("custom_vjp fun", self.fun, args, kwargs,
|
||||
static_argnums=self.nondiff_argnums)
|
||||
if not self.fwd or not self.bwd:
|
||||
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
|
||||
msg = f"No VJP defined for custom_vjp function {debug_fun.func_name} using defvjp."
|
||||
raise AttributeError(msg)
|
||||
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
||||
args = resolve_kwargs(self.fun, args, kwargs)
|
||||
debug_fwd = debug_info("custom_vjp fwd", self.fwd, args, kwargs,
|
||||
static_argnums=self.nondiff_argnums)
|
||||
# TODO(necula): figure out how to construct the debug_bwd args
|
||||
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {})
|
||||
if self.optimize_remat:
|
||||
fwd = optimize_remat_of_custom_vjp_fwd(
|
||||
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
|
||||
self.fun, debug_fun, self.fwd, debug_fwd,
|
||||
nondiff_argnums=self.nondiff_argnums,
|
||||
symbolic_zeros=self.symbolic_zeros)
|
||||
else:
|
||||
fwd = self.fwd
|
||||
@ -633,23 +652,27 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
for i in self.nondiff_argnums: _check_for_tracers(args[i])
|
||||
nondiff_argnums = set(self.nondiff_argnums)
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
|
||||
args, require_static_args_hashable=False)
|
||||
f_, dyn_args = argnums_partial(
|
||||
lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums,
|
||||
args, require_static_args_hashable=False)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
bwd = prepend_static_args(lu.wrap_init(self.bwd), static_args)
|
||||
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
|
||||
dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
bwd = prepend_static_args(lu.wrap_init(self.bwd, debug_info=debug_bwd),
|
||||
static_args)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
||||
f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug_fun), args
|
||||
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
|
||||
bwd = lu.wrap_init(self.bwd, debug_info=debug_bwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
if config.mutable_array_checks.value:
|
||||
f_ = _check_primal_refs(f_, self.nondiff_argnums)
|
||||
f_ = _check_primal_refs(f_, self.nondiff_argnums, f_.debug_info)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(
|
||||
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
fwd_, self.nondiff_argnums, self.symbolic_zeros, debug_fun,
|
||||
debug_fwd, in_tree, out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees,
|
||||
@ -658,19 +681,27 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
@lu.transformation2
|
||||
def _check_primal_refs(f, nondiff_argnums, *args):
|
||||
_check_for_aliased_refs(f, nondiff_argnums, args)
|
||||
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
|
||||
debug_info: core.DebugInfo | None, *args):
|
||||
_check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
|
||||
out = f(*args)
|
||||
_check_for_returned_refs(f, out, 'primal')
|
||||
return out
|
||||
|
||||
def _check_for_aliased_refs(f, nondiff_argnums, args):
|
||||
def _check_for_aliased_refs(f: Callable,
|
||||
nondiff_argnums: Sequence[int],
|
||||
debug: core.DebugInfo | None,
|
||||
args):
|
||||
leaves = tree_leaves(args)
|
||||
refs: dict[int, int] = {}
|
||||
for i, x in enumerate(leaves):
|
||||
if (isinstance((a := core.get_aval(x)), AbstractRef) and
|
||||
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
|
||||
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
if debug is not None:
|
||||
arg_names = debug.safe_arg_names(len(leaves))
|
||||
else:
|
||||
# TODO(necula): drop this branch
|
||||
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
|
||||
if arg_names is None:
|
||||
arg_names = [f'flat index {j}' for j in range(len(leaves))]
|
||||
raise ValueError(
|
||||
@ -725,18 +756,24 @@ def _check_for_tracers(x):
|
||||
raise UnexpectedTracerError(msg)
|
||||
|
||||
@partial(lu.transformation_with_aux2, use_eq_store=True)
|
||||
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, maybe_out_type, *args):
|
||||
def _flatten_fwd(f: Callable, store: lu.EqualStore,
|
||||
nondiff_argnums: Sequence[int],
|
||||
symbolic_zeros: bool,
|
||||
debug_primal: core.DebugInfo | None,
|
||||
debug_fwd: core.DebugInfo | None,
|
||||
in_tree: PyTreeDef, maybe_out_type, *args):
|
||||
primal_name = debug_primal.func_name if debug_primal else str(f)
|
||||
fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
|
||||
if symbolic_zeros:
|
||||
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
||||
args = tuple(CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2]))
|
||||
else:
|
||||
args = args[::2]
|
||||
py_args = tree_unflatten(in_tree, args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_for_aliased_refs(f, nondiff_argnums, py_args)
|
||||
_check_for_aliased_refs(f, nondiff_argnums, debug_primal, py_args)
|
||||
pair_out = f(*py_args)
|
||||
if config.mutable_array_checks.value:
|
||||
_check_for_returned_refs(f, pair_out, 'fwd')
|
||||
_check_for_returned_refs(f, pair_out, "fwd")
|
||||
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
||||
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
||||
"must produce a pair (list or tuple of length two) where the first "
|
||||
@ -790,7 +827,10 @@ def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
|
||||
return (*res, *primals_out)
|
||||
|
||||
@lu.transformation2
|
||||
def _flatten_bwd(f, in_tree, in_avals, out_trees, *args):
|
||||
def _flatten_bwd(f: Callable,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_trees: Callable[[], Sequence[PyTreeDef]], *args):
|
||||
out_tree, res_tree = out_trees()
|
||||
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
|
||||
res, cts_out = split_list(args, [res_tree.num_leaves])
|
||||
@ -1494,7 +1534,9 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
|
||||
# simpler, but it would be worth revisiting this.
|
||||
def optimize_remat_of_custom_vjp_fwd(
|
||||
fun: Callable[..., ReturnValue],
|
||||
debug_fun: core.DebugInfo | None,
|
||||
fwd: Callable[..., tuple[ReturnValue, Any]],
|
||||
debug_fwd: core.DebugInfo | None,
|
||||
nondiff_argnums: Sequence[int] = (),
|
||||
symbolic_zeros: bool = False,
|
||||
) -> Callable[..., tuple[ReturnValue, Any]]:
|
||||
@ -1507,8 +1549,7 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
|
||||
# TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
|
||||
# above and it would be good to consolidate it.
|
||||
primal_name = getattr(fun, "__name__", str(fun))
|
||||
fwd_name = getattr(fwd, "__name__", str(fwd))
|
||||
fwd_name = debug_fwd.func_name if debug_fwd else str(fwd)
|
||||
# Note: we use `fun` instead of `fwd` here for consistency with
|
||||
# custom_vjp.__call__ above.
|
||||
args = resolve_kwargs(fun, args, kwargs)
|
||||
@ -1516,17 +1557,19 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
for i in nondiff_argnums: _check_for_tracers(args[i])
|
||||
nondiff_argnums_ = set(nondiff_argnums)
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums,
|
||||
args, require_static_args_hashable=False)
|
||||
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun),
|
||||
dyn_argnums,
|
||||
args, require_static_args_hashable=False)
|
||||
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
|
||||
dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(fun), args
|
||||
fwd_ = lu.wrap_init(fwd)
|
||||
f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args
|
||||
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
|
||||
primal_name, fwd_name, in_tree, out_type)
|
||||
debug_fun, debug_fwd, in_tree, out_type)
|
||||
flat_fwd = _fix_fwd_args(flat_fwd)
|
||||
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
|
@ -68,7 +68,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
|
||||
return jvpfun(fun, instantiate, transform_stack), aux
|
||||
|
||||
@lu.transformation2
|
||||
def jvpfun(f, instantiate, transform_stack, primals, tangents):
|
||||
def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
|
||||
tag = core.TraceTag()
|
||||
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
|
||||
and dtype(t) == float0 else t for t in tangents]
|
||||
@ -106,7 +106,7 @@ def linearize_subtrace(_f: Callable, _store, _tag, nzs_in, *primals, **params):
|
||||
return tuple(consts) + tuple(out_primals)
|
||||
|
||||
@lu.transformation2
|
||||
def jvp_subtrace(f, tag, primals, tangents):
|
||||
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JVPTrace(parent_trace, tag)
|
||||
in_tracers = [maybe_jvp_tracer(trace, x, t)
|
||||
@ -778,7 +778,8 @@ def linearize_from_jvp(jvp, multiple_results, nonzeros,
|
||||
out_nz_tracers = [trace.to_jaxpr_tracer(r)
|
||||
for (r, nz) in zip(out_tangents, out_nzs) if nz]
|
||||
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
|
||||
# TODO(necula): pass debug_info here
|
||||
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, None)
|
||||
|
||||
def linearized(residuals, *tangents):
|
||||
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
|
||||
@ -932,13 +933,15 @@ def traceable(f, store, in_tree, *primals_and_tangents):
|
||||
return out_flat
|
||||
|
||||
|
||||
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
|
||||
if isinstance(call_jaxpr, core.ClosedJaxpr):
|
||||
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
||||
else:
|
||||
consts = ()
|
||||
all_args, in_tree_def = tree_flatten((consts, args, ct))
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, False)
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass,
|
||||
debug_info=call_jaxpr.debug_info),
|
||||
call_jaxpr, False)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
if update_params:
|
||||
@ -950,7 +953,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
res_invars, _ = partition_list(which_lin, call_jaxpr.invars)
|
||||
new_invars = [*res_invars, *call_jaxpr.outvars]
|
||||
dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)}
|
||||
in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape))
|
||||
in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type]
|
||||
if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars]
|
||||
fun = lu.annotate(fun, tuple(in_type))
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
@ -1027,8 +1030,8 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
|
||||
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
|
||||
assert len(jaxpr.in_avals) == len(nonzeros)
|
||||
debug_info = jaxpr.jaxpr.debug_info
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=debug_info)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
|
||||
debug_info=jaxpr.jaxpr.debug_info)
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
|
||||
nonzeros)
|
||||
tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
|
@ -34,6 +34,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
@ -2156,7 +2157,8 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
as `avals_out`."""
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
wrapped_fun = lu.wrap_init(f, params)
|
||||
wrapped_fun = lu.wrap_init(f, params,
|
||||
debug_info=api_util.debug_info("lower_fun", fun, args, params))
|
||||
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
|
||||
ctx.jaxpr_eqn_ctx.manager)
|
||||
|
||||
|
@ -236,7 +236,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
params, effects, source)
|
||||
return out_tracer
|
||||
|
||||
def process_call(self, primitive, f, tracers, params):
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
tracers = map(self.to_jaxpr_tracer, tracers)
|
||||
rule = call_partial_eval_rules.get(primitive)
|
||||
if rule:
|
||||
@ -254,7 +254,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
# which were unknown to the first call (corresponding to in_avals).
|
||||
|
||||
# Wrap f to perform the partial evaluation and plumb out aux data.
|
||||
f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, False)
|
||||
f_ = trace_to_subjaxpr_nounits_fwd(f, self.tag, f.debug_info, False)
|
||||
f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), tuple(in_avals))
|
||||
|
||||
# Adjust parameters (e.g. donated_invars) for the call to be evaluated now.
|
||||
@ -330,7 +330,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
for ax, aval in zip(unk_in_axes, in_avals)]
|
||||
|
||||
# Wrap f to perform partial evaluation and plumb out aux data.
|
||||
f = trace_to_subjaxpr_nounits2(f, self.tag, False)
|
||||
f = trace_to_subjaxpr_nounits2(f, self.tag, f.debug_info, False)
|
||||
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns),
|
||||
tuple(in_avals_mapped))
|
||||
# Adjust params for knowns (e.g. donated_invars, in_axes, out_axes_thunk)
|
||||
@ -429,7 +429,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
# Because we instantiate all tracers, in_knowns is all False.
|
||||
in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers])
|
||||
f = trace_to_subjaxpr_nounits(f, self, True)
|
||||
f = trace_to_subjaxpr_nounits(f, self, True, f.debug_info)
|
||||
f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,))
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees,
|
||||
@ -445,7 +445,7 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
@_memoize
|
||||
def fwd_jaxpr_thunk(*zeros):
|
||||
fwd_ = _interleave_fun(fwd, zeros)
|
||||
fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True)
|
||||
fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True, fwd_.debug_info)
|
||||
fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,))
|
||||
out_flat = fwd_.call_wrapped()
|
||||
out_knowns, out_avals, jaxpr, env = aux()
|
||||
@ -476,7 +476,10 @@ def partition_pvals(
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def partial_eval_wrapper_nounits(
|
||||
f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
|
||||
f: Callable,
|
||||
store: lu.Store,
|
||||
in_knowns: Sequence[bool],
|
||||
in_avals: Sequence[AbstractValue],
|
||||
*in_consts: Any):
|
||||
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
|
||||
in_pvals = [PartialVal.known(next(in_consts_)) if known else
|
||||
@ -566,7 +569,7 @@ def trace_to_jaxpr_nounits(
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, TraceTag())
|
||||
with core.ensure_no_leaks(trace):
|
||||
fun = trace_to_subjaxpr_nounits(fun, trace, instantiate)
|
||||
fun = trace_to_subjaxpr_nounits(fun, trace, instantiate, fun.debug_info)
|
||||
with core.set_current_trace(trace):
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
@ -579,10 +582,11 @@ def trace_to_subjaxpr_nounits(
|
||||
f: Callable,
|
||||
trace: JaxprTrace,
|
||||
instantiate: Sequence[bool] | bool,
|
||||
debug_info: core.DebugInfo | None,
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
f, trace, instantiate, in_pvals, debug_info)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
del out_tracers
|
||||
return jaxpr, (out_pvals, out_consts, env)
|
||||
@ -591,6 +595,7 @@ def trace_to_subjaxpr_nounits(
|
||||
def trace_to_subjaxpr_nounits2(
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
debug_info: core.DebugInfo | None,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert isinstance(tag, TraceTag)
|
||||
@ -599,14 +604,15 @@ def trace_to_subjaxpr_nounits2(
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, tag)
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
f, trace, instantiate, in_pvals, debug_info)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
del out_tracers
|
||||
return jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
|
||||
instantiate: Sequence[bool] | bool,
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
in_pvals: Sequence[PartialVal],
|
||||
debug_info: core.DebugInfo | None):
|
||||
in_knowns = [pval.is_known() for pval in in_pvals]
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
|
||||
@ -623,7 +629,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
|
||||
out_tracers = [trace.instantiate_const(t) if inst else t
|
||||
for inst, t in zip(instantiate, out_tracers)]
|
||||
out_tracers_ = [t for t in out_tracers if not t.is_known()]
|
||||
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_)
|
||||
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info)
|
||||
return out_tracers, jaxpr, out_consts, env
|
||||
|
||||
# The below variant implements an optimization where residuals which are also
|
||||
@ -631,8 +637,9 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
|
||||
# TODO(mattjj): update all callers to use this version, delete other version.
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits_fwd(
|
||||
f,
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
debug_info: core.DebugInfo | None,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
@ -641,7 +648,7 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, tag)
|
||||
with core.set_current_trace(trace):
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
f, trace, instantiate, in_pvals, debug_info)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
|
||||
# Which out_consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||
@ -660,8 +667,9 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
# than passed as redundant outputs.
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits_fwd2(
|
||||
f,
|
||||
f: Callable,
|
||||
tag: TraceTag,
|
||||
debug_info: core.DebugInfo | None,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
@ -669,7 +677,7 @@ def trace_to_subjaxpr_nounits_fwd2(
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, tag)
|
||||
out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
f, trace, instantiate, in_pvals, debug_info)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
|
||||
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||
@ -743,7 +751,8 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
|
||||
|
||||
def tracers_to_jaxpr(
|
||||
in_tracers: Sequence[JaxprTracer],
|
||||
out_tracers: Sequence[JaxprTracer]
|
||||
out_tracers: Sequence[JaxprTracer],
|
||||
debug_info: core.DebugInfo | None,
|
||||
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
|
||||
"""Constructs Jaxpr given tracers for inputs and outputs.
|
||||
|
||||
@ -819,7 +828,8 @@ def tracers_to_jaxpr(
|
||||
outvars = map(get_atom, out_tracers) # type: ignore[arg-type]
|
||||
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
|
||||
jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type]
|
||||
outvars, eqns, jaxpr_effects)
|
||||
outvars, eqns, jaxpr_effects,
|
||||
debug_info)
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
||||
return jaxpr, const_vals, env_vals
|
||||
@ -1135,7 +1145,8 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars,
|
||||
outs_staged, staged_eqns)
|
||||
jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars,
|
||||
outs_staged, staged_eqns, staged_effects)
|
||||
outs_staged, staged_eqns, staged_effects,
|
||||
jaxpr.debug_info)
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr_staged)
|
||||
|
||||
return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals),
|
||||
@ -1504,7 +1515,8 @@ def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]
|
||||
new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move)
|
||||
new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars,
|
||||
closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns,
|
||||
closed_jaxpr.jaxpr.effects)
|
||||
closed_jaxpr.jaxpr.effects,
|
||||
closed_jaxpr.jaxpr.debug_info)
|
||||
new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
|
||||
return new_closed_jaxpr
|
||||
|
||||
@ -1666,7 +1678,8 @@ class JaxprStackFrame:
|
||||
set_states(self.attrs_tracked, self.attrs_inits)
|
||||
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
|
||||
|
||||
def to_jaxpr2(self, out_tracers):
|
||||
def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
|
||||
debug_info: core.DebugInfo | None):
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
@ -1674,10 +1687,10 @@ class JaxprStackFrame:
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
|
||||
self.eqns)
|
||||
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
|
||||
jaxpr_effects)
|
||||
jaxpr_effects, debug_info)
|
||||
# We can't run check_jaxpr until after we normalize.
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore[assignment]
|
||||
jaxpr, out_type = _add_implicit_outputs(jaxpr)
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, out_type, constvals
|
||||
@ -2184,7 +2197,7 @@ def trace_to_jaxpr_dynamic2(
|
||||
with core.set_current_trace(trace):
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
out_tracers = map(trace.to_jaxpr_tracer, ans)
|
||||
jaxpr = trace.frame.to_jaxpr2(out_tracers)
|
||||
jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info)
|
||||
del trace, in_tracers, out_tracers, ans
|
||||
|
||||
return jaxpr
|
||||
|
@ -240,7 +240,8 @@ def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=Fa
|
||||
def _prune_zeros(ts):
|
||||
return [t for t in ts if type(t) is not ad_util.Zero]
|
||||
|
||||
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
||||
def _make_closed_jaxpr(traceable: lu.WrappedFun,
|
||||
in_avals: Sequence[core.AbstractValue]):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
|
@ -327,7 +327,7 @@ def _cond_with_per_branch_args(pred,
|
||||
lambda op: false_fun(op[1]),
|
||||
(true_operand, false_operand))
|
||||
|
||||
def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects:
|
||||
def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
|
||||
joined_effects = set()
|
||||
for b in branches:
|
||||
for eff in b.effects:
|
||||
@ -337,7 +337,8 @@ def _join_cond_effects(branches: Sequence[core.Jaxpr]) -> effects.Effects:
|
||||
joined_effects.add(eff)
|
||||
return joined_effects
|
||||
|
||||
def _cond_abstract_eval(*avals, branches, **_):
|
||||
def _cond_abstract_eval(*avals: core.AbstractValue,
|
||||
branches: Sequence[core.ClosedJaxpr], **_):
|
||||
joined_effects = _join_cond_effects(branches)
|
||||
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
||||
if disallowed_effects:
|
||||
@ -614,10 +615,11 @@ def _merge_branch_residuals(branch_res_avals):
|
||||
# This function augments branch outputs to agree with the merged residual
|
||||
# format: each branch is made to return zero-filled values in the places of
|
||||
# residual outputs that it does not populate.
|
||||
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
||||
def _join_cond_outputs(jaxprs: Sequence[core.ClosedJaxpr],
|
||||
all_res_avals, res_aval_indices_per_jaxpr,
|
||||
num_non_res_outputs):
|
||||
def augment_jaxpr(jaxpr, res_indices):
|
||||
@lu.wrap_init
|
||||
def augment_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
res_indices):
|
||||
def f_aug(*args):
|
||||
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
|
||||
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
|
||||
@ -625,19 +627,21 @@ def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
|
||||
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
|
||||
return outs + list(aug_residuals)
|
||||
|
||||
return _make_closed_jaxpr(f_aug, jaxpr.in_avals)
|
||||
wrapped_f_aug = lu.wrap_init(f_aug, debug_info=jaxpr.jaxpr.debug_info)
|
||||
return _make_closed_jaxpr(wrapped_f_aug, jaxpr.in_avals)
|
||||
|
||||
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
|
||||
# This function augments branch inputs to agree with the merged residual format:
|
||||
# each branch is made to accept all residuals, even though it will ignore those
|
||||
# that it does not read.
|
||||
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
||||
def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr],
|
||||
all_res_avals,
|
||||
res_aval_indices_per_jaxpr):
|
||||
newvar = core.gensym(suffix='_')
|
||||
all_res_vars = map(newvar, all_res_avals)
|
||||
|
||||
def augment_jaxpr(jaxpr, res_indices):
|
||||
def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr:
|
||||
num_res = len(res_indices)
|
||||
res_vars = jaxpr.jaxpr.invars[:num_res]
|
||||
non_res_vars = jaxpr.jaxpr.invars[num_res:]
|
||||
@ -646,9 +650,9 @@ def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
|
||||
aug_invars = aug_res_vars + non_res_vars
|
||||
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
|
||||
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
|
||||
jaxpr.jaxpr.effects)
|
||||
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
||||
return jaxpr_aug
|
||||
jaxpr.jaxpr.effects,
|
||||
jaxpr.jaxpr.debug_info)
|
||||
return core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
|
||||
|
||||
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
|
||||
|
||||
@ -679,8 +683,7 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
||||
|
||||
# Finally, update parameters and form the new eqn.
|
||||
new_params = dict(eqn.params, branches=tuple(dce_branches))
|
||||
new_effects = core.join_effects(*(b.effects for b in dce_branches))
|
||||
new_effects = _join_cond_effects(dce_branches_)
|
||||
new_effects = _join_cond_effects(dce_branches)
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
@ -693,10 +696,10 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
|
||||
return [True, *used_inputs], new_eqn
|
||||
|
||||
|
||||
def _transpose_cond_jaxpr(jaxpr, num_res):
|
||||
def _transpose_cond_jaxpr(jaxpr: core.ClosedJaxpr,
|
||||
num_res: int):
|
||||
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
||||
|
||||
@lu.wrap_init
|
||||
def transposed(*args):
|
||||
res, cts_out = split_list(args, [num_res])
|
||||
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
||||
@ -705,7 +708,9 @@ def _transpose_cond_jaxpr(jaxpr, num_res):
|
||||
_, cts_in = split_list(cts_in, [num_res])
|
||||
return map(ad.instantiate_zeros, cts_in)
|
||||
|
||||
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
||||
return _make_closed_jaxpr(lu.wrap_init(transposed,
|
||||
debug_info=jaxpr.jaxpr.debug_info),
|
||||
res_avals + jaxpr.out_avals)
|
||||
|
||||
def _cond_transpose(cts, *args, branches):
|
||||
index, *ops = args
|
||||
|
@ -70,11 +70,13 @@ for_p.multiple_results = True
|
||||
|
||||
### Tracing utilities
|
||||
|
||||
def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||||
state_avals: Sequence[core.AbstractValue]
|
||||
def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef,
|
||||
state_avals: Sequence[core.AbstractValue],
|
||||
debug_info: core.DebugInfo | None,
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
f, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
|
||||
lu.wrap_init(f, debug_info=debug_info),
|
||||
treedef_tuple((tree_structure(0), state_tree)))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, state_avals)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
@ -129,12 +131,13 @@ def for_loop(nsteps: int | Sequence[int],
|
||||
rest_steps, functools.partial(body, i), vals, unroll=unroll)
|
||||
tree_map(lambda ref, val: ref_set(ref, (), val), refs, vals)
|
||||
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
|
||||
dbg = api_util.debug_info("for_loop", body, (0, init_state), {})
|
||||
nsteps, = nsteps
|
||||
flat_state, state_tree = tree_flatten(init_state)
|
||||
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||||
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
body, state_tree, [idx_aval, *state_avals], dbg)
|
||||
if out_tree != tree_structure(None):
|
||||
raise Exception("`body` should not return anything.")
|
||||
jaxpr = state_utils.hoist_consts_to_refs(jaxpr, index=1)
|
||||
@ -212,6 +215,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry)
|
||||
tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y)
|
||||
assert isinstance(length, int)
|
||||
api_util.save_wrapped_fun_sourceinfo(for_body, f)
|
||||
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse,
|
||||
unroll=unroll)
|
||||
return init, ys
|
||||
@ -572,7 +576,6 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
newvar = core.gensym()
|
||||
resvars = map(newvar, res_avals)
|
||||
|
||||
@lu.wrap_init
|
||||
def known(*known_vals):
|
||||
empty_res = map(ad_util.zeros_like_aval, res_avals)
|
||||
jaxpr_known_args = [*known_vals, *empty_res]
|
||||
@ -581,7 +584,8 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
reverse=reverse, which_linear=jaxpr_known_which_linear,
|
||||
unroll=unroll)
|
||||
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
known, [v.aval for v in known_invars])
|
||||
lu.wrap_init(known, debug_info=jaxpr.debug_info),
|
||||
[v.aval for v in known_invars])
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||||
@ -596,14 +600,14 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
which_linear=which_linear_unknown,
|
||||
unroll=unroll)
|
||||
|
||||
@lu.wrap_init
|
||||
def staged(*res_and_refs):
|
||||
out_flat = for_p.bind(*res_and_refs, **params_staged)
|
||||
_, ans = split_list(out_flat, [num_res])
|
||||
_, ans = partition_list(out_inst, ans)
|
||||
return ans
|
||||
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
staged, [v.aval for v in [*resvars, *eqn.invars]])
|
||||
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info),
|
||||
[v.aval for v in [*resvars, *eqn.invars]])
|
||||
assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars)
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
_, outvars = partition_list(out_inst, eqn.outvars)
|
||||
@ -621,7 +625,7 @@ def _convert_outputs_to_writes(
|
||||
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
|
||||
|
||||
in_avals = [v.aval for v in jaxpr.invars] # [i, *orig_ref_avals]
|
||||
@lu.wrap_init
|
||||
|
||||
def eval_jaxpr(i, *refs):
|
||||
# We split the refs into the original input refs and the dummy residual
|
||||
# refs.
|
||||
@ -641,7 +645,8 @@ def _convert_outputs_to_writes(
|
||||
v.aval.dtype)) # pytype: disable=attribute-error
|
||||
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||||
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
|
||||
[*in_avals, *res_ref_avals])
|
||||
assert not consts
|
||||
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error
|
||||
|
||||
@ -650,7 +655,6 @@ def _convert_inputs_to_reads(
|
||||
loop_invar_res: Sequence[bool]) -> core.Jaxpr:
|
||||
assert not jaxpr.constvars, "Jaxpr should not have constvars"
|
||||
|
||||
@lu.wrap_init
|
||||
def eval_jaxpr(i, *refs):
|
||||
residual_refs, orig_refs = split_list(refs, [num_res])
|
||||
residual_vals = [r[()] if loop_invar else r[i] for r, loop_invar
|
||||
@ -667,7 +671,8 @@ def _convert_inputs_to_reads(
|
||||
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
|
||||
|
||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
|
||||
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info),
|
||||
[i_aval, *res_ref_avals, *orig_ref_avals])
|
||||
return jaxpr
|
||||
|
||||
def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
|
||||
@ -693,7 +698,8 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
|
||||
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
|
||||
return []
|
||||
jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
||||
lu.wrap_init(trans, debug_info=jaxpr.debug_info),
|
||||
[v.aval for v in jaxpr.invars])
|
||||
return jaxpr_trans
|
||||
|
||||
def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear, unroll):
|
||||
@ -746,8 +752,9 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
|
||||
flat_state, state_tree = tree_flatten(init_state)
|
||||
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||||
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))
|
||||
debug = api_util.debug_info("discharged_for_loop", body, (0, init_state), {})
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
body, state_tree, [idx_aval, *state_avals], debug)
|
||||
if out_tree != tree_structure(None):
|
||||
raise Exception("`body` should not return anything.")
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
|
||||
|
@ -511,7 +511,7 @@ def _empty_array(prefix, length_spec, aval):
|
||||
|
||||
eval_jaxpr_p = core.Primitive('eval_jaxpr')
|
||||
eval_jaxpr_p.multiple_results = True
|
||||
def _stage_jaxpr(trace, *tracers, jaxpr):
|
||||
def _stage_jaxpr(trace: pe.JaxprTrace, *tracers, jaxpr: core.ClosedJaxpr):
|
||||
params = dict(call_jaxpr=jaxpr)
|
||||
return trace.default_process_primitive(core.closed_call_p, tracers, params)
|
||||
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
|
||||
|
@ -334,7 +334,9 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
return x
|
||||
|
||||
|
||||
def _tangent_linear_map(func, params, params_dot, *x):
|
||||
def _tangent_linear_map(func: Callable, params, params_dot,
|
||||
debug_info: core.DebugInfo | None,
|
||||
*x):
|
||||
"""Compute the tangent of a linear map.
|
||||
|
||||
Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
|
||||
@ -342,7 +344,7 @@ def _tangent_linear_map(func, params, params_dot, *x):
|
||||
"""
|
||||
assert any(type(p) is not ad_util.Zero for p in params_dot)
|
||||
zeros = _map(ad_util.Zero.from_primal_value, x)
|
||||
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
|
||||
_, out_tangent = ad.jvp(lu.wrap_init(func, debug_info=debug_info)).call_wrapped(
|
||||
params + list(x), params_dot + zeros)
|
||||
return out_tangent
|
||||
|
||||
@ -369,7 +371,8 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
|
||||
rhs = b_dot
|
||||
else:
|
||||
matvec_tangents = _tangent_linear_map(
|
||||
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
|
||||
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec,
|
||||
jaxprs.matvec.jaxpr.debug_info, *x_leaves)
|
||||
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
||||
|
||||
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
||||
|
@ -1847,8 +1847,8 @@ def reduce(operands: Any,
|
||||
return tree_util.tree_unflatten(out_tree, out)
|
||||
|
||||
@cache()
|
||||
def _reduction_jaxpr(computation, aval):
|
||||
@lu.wrap_init
|
||||
def _reduction_jaxpr(computation: Callable,
|
||||
aval: core.AbstractValue):
|
||||
def comp(x, y):
|
||||
result = computation(x, y)
|
||||
if not (isinstance(result, core.Tracer) or core.valid_jaxtype(result)):
|
||||
@ -1857,7 +1857,11 @@ def _reduction_jaxpr(computation, aval):
|
||||
f"Reduction functions should only return an array.\n"
|
||||
f"Full return value: {result}")
|
||||
return (result,)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
|
||||
comp_wrapped = lu.wrap_init(
|
||||
comp,
|
||||
debug_info=api_util.debug_info("reduction_jaxpr", computation,
|
||||
(aval, aval), {}))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp_wrapped, (aval, aval))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise NotImplementedError(
|
||||
"Reduction computations can't close over Tracers. Please open an issue "
|
||||
|
@ -282,6 +282,10 @@ class DebugInfo(NamedTuple):
|
||||
return self._replace(result_paths=tuple(self.result_paths()))
|
||||
return self
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
return self.func_src_info.split(" ")[0]
|
||||
|
||||
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
|
||||
"""Get the arg_names with a safety check."""
|
||||
if len(self.arg_names) == expected:
|
||||
@ -328,7 +332,13 @@ def wrap_init(f: Callable, params=None, *,
|
||||
@transformation_with_aux2
|
||||
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
result_paths = [keystr(path) for path, _ in generate_key_paths(ans)]
|
||||
if _store:
|
||||
# In some instances a lu.WrappedFun is called multiple times, e.g.,
|
||||
# the bwd function in a custom_vjp
|
||||
assert _store.val == result_paths, (_store, result_paths)
|
||||
else:
|
||||
_store.store(result_paths)
|
||||
return ans
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
|
||||
|
@ -2352,7 +2352,8 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
|
||||
|
||||
|
||||
@lu.cache
|
||||
def _pjit_transpose_trace(fun, in_avals):
|
||||
def _pjit_transpose_trace(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[core.AbstractValue]):
|
||||
transpose_jaxpr, _, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_avals)
|
||||
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
|
||||
@ -2360,13 +2361,15 @@ def _pjit_transpose_trace(fun, in_avals):
|
||||
|
||||
|
||||
def _pjit_transpose(cts_in, *primals_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
|
||||
body = lu.wrap_init(ad.closed_backward_pass)
|
||||
body = lu.wrap_init(ad.closed_backward_pass,
|
||||
debug_info=jaxpr.jaxpr._debug_info)
|
||||
body = lu.hashable_partial(body, jaxpr, False)
|
||||
primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in))
|
||||
body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef)
|
||||
|
@ -79,7 +79,8 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
|
||||
if isinstance(v.aval, AbstractRef) and d
|
||||
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
|
||||
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
|
||||
should_discharge, consts))
|
||||
should_discharge, consts),
|
||||
debug_info=jaxpr.debug_info)
|
||||
new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
|
||||
return new_jaxpr, new_consts
|
||||
|
||||
|
@ -29,6 +29,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import callback
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -58,7 +59,6 @@ from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
split_list, subs_list2)
|
||||
from jax.api_util import flatten_fun_nokwargs, argnums_partial
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -160,9 +160,10 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs,
|
||||
@util.wraps(f)
|
||||
@traceback_util.api_boundary
|
||||
def wrapped(*args):
|
||||
fun = lu.wrap_init(f)
|
||||
fun = lu.wrap_init(f,
|
||||
debug_info=api_util.debug_info("shard_map", f, args, {}))
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
|
||||
try: in_specs_flat = broadcast_prefix(in_specs, args,
|
||||
is_leaf=lambda x: x is None)
|
||||
except ValueError:
|
||||
@ -170,7 +171,7 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs,
|
||||
raise e('shard_map in_specs') from None
|
||||
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
|
||||
if s is not None)
|
||||
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False)
|
||||
fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False)
|
||||
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
|
||||
@ -460,13 +461,16 @@ class ShardMapPrimitive(core.Primitive):
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
def bind_with_trace(self, trace, fun_and_args, params):
|
||||
fun: lu.WrappedFun
|
||||
fun, *args = fun_and_args
|
||||
return trace.process_shard_map(shard_map_p, fun, args, **params)
|
||||
|
||||
def get_bind_params(self, params):
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
|
||||
jaxpr: core.Jaxpr = new_params.pop('jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr,
|
||||
debug_info=jaxpr.debug_info),
|
||||
jaxpr, ())
|
||||
axes = new_params.pop('out_names')
|
||||
new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
||||
return [subfun], new_params
|
||||
@ -1511,7 +1515,8 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
|
||||
ad.JVPTrace.process_shard_map = _shard_map_jvp
|
||||
|
||||
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
||||
f: lu.WrappedFun, tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
tracers = map(trace.to_jaxpr_tracer, tracers)
|
||||
in_pvals = [t.pval for t in tracers]
|
||||
@ -1519,7 +1524,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
all_names = _all_mesh_names_except_spmd(mesh, trace)
|
||||
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
||||
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, False)
|
||||
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
|
||||
f = _promote_scalar_residuals(f)
|
||||
f_known, aux = pe.partial_eval_wrapper_nounits(
|
||||
f, (*in_knowns,), (*in_avals_sharded,))
|
||||
@ -1545,7 +1550,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
res_names = [known_in_names[f1] if f1 is not None else
|
||||
known_out_names_[f2] if f2 is not None else
|
||||
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
|
||||
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment]
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
env_tracers = map(trace.to_jaxpr_tracer, env)
|
||||
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
||||
@ -1628,7 +1633,7 @@ def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
|
||||
return residuals + primals
|
||||
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals(f, *args, **kwargs):
|
||||
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
|
||||
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
|
||||
which = [f1 is None and f2 is None and not v.aval.shape
|
||||
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
|
||||
@ -1686,7 +1691,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
return out
|
||||
|
||||
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
|
||||
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
|
||||
fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree)
|
||||
|
||||
new_in_names = \
|
||||
[n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \
|
||||
|
@ -9520,8 +9520,11 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
def fwd(x):
|
||||
return np.array([2.0])*x*x/np.array([1.0]), (x,)
|
||||
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
|
||||
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
|
||||
|
||||
self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE
|
||||
self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed
|
||||
|
||||
@ -9530,8 +9533,10 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
return (np.array([1.0])*x)[0]
|
||||
def fwd(x):
|
||||
return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
|
||||
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
|
||||
self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x)
|
||||
self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x)
|
||||
|
||||
@ -9540,11 +9545,15 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
return x
|
||||
def fwd(x):
|
||||
return x*x, (x,)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
||||
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
|
||||
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
|
||||
|
||||
def g(x):
|
||||
return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x)
|
||||
x = jnp.linspace(0, 5.0, 10)
|
||||
|
||||
self.assertAllClose(jax.jit(g)(x)[0], x*x)
|
||||
self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x)
|
||||
|
||||
@ -9553,7 +9562,10 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
return x**2
|
||||
def fwd_(x):
|
||||
return x*x, (x,)
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_)
|
||||
|
||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||
fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}),
|
||||
fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {}))
|
||||
calc = jax.jvp(fwd, (3.2,), (1.0,))
|
||||
expected = jax.jvp(fwd_, (3.2,), (1.0,))
|
||||
self.assertAllClose(calc, expected)
|
||||
|
@ -31,9 +31,13 @@ import jax.custom_transpose
|
||||
from jax.experimental import checkify
|
||||
import jax.experimental.custom_dce
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.shard_map import shard_map
|
||||
import jax.numpy as jnp
|
||||
import jax.scipy as jsp
|
||||
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src import config
|
||||
@ -159,14 +163,17 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
t_debug_info = _debug_info_to_string(t._debug_info)
|
||||
if check_tracer_arg_name:
|
||||
msg = str(exc)
|
||||
m = re.match(r".* while tracing the function (.+) for (.+)\. .* depends on the value of the argument ([^\n]+)\.",
|
||||
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
|
||||
msg,
|
||||
re.DOTALL)
|
||||
self.assertIsNotNone(m, msg)
|
||||
self.assertEqual(t._debug_info.func_src_info, m.group(1))
|
||||
self.assertEqual(t._debug_info.traced_for, m.group(2))
|
||||
m = re.match(r".* depends on the value of the argument ([^\n]+)\.",
|
||||
msg,
|
||||
re.DOTALL)
|
||||
found_tracer_debug_infos.append(
|
||||
f"{t_debug_info}, from {m.group(3)}")
|
||||
f"{t_debug_info}, from {m.group(1) if m else None}")
|
||||
else:
|
||||
found_tracer_debug_infos.append(t_debug_info)
|
||||
else:
|
||||
@ -804,7 +811,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=",
|
||||
"None", # TODO(necula): missing debug info
|
||||
# TODO(necula): arg_names?
|
||||
"traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]",
|
||||
],
|
||||
tracer_spy=tracer_spy,
|
||||
check_tracer_arg_name=True,
|
||||
@ -837,8 +845,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=[0],[1]",
|
||||
# TODO(necula): result_paths
|
||||
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
|
||||
# TODO(necula): missing debug info
|
||||
"None",
|
||||
# TODO(necula): arg_names
|
||||
"traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
@ -854,6 +862,138 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"),
|
||||
])
|
||||
|
||||
def test_custom_jvp(self):
|
||||
tracer_spy = TracerSpy()
|
||||
@jax.custom_jvp
|
||||
def my_fun(x, y, c=1.):
|
||||
tracer_spy.append(y)
|
||||
return c * (x + y)
|
||||
def my_jvp(primals, tangents):
|
||||
x, y, c = primals
|
||||
t_x, t_y, t_c = tangents
|
||||
tracer_spy.append(t_y)
|
||||
return my_fun(x, y, c), t_c
|
||||
my_fun.defjvp(my_jvp)
|
||||
|
||||
def top_f(x, y):
|
||||
return jnp.square(my_fun(x, y, c=2.)).sum()
|
||||
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda a: jax.jvp(top_f, (a, a),
|
||||
(jnp.ones_like(a), jnp.ones_like(a)))),
|
||||
42.,
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=[0],[1]",
|
||||
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=a, from None",
|
||||
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y",
|
||||
])
|
||||
|
||||
def test_custom_jvp_nondiff_args(self):
|
||||
tracer_spy = TracerSpy()
|
||||
def top_f(xy):
|
||||
tracer_spy.append(xy[0])
|
||||
@functools.partial(jax.custom_jvp, nondiff_argnums=(0,))
|
||||
def my_g(h, xy):
|
||||
x, y = xy
|
||||
tracer_spy.append(x)
|
||||
return h(x)
|
||||
@my_g.defjvp
|
||||
def my_g_jvp(h, primals, tangents):
|
||||
(x, y), = primals
|
||||
(xt, yt), = tangents
|
||||
tracer_spy.append(xt)
|
||||
return my_g(h, (x, y)), 2. * xt
|
||||
h = lambda y: xy[0] + y # capture x
|
||||
return my_g(h, xy)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(lambda a, b: jax.jvp(top_f, ((a, b),),
|
||||
((jnp.ones_like(a), jnp.ones_like(b)),))),
|
||||
42., 43.,
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): arg_names
|
||||
"traced_for=jit, fun=<lambda>, arg_names=None,a,b, result_paths=[0],[1]",
|
||||
"traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]",
|
||||
# TODO(necula): from None
|
||||
"traced_for=jit, fun=<lambda>, arg_names=a,b, from None",
|
||||
"None", # TODO(necula): None
|
||||
])
|
||||
|
||||
def test_custom_vjp(self):
|
||||
tracer_spy = TracerSpy()
|
||||
@jax.custom_vjp
|
||||
def my_f(x):
|
||||
tracer_spy.append(x["a"])
|
||||
return {"b": jnp.sin(x["a"])}
|
||||
def my_f_fwd(x):
|
||||
tracer_spy.append(x["a"])
|
||||
return my_f(x), {"r": jnp.cos(x["a"])}
|
||||
def my_f_bwd(res, g):
|
||||
tracer_spy.append(g["b"])
|
||||
cos_x = res["r"]
|
||||
return ({"a": 2 * cos_x * g["b"]},)
|
||||
my_f.defvjp(my_f_fwd, my_f_bwd)
|
||||
|
||||
def to_diff(x):
|
||||
return my_f(x)["b"]
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(jax.grad(to_diff)),
|
||||
{"a" : 3.},
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']",
|
||||
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
|
||||
# TODO(necula): from None
|
||||
"traced_for=jit, fun=to_diff, arg_names=x['a'], from None",
|
||||
"traced_for=jit, fun=to_diff, arg_names=x['a'], from x['a']",
|
||||
])
|
||||
|
||||
def test_custom_vjp_nondiff_args(self):
|
||||
tracer_spy = TracerSpy()
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
|
||||
def app(f, xy):
|
||||
tracer_spy.append(xy[0])
|
||||
return f(xy)
|
||||
def app_fwd(f, xy):
|
||||
tracer_spy.append(xy[0])
|
||||
return app(f, xy), jnp.cos(xy[0])
|
||||
def app_rev(f, cos_x0, g):
|
||||
tracer_spy.append(cos_x0)
|
||||
tracer_spy.append(g)
|
||||
return ((cos_x0 * g, cos_x0),)
|
||||
app.defvjp(app_fwd, app_rev)
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(jax.grad(lambda xy: app(lambda x: 2 * x[0], xy))),
|
||||
(3., 3.),
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=[0],[1]",
|
||||
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=",
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from xy[0]",
|
||||
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]",
|
||||
# TODO(necula): from None
|
||||
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from None",
|
||||
])
|
||||
|
||||
def test_vmap_of_nested_jit(self):
|
||||
tracer_spy = TracerSpy()
|
||||
@ -995,8 +1135,12 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=',
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
'None',
|
||||
# TODO(necula): arg_names? result_paths?
|
||||
"traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,",
|
||||
"traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,",
|
||||
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]",
|
||||
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]",
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
'traced_for=cond, fun=my_true_branch, arg_names=a,b',
|
||||
@ -1020,7 +1164,9 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def my_f(x, as_):
|
||||
tracer_spy.append(x)
|
||||
return jax.remat(lambda *args: for_loop.scan(f, *args))(c, as_)
|
||||
def to_remat(a, b):
|
||||
return for_loop.scan(f, a, b)
|
||||
return jax.remat(to_remat)(c, as_)
|
||||
|
||||
def the_grad(c, as_):
|
||||
tracer_spy.append(c)
|
||||
@ -1035,13 +1181,21 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]",
|
||||
# TODO(necula): bad result paths
|
||||
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
|
||||
'None', # TODO(necula): some Jaxprs without debug info
|
||||
# TODO(necula): arg_names?
|
||||
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=",
|
||||
"traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,",
|
||||
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,",
|
||||
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=",
|
||||
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,",
|
||||
"traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,",
|
||||
"traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=jit, fun=the_grad, arg_names=c,as_",
|
||||
"traced_for=scan, fun=f, arg_names=c,a",
|
||||
"traced_for=jit, fun=my_f, arg_names=x,as_",
|
||||
'None', # TODO(necula): some missing debug info
|
||||
# TODO(necula): arg_names
|
||||
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]",
|
||||
],
|
||||
expected_lowering_lines=[
|
||||
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
|
||||
@ -1321,12 +1475,44 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
|
||||
# TODO(necula): some Jaxprs without debug info
|
||||
"None"],
|
||||
# TODO(necula): arg_names?
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=checkpoint / remat, fun=my_g, arg_names=y"
|
||||
])
|
||||
|
||||
def test_remat_shard_map(self):
|
||||
tracer_spy = TracerSpy()
|
||||
if len(jax.devices()) < 2:
|
||||
self.skipTest("requires at least 2 devices")
|
||||
# this tests remat-of-shmap
|
||||
mesh = Mesh(np.array(jax.devices()[:2]), ('x',))
|
||||
|
||||
# check param updating is handled
|
||||
@jax.remat
|
||||
@functools.partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def my_f(x):
|
||||
tracer_spy.append(x)
|
||||
return jnp.sin(jnp.sin(x))
|
||||
|
||||
self._check_tracers_and_jaxprs(
|
||||
jax.jit(jax.grad(lambda x: my_f(x).sum())),
|
||||
jnp.arange(2, dtype=np.float32),
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): arg_names
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
|
||||
"traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=",
|
||||
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=",
|
||||
"traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=",
|
||||
"None", # TODO(necula): missing
|
||||
],
|
||||
check_tracer_arg_name=True,
|
||||
expected_tracer_debug_infos=[
|
||||
"None" # TODO(necula): missing
|
||||
])
|
||||
|
||||
def test_remat_saved_residuals(self):
|
||||
@functools.partial(jax.remat,
|
||||
static_argnums=(1,),
|
||||
@ -1359,7 +1545,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
# TODO(necula): this should not be pointing into the JAX internals
|
||||
re.compile(r"traced_for=jit, fun=checked_fun at .*jax/_src/checkify.py:.*, arg_names=args\[0\]"),
|
||||
re.compile(r"traced_for=jit, fun=argsort at .*numpy/lax_numpy.py:.*, arg_names=a, result_paths="),
|
||||
"None", # TODO(necula): missing tracer debug info
|
||||
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=pmap, fun=my_f, arg_names=my_x",
|
||||
@ -1495,7 +1681,8 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
tracer_spy=tracer_spy,
|
||||
expected_jaxpr_debug_infos=[
|
||||
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0],[1]",
|
||||
"None", # TODO(necula): there are missing Jaxpr debug info
|
||||
# TODO(necula): internal function?
|
||||
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*/control_flow/solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"),
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
"traced_for=custom_root, fun=my_f, arg_names=x",
|
||||
|
Loading…
x
Reference in New Issue
Block a user