mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Reverts 342cb7b99a09180472823a33c7cdad8a8db77875
PiperOrigin-RevId: 733782497
This commit is contained in:
parent
4493889cda
commit
4a93c8b30c
@ -23,9 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
true, matching the current behavior. If set to false, JAX does not need to
|
true, matching the current behavior. If set to false, JAX does not need to
|
||||||
emit code clamping negative indices, which improves code size.
|
emit code clamping negative indices, which improves code size.
|
||||||
|
|
||||||
* Breaking changes
|
|
||||||
* The ``jax.custom_derivatives.remat_opt_p`` helper primitive was removed.
|
|
||||||
|
|
||||||
## jax 0.5.1 (Feb 24, 2025)
|
## jax 0.5.1 (Feb 24, 2025)
|
||||||
|
|
||||||
* New Features
|
* New Features
|
||||||
|
@ -16,7 +16,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import update_wrapper, reduce, partial
|
from functools import update_wrapper, reduce, partial, wraps
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from jax._src import config
|
from jax._src import config
|
||||||
@ -32,7 +32,6 @@ from jax._src.ad_util import (
|
|||||||
from jax._src.api_util import (
|
from jax._src.api_util import (
|
||||||
argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
|
argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
|
||||||
prepend_static_args, debug_info)
|
prepend_static_args, debug_info)
|
||||||
from jax._src.custom_dce import custom_dce
|
|
||||||
from jax._src.errors import UnexpectedTracerError
|
from jax._src.errors import UnexpectedTracerError
|
||||||
from jax._src.state.types import AbstractRef
|
from jax._src.state.types import AbstractRef
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
@ -658,12 +657,10 @@ class custom_vjp(Generic[ReturnValue]):
|
|||||||
# TODO(necula): figure out how to construct the debug_bwd args
|
# TODO(necula): figure out how to construct the debug_bwd args
|
||||||
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {})
|
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {})
|
||||||
if self.optimize_remat:
|
if self.optimize_remat:
|
||||||
if self.symbolic_zeros:
|
|
||||||
# TODO(dfm): This probably shouldn't be too hard to support.
|
|
||||||
raise NotImplementedError(
|
|
||||||
"remat optimization for custom_vjp does not support symbolic zeros")
|
|
||||||
fwd = optimize_remat_of_custom_vjp_fwd(
|
fwd = optimize_remat_of_custom_vjp_fwd(
|
||||||
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums)
|
self.fun, debug_fun, self.fwd, debug_fwd,
|
||||||
|
nondiff_argnums=self.nondiff_argnums,
|
||||||
|
symbolic_zeros=self.symbolic_zeros)
|
||||||
else:
|
else:
|
||||||
fwd = self.fwd
|
fwd = self.fwd
|
||||||
if config.enable_custom_vjp_by_custom_transpose.value:
|
if config.enable_custom_vjp_by_custom_transpose.value:
|
||||||
@ -1574,31 +1571,229 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
|
|||||||
# simpler, but it would be worth revisiting this.
|
# simpler, but it would be worth revisiting this.
|
||||||
def optimize_remat_of_custom_vjp_fwd(
|
def optimize_remat_of_custom_vjp_fwd(
|
||||||
fun: Callable[..., ReturnValue],
|
fun: Callable[..., ReturnValue],
|
||||||
|
debug_fun: core.DebugInfo,
|
||||||
fwd: Callable[..., tuple[ReturnValue, Any]],
|
fwd: Callable[..., tuple[ReturnValue, Any]],
|
||||||
|
debug_fwd: core.DebugInfo,
|
||||||
nondiff_argnums: Sequence[int] = (),
|
nondiff_argnums: Sequence[int] = (),
|
||||||
|
symbolic_zeros: bool = False,
|
||||||
) -> Callable[..., tuple[ReturnValue, Any]]:
|
) -> Callable[..., tuple[ReturnValue, Any]]:
|
||||||
wrapped_fwd = custom_dce(
|
if symbolic_zeros:
|
||||||
# It might seem like we don't need this lambda, but there are some real
|
# TODO(dfm): This probably shouldn't be too hard to support.
|
||||||
# world use cases where the signature of `fwd` is wrong, and we shouldn't
|
raise NotImplementedError(
|
||||||
# error out when resolving the arguments in those cases. This is fine,
|
"remat optimization for custom_vjp does not support symbolic zeros")
|
||||||
# because the arguments have already been resolved in custom_vjp.
|
|
||||||
lambda *args: fwd(*args), # pylint: disable=unnecessary-lambda
|
|
||||||
static_argnums=nondiff_argnums,
|
|
||||||
)
|
|
||||||
|
|
||||||
@wrapped_fwd.def_dce
|
@wraps(fwd)
|
||||||
def _(*args):
|
def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
|
||||||
static_args, used_outs, args = split_list(args, [len(nondiff_argnums), 1])
|
# TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
|
||||||
static_args_iter = iter(static_args)
|
# above and it would be good to consolidate it.
|
||||||
args_iter = iter(args)
|
fwd_name = debug_fwd.func_name if debug_fwd else str(fwd)
|
||||||
nondiff_argnums_ = set(nondiff_argnums)
|
# Note: we use `fun` instead of `fwd` here for consistency with
|
||||||
fun_args = [
|
# custom_vjp.__call__ above.
|
||||||
next(static_args_iter) if i in nondiff_argnums_ else next(args_iter)
|
args = resolve_kwargs(fun, args, kwargs)
|
||||||
for i in range(len(static_args) + len(args))]
|
if nondiff_argnums:
|
||||||
used_outs, = used_outs
|
for i in nondiff_argnums: _check_for_tracers(args[i])
|
||||||
_, used_res = used_outs
|
nondiff_argnums_ = set(nondiff_argnums)
|
||||||
if any(tree_leaves(used_res)):
|
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
|
||||||
return fwd(*fun_args)
|
f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun),
|
||||||
return fun(*fun_args), None
|
dyn_argnums,
|
||||||
|
args, require_static_args_hashable=False)
|
||||||
|
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
|
||||||
|
dyn_argnums, args,
|
||||||
|
require_static_args_hashable=False)
|
||||||
|
else:
|
||||||
|
f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args
|
||||||
|
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
|
||||||
|
args_flat, in_tree = tree_flatten(dyn_args)
|
||||||
|
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||||
|
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
|
||||||
|
debug_fun, debug_fwd, in_tree, out_type)
|
||||||
|
flat_fwd = _fix_fwd_args(flat_fwd)
|
||||||
|
|
||||||
|
in_avals = [core.get_aval(x) for x in args_flat]
|
||||||
|
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
|
||||||
|
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
|
||||||
|
prim_tree, res_tree = out_trees()
|
||||||
|
num_res = res_tree.num_leaves
|
||||||
|
|
||||||
|
if fwd_jaxpr.effects:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"remat optimization for custom_vjp does not support forward "
|
||||||
|
f"functions with side effects, but {fwd_name} has the following "
|
||||||
|
f"effects: {fwd_jaxpr.effects}")
|
||||||
|
|
||||||
|
@pe._memoize
|
||||||
|
def fun_jaxpr_thunk():
|
||||||
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||||
|
return jaxpr, consts
|
||||||
|
|
||||||
|
out_flat = remat_opt_p.bind(*consts, *args_flat,
|
||||||
|
num_consts=len(consts),
|
||||||
|
num_res=num_res,
|
||||||
|
fwd_jaxpr=fwd_jaxpr,
|
||||||
|
fun_jaxpr_thunk=fun_jaxpr_thunk)
|
||||||
|
res, out_flat = split_list(out_flat, [num_res])
|
||||||
|
out_tree = treedef_tuple((prim_tree, res_tree))
|
||||||
|
return tree_unflatten(out_tree, (*out_flat, *res))
|
||||||
|
|
||||||
return wrapped_fwd
|
return wrapped_fwd
|
||||||
|
|
||||||
|
@lu.transformation2
|
||||||
|
def _fix_fwd_args(f, *args):
|
||||||
|
args = [(x, True) for x in args]
|
||||||
|
args = [x for pair in args for x in pair]
|
||||||
|
return f(*args)
|
||||||
|
|
||||||
|
def _remat_opt_impl(
|
||||||
|
*args,
|
||||||
|
num_consts: int,
|
||||||
|
num_res: int,
|
||||||
|
fwd_jaxpr: core.ClosedJaxpr,
|
||||||
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||||
|
):
|
||||||
|
del num_consts, num_res, fun_jaxpr_thunk # unused
|
||||||
|
return core.jaxpr_as_fun(fwd_jaxpr)(*args)
|
||||||
|
|
||||||
|
def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
|
||||||
|
del args
|
||||||
|
return fwd_jaxpr.out_avals, fwd_jaxpr.effects
|
||||||
|
|
||||||
|
def _remat_opt_vmap(
|
||||||
|
axis_data, args, in_dims,
|
||||||
|
*,
|
||||||
|
num_consts: int,
|
||||||
|
num_res: int,
|
||||||
|
fwd_jaxpr: core.ClosedJaxpr,
|
||||||
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||||
|
):
|
||||||
|
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||||
|
else x for x, d in zip(args, in_dims)]
|
||||||
|
in_batched = [d is not not_mapped for d in in_dims]
|
||||||
|
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
|
||||||
|
fwd_jaxpr, axis_data, in_batched, False)
|
||||||
|
extra_consts = batched_fwd_jaxpr.consts
|
||||||
|
batched_fwd_jaxpr = pe.close_jaxpr(
|
||||||
|
pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr))
|
||||||
|
out_dims = [0 if b else not_mapped for b in out_batched]
|
||||||
|
|
||||||
|
_, prim_batched = split_list(in_batched, [num_consts])
|
||||||
|
|
||||||
|
@pe._memoize
|
||||||
|
def batched_fun_jaxpr_thunk():
|
||||||
|
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
|
||||||
|
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
||||||
|
fun_jaxpr, axis_data, prim_batched, False)
|
||||||
|
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts
|
||||||
|
|
||||||
|
batched_outs = remat_opt_p.bind(*extra_consts, *args,
|
||||||
|
num_consts=num_consts + len(extra_consts),
|
||||||
|
num_res=num_res,
|
||||||
|
fwd_jaxpr=batched_fwd_jaxpr,
|
||||||
|
fun_jaxpr_thunk=batched_fun_jaxpr_thunk)
|
||||||
|
|
||||||
|
return batched_outs, out_dims
|
||||||
|
|
||||||
|
def _remat_opt_jvp(
|
||||||
|
primals,
|
||||||
|
tangents,
|
||||||
|
*,
|
||||||
|
num_consts: int,
|
||||||
|
num_res: int,
|
||||||
|
fwd_jaxpr: core.ClosedJaxpr,
|
||||||
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||||
|
):
|
||||||
|
consts, primals = split_list(primals, [num_consts])
|
||||||
|
consts_dot, tangents = split_list(tangents, [num_consts])
|
||||||
|
# Tangents must be instantated in case we end up DCEing later.
|
||||||
|
tangents = map(ad.instantiate_zeros, tangents)
|
||||||
|
consts_nz = [not isinstance(t, Zero) for t in consts_dot]
|
||||||
|
consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz]
|
||||||
|
in_nz = consts_nz + [True] * len(tangents)
|
||||||
|
fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True)
|
||||||
|
num_out = len(out_nz) - num_res
|
||||||
|
fwd_jaxpr_jvp_ = ad.rearrange_binders(
|
||||||
|
fwd_jaxpr_jvp_, [num_consts, len(primals)],
|
||||||
|
[len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out])
|
||||||
|
fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr))
|
||||||
|
|
||||||
|
# @pe._memoize
|
||||||
|
def fun_jvp_jaxpr_thunk():
|
||||||
|
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
|
||||||
|
in_nz = [True] * len(primals)
|
||||||
|
fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True)
|
||||||
|
return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts
|
||||||
|
|
||||||
|
new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot)
|
||||||
|
outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot,
|
||||||
|
*primals, *tangents, num_consts=new_num_consts,
|
||||||
|
num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp,
|
||||||
|
fun_jaxpr_thunk=fun_jvp_jaxpr_thunk)
|
||||||
|
res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out])
|
||||||
|
return (*res, *outs), (*res_dot, *outs_dot)
|
||||||
|
|
||||||
|
def _remat_opt_transpose(
|
||||||
|
cts, *args,
|
||||||
|
num_consts: int,
|
||||||
|
num_res: int,
|
||||||
|
fwd_jaxpr: core.ClosedJaxpr,
|
||||||
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
||||||
|
):
|
||||||
|
# TODO(dfm): It shouldn't be too hard to implement this as needed in the
|
||||||
|
# future.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"remat optimization for custom_vjp does not support higher-order AD")
|
||||||
|
|
||||||
|
def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
|
||||||
|
if not any(used_outs) and not pe.has_effects(eqn):
|
||||||
|
return [False] * len(eqn.invars), None
|
||||||
|
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
|
||||||
|
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
|
||||||
|
if any(used_res):
|
||||||
|
# If any of the residuals are used, we still need to run fwd at this point,
|
||||||
|
# but we may end up DCEing again in the future, so we must instantiate all
|
||||||
|
# the input primals.
|
||||||
|
instantiate = [False] * eqn.params["num_consts"]
|
||||||
|
instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"])
|
||||||
|
new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs,
|
||||||
|
instantiate=instantiate)
|
||||||
|
assert not new_jaxpr.constvars
|
||||||
|
closed_jaxpr = pe.close_jaxpr(new_jaxpr)
|
||||||
|
invars = [v for used, v in zip(used_ins, eqn.invars) if used]
|
||||||
|
new_params = dict(eqn.params)
|
||||||
|
new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0])
|
||||||
|
new_params["num_consts"] = new_num_consts
|
||||||
|
new_params["fwd_jaxpr"] = closed_jaxpr
|
||||||
|
new_params["num_res"] = sum(used_res)
|
||||||
|
new_eqn = pe.new_jaxpr_eqn(
|
||||||
|
invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects,
|
||||||
|
eqn.source_info, eqn.ctx)
|
||||||
|
return used_ins, new_eqn
|
||||||
|
else:
|
||||||
|
# If none of the residuals are used, we run the primal computation instead.
|
||||||
|
# At this point we drop this custom DCE behavior, but since the primal might
|
||||||
|
# have different consts than fwd, we build a new JaxprEqn with a closed_call
|
||||||
|
# primitive.
|
||||||
|
fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]()
|
||||||
|
new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims)
|
||||||
|
consts = [c for used, c in zip(used_consts, consts) if used]
|
||||||
|
closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
|
||||||
|
_, invars = split_list(eqn.invars, [eqn.params["num_consts"]])
|
||||||
|
invars = [v for used, v in zip(used_ins, invars) if used]
|
||||||
|
new_eqn = pe.new_jaxpr_eqn(
|
||||||
|
invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr),
|
||||||
|
closed_jaxpr.effects, eqn.source_info, eqn.ctx)
|
||||||
|
used_ins = [False] * eqn.params["num_consts"] + used_ins
|
||||||
|
return used_ins, new_eqn
|
||||||
|
|
||||||
|
remat_opt_p = core.Primitive("remat_opt")
|
||||||
|
remat_opt_p.multiple_results = True
|
||||||
|
remat_opt_p.def_impl(_remat_opt_impl)
|
||||||
|
remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval)
|
||||||
|
xla.register_initial_style_primitive(remat_opt_p)
|
||||||
|
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
|
||||||
|
_remat_opt_impl, multiple_results=True))
|
||||||
|
|
||||||
|
|
||||||
|
batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap
|
||||||
|
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
|
||||||
|
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
|
||||||
|
pe.dce_rules[remat_opt_p] = _remat_opt_dce
|
||||||
|
@ -30,6 +30,7 @@ from jax._src.custom_derivatives import (
|
|||||||
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
|
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
|
||||||
CustomVJPPrimal as CustomVJPPrimal,
|
CustomVJPPrimal as CustomVJPPrimal,
|
||||||
linear_call as linear_call,
|
linear_call as linear_call,
|
||||||
|
remat_opt_p as remat_opt_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
from jax._src.ad_util import (
|
from jax._src.ad_util import (
|
||||||
|
@ -45,7 +45,6 @@ from jax._src import api
|
|||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import custom_dce
|
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
@ -3474,14 +3473,15 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
|
|||||||
tf_impl[ad.custom_lin_p] = _custom_lin
|
tf_impl[ad.custom_lin_p] = _custom_lin
|
||||||
|
|
||||||
|
|
||||||
def _custom_dce(*args: TfVal, num_consts: int, fun_jaxpr: core.ClosedJaxpr,
|
def _remat_opt(*args: TfVal, num_consts: int, num_res: int,
|
||||||
dce_jaxpr_thunk: Callable) -> Sequence[TfVal]:
|
fwd_jaxpr: core.ClosedJaxpr,
|
||||||
del num_consts, dce_jaxpr_thunk
|
fun_jaxpr_thunk: Callable) -> Sequence[TfVal]:
|
||||||
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_dce_call",
|
del num_consts, num_res, fun_jaxpr_thunk
|
||||||
|
return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt",
|
||||||
fresh_constant_cache=False)
|
fresh_constant_cache=False)
|
||||||
|
|
||||||
|
|
||||||
tf_impl[custom_dce.custom_dce_p] = _custom_dce
|
tf_impl[custom_derivatives.remat_opt_p] = _remat_opt
|
||||||
|
|
||||||
|
|
||||||
PartitionsOrReplicated = Union[tuple[int, ...], None]
|
PartitionsOrReplicated = Union[tuple[int, ...], None]
|
||||||
|
@ -9599,7 +9599,10 @@ class CustomVJPTest(jtu.JaxTestCase):
|
|||||||
return np.array([2.0])*x*x/np.array([1.0]), (x,)
|
return np.array([2.0])*x*x/np.array([1.0]), (x,)
|
||||||
|
|
||||||
x = jnp.linspace(0, 5.0, 10)
|
x = jnp.linspace(0, 5.0, 10)
|
||||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||||
|
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
|
||||||
|
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
|
||||||
|
|
||||||
self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE
|
self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE
|
||||||
self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed
|
self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed
|
||||||
|
|
||||||
@ -9609,7 +9612,9 @@ class CustomVJPTest(jtu.JaxTestCase):
|
|||||||
def fwd(x):
|
def fwd(x):
|
||||||
return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,)
|
return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,)
|
||||||
x = jnp.linspace(0, 5.0, 10)
|
x = jnp.linspace(0, 5.0, 10)
|
||||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||||
|
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
|
||||||
|
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
|
||||||
self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x)
|
self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x)
|
||||||
self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x)
|
self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x)
|
||||||
|
|
||||||
@ -9620,7 +9625,9 @@ class CustomVJPTest(jtu.JaxTestCase):
|
|||||||
return x*x, (x,)
|
return x*x, (x,)
|
||||||
|
|
||||||
x = jnp.linspace(0, 5.0, 10)
|
x = jnp.linspace(0, 5.0, 10)
|
||||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd)
|
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||||
|
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
|
||||||
|
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
|
||||||
|
|
||||||
def g(x):
|
def g(x):
|
||||||
return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x)
|
return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x)
|
||||||
@ -9634,7 +9641,9 @@ class CustomVJPTest(jtu.JaxTestCase):
|
|||||||
def fwd_(x):
|
def fwd_(x):
|
||||||
return x*x, (x,)
|
return x*x, (x,)
|
||||||
|
|
||||||
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_)
|
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
|
||||||
|
fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}),
|
||||||
|
fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {}))
|
||||||
calc = jax.jvp(fwd, (3.2,), (1.0,))
|
calc = jax.jvp(fwd, (3.2,), (1.0,))
|
||||||
expected = jax.jvp(fwd_, (3.2,), (1.0,))
|
expected = jax.jvp(fwd_, (3.2,), (1.0,))
|
||||||
self.assertAllClose(calc, expected)
|
self.assertAllClose(calc, expected)
|
||||||
@ -9731,55 +9740,6 @@ class CustomVJPTest(jtu.JaxTestCase):
|
|||||||
x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5)
|
x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5)
|
||||||
jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error
|
jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error
|
||||||
|
|
||||||
def test_optimize_remat_nondiff_argnums(self):
|
|
||||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
|
|
||||||
def f(x, y, fun):
|
|
||||||
return fun(x, y)
|
|
||||||
|
|
||||||
def f_fwd(x, y, fun):
|
|
||||||
del fun
|
|
||||||
return jnp.cos(x) * y, (jnp.cos(x), jnp.sin(x), y)
|
|
||||||
|
|
||||||
def f_bwd(fun, res, g):
|
|
||||||
del fun
|
|
||||||
cos_x, sin_x, y = res
|
|
||||||
return (cos_x * g * y, sin_x * g)
|
|
||||||
|
|
||||||
def fun(x, y):
|
|
||||||
return jnp.sin(x) * y
|
|
||||||
|
|
||||||
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
|
|
||||||
x, y = 0.5, 0.1
|
|
||||||
res = jax.value_and_grad(lambda *args: f(*args, fun))(x, y)[0]
|
|
||||||
self.assertAllClose(res, f_fwd(x, y, fun)[0])
|
|
||||||
res = jax.jit(lambda *args: jax.value_and_grad(
|
|
||||||
lambda *args: f(*args, fun))(*args)[0])(x, y)
|
|
||||||
self.assertAllClose(res, fun(x, y))
|
|
||||||
|
|
||||||
def test_optimize_remat_incorrect_signature(self):
|
|
||||||
def f_(x, y):
|
|
||||||
return jnp.sin(x) * y
|
|
||||||
|
|
||||||
@jax.custom_vjp
|
|
||||||
def f(x, y):
|
|
||||||
return f_(x, y)
|
|
||||||
|
|
||||||
def wrong_signature(x, y, z):
|
|
||||||
self.fail("wrong_signature should not be called")
|
|
||||||
|
|
||||||
@functools.wraps(wrong_signature)
|
|
||||||
def f_fwd(x, y):
|
|
||||||
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
|
||||||
|
|
||||||
def f_bwd(res, g):
|
|
||||||
cos_x, sin_x, y = res
|
|
||||||
return (cos_x * g * y, sin_x * g)
|
|
||||||
|
|
||||||
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
|
|
||||||
x, y = 3.2, 1.0
|
|
||||||
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y))
|
|
||||||
|
|
||||||
|
|
||||||
def test_dce(self):
|
def test_dce(self):
|
||||||
@jax.custom_vjp
|
@jax.custom_vjp
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
@ -10508,20 +10468,20 @@ class CustomDceTest(jtu.JaxTestCase):
|
|||||||
self.assertAllClose(v, jnp.tan(3.2)**2)
|
self.assertAllClose(v, jnp.tan(3.2)**2)
|
||||||
|
|
||||||
def test_static_argnums(self):
|
def test_static_argnums(self):
|
||||||
@partial(jax.experimental.custom_dce.custom_dce, static_argnums=(1,))
|
@partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,))
|
||||||
def g(x, f):
|
def g(f, x):
|
||||||
return f(x), 10 * f(x)
|
return f(x), 10 * f(x)
|
||||||
|
|
||||||
@g.def_dce
|
@g.def_dce
|
||||||
def g_dce(f, used_outs, x): # note: static_argnums are always passes first
|
def g_dce(f, used_outs, x): # note: static_argnums are always passes first
|
||||||
self.assertTrue(callable(f))
|
self.assertTrue(callable(f))
|
||||||
return [2 * v if used else None for used, v in zip(used_outs, g(x, f))]
|
return [2 * v if used else None for used, v in zip(used_outs, g(f, x))]
|
||||||
|
|
||||||
x = 1.1234
|
x = 1.1234
|
||||||
f = lambda x: jnp.exp(x)
|
f = lambda x: jnp.exp(x)
|
||||||
expected = g(x, f)
|
expected = g(f, x)
|
||||||
self.assertAllClose(jax.jit(lambda x: g(x, f)[0])(x), 2 * expected[0])
|
self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0])
|
||||||
self.assertAllClose(jax.jit(lambda x: g(x, f)[1])(x), 2 * expected[1])
|
self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1])
|
||||||
|
|
||||||
def test_shape_mismatch_error(self):
|
def test_shape_mismatch_error(self):
|
||||||
@jax.experimental.custom_dce.custom_dce
|
@jax.experimental.custom_dce.custom_dce
|
||||||
|
Loading…
x
Reference in New Issue
Block a user