Merge pull request #10251 from mattjj:debug-nans-error

PiperOrigin-RevId: 441530769
This commit is contained in:
jax authors 2022-04-13 11:17:23 -07:00
commit 0b898ea627
2 changed files with 41 additions and 9 deletions

View File

@ -149,24 +149,45 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*arg_specs)
try:
out = compiled_fun(*args)
return compiled_fun(*args)
except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid value encountered in the output of a jit/pmap-ed function. "
print("Invalid value encountered in the output of a jit-decorated function. "
"Calling the de-optimized version.")
# We want to run the wrapped function again (after _xla_callable already ran
# it), but linear_util.WrappedFun instances are meant to be run only once.
# In addition to re-executing the Python code, which is usually undesirable
# but which config.jax_debug_nans is meant to opt into, we'll be re-executing
# any linear_util.py-style side effects, i.e. re-populating Stores created
# by any transformation_with_aux's applied to fun. Since this is
# intentional here, to avoid "Store occupied" errors we clone the WrappedFun
# with empty stores.
# but which config.jax_debug_nans is meant to opt into, we'll be
# re-executing any linear_util.py-style side effects, i.e. re-populating
# Stores created by any transformation_with_aux's applied to fun. Since this
# is intentional here, to avoid "Store occupied" errors we clone the
# WrappedFun with empty stores.
stores = [lu.Store() for _ in fun.stores]
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params, fun.in_type)
with core.new_sublevel():
_ = clone.call_wrapped(*args) # probably won't return
return out
_ = clone.call_wrapped(*args) # may raise, not return
# If control reaches this line, we got a NaN on the output of `compiled_fun`
# but not `clone.call_wrapped` on the same arguments. Let's tell the user.
fun_info = pe.fun_sourceinfo(fun.f)
msg = ("An invalid value was encountered in the output of the "
f"`jit`-decorated function {fun_info}. Because "
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalud value to be "
"produced. It may also arise from having nan/inf constants as "
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/google/jax.")
raise FloatingPointError(msg)
xla.xla_call_p.def_impl(_xla_call_impl)

View File

@ -220,5 +220,16 @@ class DebugInfsTest(jtu.JaxTestCase):
except FloatingPointError:
pass
def testDebugNansDoesntReturnDeoptimizedResult(self):
@jax.jit
def f(x):
x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
with jax.debug_nans(True):
f(3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())