mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22816 from dfm:fix-remat-opt
PiperOrigin-RevId: 658378186
This commit is contained in:
commit
6bddaad6c7
@ -1468,6 +1468,7 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
|
||||
in_tree, out_type)
|
||||
flat_fwd = _fix_fwd_args(flat_fwd)
|
||||
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
|
||||
@ -1497,6 +1498,12 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
|
||||
return wrapped_fwd
|
||||
|
||||
@lu.transformation
|
||||
def _fix_fwd_args(*args):
|
||||
args = [(x, True) for x in args]
|
||||
args = [x for pair in args for x in pair]
|
||||
yield (yield args, {})
|
||||
|
||||
def _remat_opt_impl(
|
||||
*args,
|
||||
num_consts: int,
|
||||
|
@ -9730,6 +9730,25 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
v, g = jax.value_and_grad(temp)(3.2)
|
||||
self.assertAllClose(v, jnp.tan(3.2)**2)
|
||||
|
||||
def test_optimize_remat_multiple_args(self):
|
||||
def f_(x, y):
|
||||
return jnp.sin(x) * y
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x, y):
|
||||
return f_(x, y)
|
||||
|
||||
def f_fwd(x, y):
|
||||
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
||||
|
||||
def f_bwd(res, g):
|
||||
cos_x, sin_x, y = res
|
||||
return (cos_x * g * y, sin_x * g)
|
||||
|
||||
f.defvjp(f_fwd, f_bwd, optimize_remat=True)
|
||||
x, y = 3.2, 1.0
|
||||
self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y))
|
||||
|
||||
|
||||
def transpose_unary(f, x_example):
|
||||
def transposed(y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user