mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make remat reduce precision of saved values to avoid xla excess precision
problem: f(x) != value_and_grad(f)(x)[0] ?? Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
parent
e4b606e38a
commit
f498d7a3ba
@ -28,6 +28,7 @@ from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import effects
|
||||
from jax._src import source_info_util
|
||||
@ -544,10 +545,15 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
|
||||
num_res = sum(used_res)
|
||||
|
||||
# To avoid precision mismatches in fwd and bwd passes due to XLA excess
|
||||
# precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls
|
||||
# on producers of any residuals. See https://github.com/google/jax/pull/22244.
|
||||
jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res)
|
||||
|
||||
# compute known outputs and residuals (hoisted out of remat primitive)
|
||||
_, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
|
||||
_, in_consts = partition_list(in_used_known, in_consts_)
|
||||
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
|
||||
out_consts = core.eval_jaxpr(jaxpr_known_, (), *in_consts)
|
||||
out_knowns, residuals = split_list(out_consts, [len(out_consts)-num_res])
|
||||
|
||||
# set up unknown outputs with a recipe to call remat
|
||||
@ -587,6 +593,43 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
|
||||
return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
|
||||
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval
|
||||
|
||||
@weakref_lru_cache
|
||||
def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr:
|
||||
res_vars = jaxpr.outvars[len(jaxpr.outvars) - num_res:]
|
||||
used_vars = {x for e in jaxpr.eqns for x in e.invars if isinstance(x, core.Var)}
|
||||
invars, constvars, eqns = jaxpr.invars[:], jaxpr.constvars[:], jaxpr.eqns[:]
|
||||
for v in res_vars:
|
||||
if (not isinstance(v.aval, core.UnshapedArray) or
|
||||
not dtypes.issubdtype(v.aval.dtype, np.inexact)):
|
||||
continue
|
||||
if v not in used_vars:
|
||||
continue
|
||||
assert isinstance(v, core.Var)
|
||||
newvar = core.Var(v.suffix, v.aval)
|
||||
finfo = dtypes.finfo(v.aval.dtype)
|
||||
params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant)
|
||||
if v in constvars or v in invars:
|
||||
lst = constvars if v in constvars else invars
|
||||
new_eqn = core.new_jaxpr_eqn(
|
||||
[newvar], [v], lax_internal.reduce_precision_p, params, set())
|
||||
lst[lst.index(v)] = newvar
|
||||
eqns.insert(0, new_eqn)
|
||||
else:
|
||||
(eqn_idx, eqn), = ((i, e) for i, e in enumerate(eqns) if v in e.outvars)
|
||||
if (eqn.primitive == lax_internal.reduce_precision_p and
|
||||
eqn.params == params):
|
||||
continue
|
||||
replace_eqn = eqn.replace(outvars=[v_ if v_ != v else newvar
|
||||
for v_ in eqn.outvars])
|
||||
new_eqn = core.new_jaxpr_eqn(
|
||||
[newvar], [v], lax_internal.reduce_precision_p, params, set(),
|
||||
eqn.source_info, eqn.ctx)
|
||||
eqns[eqn_idx] = replace_eqn
|
||||
eqns.insert(eqn_idx+1, new_eqn)
|
||||
new_jaxpr = jaxpr.replace(invars=invars, constvars=constvars, eqns=eqns)
|
||||
config.enable_checks.value and core.check_jaxpr(new_jaxpr)
|
||||
return new_jaxpr
|
||||
|
||||
def remat_partial_eval_custom_params_updater(*args):
|
||||
*_, params_known, params_staged = args
|
||||
return params_known, dict(params_staged, differentiated=True)
|
||||
|
@ -6410,6 +6410,39 @@ class RematTest(jtu.JaxTestCase):
|
||||
self.assertFalse(any(' sin ' in line for line in l.output))
|
||||
self.assertTrue(any(' cos ' in line for line in l.output))
|
||||
|
||||
def test_excess_precision_hell(self):
|
||||
finfo = jnp.finfo('bfloat16')
|
||||
eps = finfo.eps
|
||||
|
||||
@jax.custom_vjp
|
||||
def dot(x):
|
||||
return jnp.dot(x, x)
|
||||
def dot_fwd(x):
|
||||
return dot(x), None
|
||||
def dot_bwd(_, g):
|
||||
return g,
|
||||
dot.defvjp(dot_fwd, dot_bwd)
|
||||
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return jnp.float32(1.) * x.astype('float32')
|
||||
def foo_fwd(x):
|
||||
return foo(x), x
|
||||
def foo_bwd(x, _):
|
||||
return jnp.float32(1.) * x.astype('float32'),
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
@jax.jit
|
||||
@partial(jax.remat, policy=lambda *_, **__: True)
|
||||
def f(x):
|
||||
x = dot(x)
|
||||
return foo(x)
|
||||
|
||||
x = (jnp.bfloat16(1) + eps) * jnp.eye(2, dtype='bfloat16')
|
||||
y, vjp = jax.vjp(f, x)
|
||||
y_, = vjp(jnp.ones_like(y))
|
||||
self.assertAllClose(y, y_, atol=0, rtol=0)
|
||||
|
||||
|
||||
@jtu.with_config(jax_pprint_use_color=False)
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user