make remat reduce precision of saved values to avoid xla excess precision

problem: f(x) != value_and_grad(f)(x)[0] ??

Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
Matthew Johnson 2024-07-02 23:13:09 +00:00
parent e4b606e38a
commit f498d7a3ba
2 changed files with 77 additions and 1 deletions

View File

@ -28,6 +28,7 @@ from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import effects
from jax._src import source_info_util
@ -544,10 +545,15 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
num_res = sum(used_res)
# To avoid precision mismatches in fwd and bwd passes due to XLA excess
# precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls
# on producers of any residuals. See https://github.com/google/jax/pull/22244.
jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res)
# compute known outputs and residuals (hoisted out of remat primitive)
_, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
_, in_consts = partition_list(in_used_known, in_consts_)
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
out_consts = core.eval_jaxpr(jaxpr_known_, (), *in_consts)
out_knowns, residuals = split_list(out_consts, [len(out_consts)-num_res])
# set up unknown outputs with a recipe to call remat
@ -587,6 +593,43 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval
@weakref_lru_cache
def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr:
res_vars = jaxpr.outvars[len(jaxpr.outvars) - num_res:]
used_vars = {x for e in jaxpr.eqns for x in e.invars if isinstance(x, core.Var)}
invars, constvars, eqns = jaxpr.invars[:], jaxpr.constvars[:], jaxpr.eqns[:]
for v in res_vars:
if (not isinstance(v.aval, core.UnshapedArray) or
not dtypes.issubdtype(v.aval.dtype, np.inexact)):
continue
if v not in used_vars:
continue
assert isinstance(v, core.Var)
newvar = core.Var(v.suffix, v.aval)
finfo = dtypes.finfo(v.aval.dtype)
params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant)
if v in constvars or v in invars:
lst = constvars if v in constvars else invars
new_eqn = core.new_jaxpr_eqn(
[newvar], [v], lax_internal.reduce_precision_p, params, set())
lst[lst.index(v)] = newvar
eqns.insert(0, new_eqn)
else:
(eqn_idx, eqn), = ((i, e) for i, e in enumerate(eqns) if v in e.outvars)
if (eqn.primitive == lax_internal.reduce_precision_p and
eqn.params == params):
continue
replace_eqn = eqn.replace(outvars=[v_ if v_ != v else newvar
for v_ in eqn.outvars])
new_eqn = core.new_jaxpr_eqn(
[newvar], [v], lax_internal.reduce_precision_p, params, set(),
eqn.source_info, eqn.ctx)
eqns[eqn_idx] = replace_eqn
eqns.insert(eqn_idx+1, new_eqn)
new_jaxpr = jaxpr.replace(invars=invars, constvars=constvars, eqns=eqns)
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
return new_jaxpr
def remat_partial_eval_custom_params_updater(*args):
*_, params_known, params_staged = args
return params_known, dict(params_staged, differentiated=True)

View File

@ -6410,6 +6410,39 @@ class RematTest(jtu.JaxTestCase):
self.assertFalse(any(' sin ' in line for line in l.output))
self.assertTrue(any(' cos ' in line for line in l.output))
def test_excess_precision_hell(self):
finfo = jnp.finfo('bfloat16')
eps = finfo.eps
@jax.custom_vjp
def dot(x):
return jnp.dot(x, x)
def dot_fwd(x):
return dot(x), None
def dot_bwd(_, g):
return g,
dot.defvjp(dot_fwd, dot_bwd)
@jax.custom_vjp
def foo(x):
return jnp.float32(1.) * x.astype('float32')
def foo_fwd(x):
return foo(x), x
def foo_bwd(x, _):
return jnp.float32(1.) * x.astype('float32'),
foo.defvjp(foo_fwd, foo_bwd)
@jax.jit
@partial(jax.remat, policy=lambda *_, **__: True)
def f(x):
x = dot(x)
return foo(x)
x = (jnp.bfloat16(1) + eps) * jnp.eye(2, dtype='bfloat16')
y, vjp = jax.vjp(f, x)
y_, = vjp(jnp.ones_like(y))
self.assertAllClose(y, y_, atol=0, rtol=0)
@jtu.with_config(jax_pprint_use_color=False)
class JaxprTest(jtu.JaxTestCase):