mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
efbdaf66bf
commit
ae1d6b875f
@ -2027,7 +2027,8 @@ def checkpoint(fun, concrete=False):
|
||||
def fun_remat(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
|
||||
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
|
||||
concrete=concrete)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return fun_remat
|
||||
remat = checkpoint
|
||||
|
@ -526,7 +526,7 @@ def _remat_partial_eval(trace, f, tracers, params):
|
||||
# Since we traced with everything marked as unknown, but we need to know which
|
||||
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
||||
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
|
||||
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
|
||||
in_avals = ([raise_to_shaped(partial_val_aval(*t.pval)) for t in env]
|
||||
+ [raise_to_shaped(pv) for pv in in_pvs])
|
||||
out_avals = [raise_to_shaped(pv if pv is not None
|
||||
else abstract_unit if var is unitvar
|
||||
|
@ -1600,6 +1600,46 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
api.grad(func)(5.0) # doesn't crash
|
||||
|
||||
def test_remat_jit2(self):
|
||||
@api.jit
|
||||
def f(x):
|
||||
y = 2 * x
|
||||
|
||||
@api.remat
|
||||
def g():
|
||||
return y
|
||||
|
||||
return g()
|
||||
|
||||
self.assertAllClose(f(3), 6, check_dtypes=False)
|
||||
|
||||
def test_remat_nontrivial_env(self):
|
||||
# simplified from https://github.com/google/jax/issues/2030
|
||||
|
||||
@api.remat
|
||||
def foo(state, dt=0.5, c=1):
|
||||
u, u_t = state
|
||||
u_tt = c**2 * u
|
||||
u_t = u_t + u_tt * dt
|
||||
return (u, u_t)
|
||||
|
||||
@partial(api.jit, static_argnums=(1,))
|
||||
def _multi_step(state, count, dt, c):
|
||||
f = lambda s, _: (foo(s, dt, c), _)
|
||||
return lax.scan(f, state, None, count)
|
||||
|
||||
def multi_step(state, count, dt=1/np.sqrt(2), c=1):
|
||||
return _multi_step(state, count, dt, c)
|
||||
|
||||
def loss(u0, target, steps, dt=1/np.sqrt(2), c=1):
|
||||
init = (u0, np.zeros_like(u0))
|
||||
(uf, _), _ = multi_step(init, steps, dt, c)
|
||||
return ((uf - target) ** 2).mean()
|
||||
|
||||
target = np.zeros((128, 128))
|
||||
u0 = np.ones_like(target)
|
||||
loss(u0, target, 10) # doesn't crash
|
||||
|
||||
def test_trivial_computations(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = api.jit(lambda x: x)(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user