Reverts 342cb7b99a09180472823a33c7cdad8a8db77875

PiperOrigin-RevId: 733782497
This commit is contained in:
Dan Foreman-Mackey 2025-03-05 10:21:52 -08:00 committed by jax authors
parent 4493889cda
commit 4a93c8b30c
5 changed files with 250 additions and 97 deletions

View File

@ -23,9 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
true, matching the current behavior. If set to false, JAX does not need to true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size. emit code clamping negative indices, which improves code size.
* Breaking changes
* The ``jax.custom_derivatives.remat_opt_p`` helper primitive was removed.
## jax 0.5.1 (Feb 24, 2025) ## jax 0.5.1 (Feb 24, 2025)
* New Features * New Features

View File

@ -16,7 +16,7 @@ from __future__ import annotations
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
import dataclasses import dataclasses
from functools import update_wrapper, reduce, partial from functools import update_wrapper, reduce, partial, wraps
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from jax._src import config from jax._src import config
@ -32,7 +32,6 @@ from jax._src.ad_util import (
from jax._src.api_util import ( from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, argnums_partial, flatten_fun_nokwargs, resolve_kwargs,
prepend_static_args, debug_info) prepend_static_args, debug_info)
from jax._src.custom_dce import custom_dce
from jax._src.errors import UnexpectedTracerError from jax._src.errors import UnexpectedTracerError
from jax._src.state.types import AbstractRef from jax._src.state.types import AbstractRef
from jax._src.interpreters import ad from jax._src.interpreters import ad
@ -658,12 +657,10 @@ class custom_vjp(Generic[ReturnValue]):
# TODO(necula): figure out how to construct the debug_bwd args # TODO(necula): figure out how to construct the debug_bwd args
debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {}) debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {})
if self.optimize_remat: if self.optimize_remat:
if self.symbolic_zeros:
# TODO(dfm): This probably shouldn't be too hard to support.
raise NotImplementedError(
"remat optimization for custom_vjp does not support symbolic zeros")
fwd = optimize_remat_of_custom_vjp_fwd( fwd = optimize_remat_of_custom_vjp_fwd(
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums) self.fun, debug_fun, self.fwd, debug_fwd,
nondiff_argnums=self.nondiff_argnums,
symbolic_zeros=self.symbolic_zeros)
else: else:
fwd = self.fwd fwd = self.fwd
if config.enable_custom_vjp_by_custom_transpose.value: if config.enable_custom_vjp_by_custom_transpose.value:
@ -1574,31 +1571,229 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
# simpler, but it would be worth revisiting this. # simpler, but it would be worth revisiting this.
def optimize_remat_of_custom_vjp_fwd( def optimize_remat_of_custom_vjp_fwd(
fun: Callable[..., ReturnValue], fun: Callable[..., ReturnValue],
debug_fun: core.DebugInfo,
fwd: Callable[..., tuple[ReturnValue, Any]], fwd: Callable[..., tuple[ReturnValue, Any]],
debug_fwd: core.DebugInfo,
nondiff_argnums: Sequence[int] = (), nondiff_argnums: Sequence[int] = (),
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, Any]]: ) -> Callable[..., tuple[ReturnValue, Any]]:
wrapped_fwd = custom_dce( if symbolic_zeros:
# It might seem like we don't need this lambda, but there are some real # TODO(dfm): This probably shouldn't be too hard to support.
# world use cases where the signature of `fwd` is wrong, and we shouldn't raise NotImplementedError(
# error out when resolving the arguments in those cases. This is fine, "remat optimization for custom_vjp does not support symbolic zeros")
# because the arguments have already been resolved in custom_vjp.
lambda *args: fwd(*args), # pylint: disable=unnecessary-lambda
static_argnums=nondiff_argnums,
)
@wrapped_fwd.def_dce @wraps(fwd)
def _(*args): def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
static_args, used_outs, args = split_list(args, [len(nondiff_argnums), 1]) # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
static_args_iter = iter(static_args) # above and it would be good to consolidate it.
args_iter = iter(args) fwd_name = debug_fwd.func_name if debug_fwd else str(fwd)
nondiff_argnums_ = set(nondiff_argnums) # Note: we use `fun` instead of `fwd` here for consistency with
fun_args = [ # custom_vjp.__call__ above.
next(static_args_iter) if i in nondiff_argnums_ else next(args_iter) args = resolve_kwargs(fun, args, kwargs)
for i in range(len(static_args) + len(args))] if nondiff_argnums:
used_outs, = used_outs for i in nondiff_argnums: _check_for_tracers(args[i])
_, used_res = used_outs nondiff_argnums_ = set(nondiff_argnums)
if any(tree_leaves(used_res)): dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
return fwd(*fun_args) f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun),
return fun(*fun_args), None dyn_argnums,
args, require_static_args_hashable=False)
fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd),
dyn_argnums, args,
require_static_args_hashable=False)
else:
f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args
fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
debug_fun, debug_fwd, in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)
in_avals = [core.get_aval(x) for x in args_flat]
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
prim_tree, res_tree = out_trees()
num_res = res_tree.num_leaves
if fwd_jaxpr.effects:
raise NotImplementedError(
"remat optimization for custom_vjp does not support forward "
f"functions with side effects, but {fwd_name} has the following "
f"effects: {fwd_jaxpr.effects}")
@pe._memoize
def fun_jaxpr_thunk():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
return jaxpr, consts
out_flat = remat_opt_p.bind(*consts, *args_flat,
num_consts=len(consts),
num_res=num_res,
fwd_jaxpr=fwd_jaxpr,
fun_jaxpr_thunk=fun_jaxpr_thunk)
res, out_flat = split_list(out_flat, [num_res])
out_tree = treedef_tuple((prim_tree, res_tree))
return tree_unflatten(out_tree, (*out_flat, *res))
return wrapped_fwd return wrapped_fwd
@lu.transformation2
def _fix_fwd_args(f, *args):
args = [(x, True) for x in args]
args = [x for pair in args for x in pair]
return f(*args)
def _remat_opt_impl(
*args,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
del num_consts, num_res, fun_jaxpr_thunk # unused
return core.jaxpr_as_fun(fwd_jaxpr)(*args)
def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
del args
return fwd_jaxpr.out_avals, fwd_jaxpr.effects
def _remat_opt_vmap(
axis_data, args, in_dims,
*,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
in_batched = [d is not not_mapped for d in in_dims]
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_data, in_batched, False)
extra_consts = batched_fwd_jaxpr.consts
batched_fwd_jaxpr = pe.close_jaxpr(
pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr))
out_dims = [0 if b else not_mapped for b in out_batched]
_, prim_batched = split_list(in_batched, [num_consts])
@pe._memoize
def batched_fun_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_data, prim_batched, False)
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts
batched_outs = remat_opt_p.bind(*extra_consts, *args,
num_consts=num_consts + len(extra_consts),
num_res=num_res,
fwd_jaxpr=batched_fwd_jaxpr,
fun_jaxpr_thunk=batched_fun_jaxpr_thunk)
return batched_outs, out_dims
def _remat_opt_jvp(
primals,
tangents,
*,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
consts, primals = split_list(primals, [num_consts])
consts_dot, tangents = split_list(tangents, [num_consts])
# Tangents must be instantated in case we end up DCEing later.
tangents = map(ad.instantiate_zeros, tangents)
consts_nz = [not isinstance(t, Zero) for t in consts_dot]
consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz]
in_nz = consts_nz + [True] * len(tangents)
fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True)
num_out = len(out_nz) - num_res
fwd_jaxpr_jvp_ = ad.rearrange_binders(
fwd_jaxpr_jvp_, [num_consts, len(primals)],
[len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out])
fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr))
# @pe._memoize
def fun_jvp_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
in_nz = [True] * len(primals)
fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True)
return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts
new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot)
outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot,
*primals, *tangents, num_consts=new_num_consts,
num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp,
fun_jaxpr_thunk=fun_jvp_jaxpr_thunk)
res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out])
return (*res, *outs), (*res_dot, *outs_dot)
def _remat_opt_transpose(
cts, *args,
num_consts: int,
num_res: int,
fwd_jaxpr: core.ClosedJaxpr,
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
):
# TODO(dfm): It shouldn't be too hard to implement this as needed in the
# future.
raise NotImplementedError(
"remat optimization for custom_vjp does not support higher-order AD")
def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
if not any(used_outs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
if any(used_res):
# If any of the residuals are used, we still need to run fwd at this point,
# but we may end up DCEing again in the future, so we must instantiate all
# the input primals.
instantiate = [False] * eqn.params["num_consts"]
instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"])
new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs,
instantiate=instantiate)
assert not new_jaxpr.constvars
closed_jaxpr = pe.close_jaxpr(new_jaxpr)
invars = [v for used, v in zip(used_ins, eqn.invars) if used]
new_params = dict(eqn.params)
new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0])
new_params["num_consts"] = new_num_consts
new_params["fwd_jaxpr"] = closed_jaxpr
new_params["num_res"] = sum(used_res)
new_eqn = pe.new_jaxpr_eqn(
invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects,
eqn.source_info, eqn.ctx)
return used_ins, new_eqn
else:
# If none of the residuals are used, we run the primal computation instead.
# At this point we drop this custom DCE behavior, but since the primal might
# have different consts than fwd, we build a new JaxprEqn with a closed_call
# primitive.
fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]()
new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims)
consts = [c for used, c in zip(used_consts, consts) if used]
closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
_, invars = split_list(eqn.invars, [eqn.params["num_consts"]])
invars = [v for used, v in zip(used_ins, invars) if used]
new_eqn = pe.new_jaxpr_eqn(
invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr),
closed_jaxpr.effects, eqn.source_info, eqn.ctx)
used_ins = [False] * eqn.params["num_consts"] + used_ins
return used_ins, new_eqn
remat_opt_p = core.Primitive("remat_opt")
remat_opt_p.multiple_results = True
remat_opt_p.def_impl(_remat_opt_impl)
remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval)
xla.register_initial_style_primitive(remat_opt_p)
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
_remat_opt_impl, multiple_results=True))
batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
pe.dce_rules[remat_opt_p] = _remat_opt_dce

View File

@ -30,6 +30,7 @@ from jax._src.custom_derivatives import (
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
CustomVJPPrimal as CustomVJPPrimal, CustomVJPPrimal as CustomVJPPrimal,
linear_call as linear_call, linear_call as linear_call,
remat_opt_p as remat_opt_p,
) )
from jax._src.ad_util import ( from jax._src.ad_util import (

View File

@ -45,7 +45,6 @@ from jax._src import api
from jax._src import api_util from jax._src import api_util
from jax._src import config from jax._src import config
from jax._src import core from jax._src import core
from jax._src import custom_dce
from jax._src import dispatch from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -3474,14 +3473,15 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
tf_impl[ad.custom_lin_p] = _custom_lin tf_impl[ad.custom_lin_p] = _custom_lin
def _custom_dce(*args: TfVal, num_consts: int, fun_jaxpr: core.ClosedJaxpr, def _remat_opt(*args: TfVal, num_consts: int, num_res: int,
dce_jaxpr_thunk: Callable) -> Sequence[TfVal]: fwd_jaxpr: core.ClosedJaxpr,
del num_consts, dce_jaxpr_thunk fun_jaxpr_thunk: Callable) -> Sequence[TfVal]:
return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_dce_call", del num_consts, num_res, fun_jaxpr_thunk
return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt",
fresh_constant_cache=False) fresh_constant_cache=False)
tf_impl[custom_dce.custom_dce_p] = _custom_dce tf_impl[custom_derivatives.remat_opt_p] = _remat_opt
PartitionsOrReplicated = Union[tuple[int, ...], None] PartitionsOrReplicated = Union[tuple[int, ...], None]

View File

@ -9599,7 +9599,10 @@ class CustomVJPTest(jtu.JaxTestCase):
return np.array([2.0])*x*x/np.array([1.0]), (x,) return np.array([2.0])*x*x/np.array([1.0]), (x,)
x = jnp.linspace(0, 5.0, 10) x = jnp.linspace(0, 5.0, 10)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE
self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed
@ -9609,7 +9612,9 @@ class CustomVJPTest(jtu.JaxTestCase):
def fwd(x): def fwd(x):
return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,)
x = jnp.linspace(0, 5.0, 10) x = jnp.linspace(0, 5.0, 10)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x)
self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x)
@ -9620,7 +9625,9 @@ class CustomVJPTest(jtu.JaxTestCase):
return x*x, (x,) return x*x, (x,)
x = jnp.linspace(0, 5.0, 10) x = jnp.linspace(0, 5.0, 10)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}),
fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {}))
def g(x): def g(x):
return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x)
@ -9634,7 +9641,9 @@ class CustomVJPTest(jtu.JaxTestCase):
def fwd_(x): def fwd_(x):
return x*x, (x,) return x*x, (x,)
fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_) fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(
fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}),
fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {}))
calc = jax.jvp(fwd, (3.2,), (1.0,)) calc = jax.jvp(fwd, (3.2,), (1.0,))
expected = jax.jvp(fwd_, (3.2,), (1.0,)) expected = jax.jvp(fwd_, (3.2,), (1.0,))
self.assertAllClose(calc, expected) self.assertAllClose(calc, expected)
@ -9731,55 +9740,6 @@ class CustomVJPTest(jtu.JaxTestCase):
x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5)
jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error
def test_optimize_remat_nondiff_argnums(self):
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def f(x, y, fun):
return fun(x, y)
def f_fwd(x, y, fun):
del fun
return jnp.cos(x) * y, (jnp.cos(x), jnp.sin(x), y)
def f_bwd(fun, res, g):
del fun
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
def fun(x, y):
return jnp.sin(x) * y
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
x, y = 0.5, 0.1
res = jax.value_and_grad(lambda *args: f(*args, fun))(x, y)[0]
self.assertAllClose(res, f_fwd(x, y, fun)[0])
res = jax.jit(lambda *args: jax.value_and_grad(
lambda *args: f(*args, fun))(*args)[0])(x, y)
self.assertAllClose(res, fun(x, y))
def test_optimize_remat_incorrect_signature(self):
def f_(x, y):
return jnp.sin(x) * y
@jax.custom_vjp
def f(x, y):
return f_(x, y)
def wrong_signature(x, y, z):
self.fail("wrong_signature should not be called")
@functools.wraps(wrong_signature)
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
x, y = 3.2, 1.0
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y))
def test_dce(self): def test_dce(self):
@jax.custom_vjp @jax.custom_vjp
def f(x, y): def f(x, y):
@ -10508,20 +10468,20 @@ class CustomDceTest(jtu.JaxTestCase):
self.assertAllClose(v, jnp.tan(3.2)**2) self.assertAllClose(v, jnp.tan(3.2)**2)
def test_static_argnums(self): def test_static_argnums(self):
@partial(jax.experimental.custom_dce.custom_dce, static_argnums=(1,)) @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,))
def g(x, f): def g(f, x):
return f(x), 10 * f(x) return f(x), 10 * f(x)
@g.def_dce @g.def_dce
def g_dce(f, used_outs, x): # note: static_argnums are always passes first def g_dce(f, used_outs, x): # note: static_argnums are always passes first
self.assertTrue(callable(f)) self.assertTrue(callable(f))
return [2 * v if used else None for used, v in zip(used_outs, g(x, f))] return [2 * v if used else None for used, v in zip(used_outs, g(f, x))]
x = 1.1234 x = 1.1234
f = lambda x: jnp.exp(x) f = lambda x: jnp.exp(x)
expected = g(x, f) expected = g(f, x)
self.assertAllClose(jax.jit(lambda x: g(x, f)[0])(x), 2 * expected[0]) self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0])
self.assertAllClose(jax.jit(lambda x: g(x, f)[1])(x), 2 * expected[1]) self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1])
def test_shape_mismatch_error(self): def test_shape_mismatch_error(self):
@jax.experimental.custom_dce.custom_dce @jax.experimental.custom_dce.custom_dce