Merge pull request #22816 from dfm:fix-remat-opt

PiperOrigin-RevId: 658378186
This commit is contained in:
jax authors 2024-08-01 06:16:57 -07:00
commit 6bddaad6c7
2 changed files with 26 additions and 0 deletions

View File

@ -1468,6 +1468,7 @@ def optimize_remat_of_custom_vjp_fwd(
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
in_tree, out_type)
flat_fwd = _fix_fwd_args(flat_fwd)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
@ -1497,6 +1498,12 @@ def optimize_remat_of_custom_vjp_fwd(
return wrapped_fwd
@lu.transformation
def _fix_fwd_args(*args):
args = [(x, True) for x in args]
args = [x for pair in args for x in pair]
yield (yield args, {})
def _remat_opt_impl(
*args,
num_consts: int,

View File

@ -9730,6 +9730,25 @@ class CustomVJPTest(jtu.JaxTestCase):
v, g = jax.value_and_grad(temp)(3.2)
self.assertAllClose(v, jnp.tan(3.2)**2)
def test_optimize_remat_multiple_args(self):
def f_(x, y):
return jnp.sin(x) * y
@jax.custom_vjp
def f(x, y):
return f_(x, y)
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
x, y = 3.2, 1.0
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y))
def transpose_unary(f, x_example):
def transposed(y):