mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10251 from mattjj:debug-nans-error
PiperOrigin-RevId: 441530769
This commit is contained in:
commit
0b898ea627
@ -149,24 +149,45 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
|
||||
*arg_specs)
|
||||
try:
|
||||
out = compiled_fun(*args)
|
||||
return compiled_fun(*args)
|
||||
except FloatingPointError:
|
||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
||||
print("Invalid value encountered in the output of a jit/pmap-ed function. "
|
||||
print("Invalid value encountered in the output of a jit-decorated function. "
|
||||
"Calling the de-optimized version.")
|
||||
# We want to run the wrapped function again (after _xla_callable already ran
|
||||
# it), but linear_util.WrappedFun instances are meant to be run only once.
|
||||
# In addition to re-executing the Python code, which is usually undesirable
|
||||
# but which config.jax_debug_nans is meant to opt into, we'll be re-executing
|
||||
# any linear_util.py-style side effects, i.e. re-populating Stores created
|
||||
# by any transformation_with_aux's applied to fun. Since this is
|
||||
# intentional here, to avoid "Store occupied" errors we clone the WrappedFun
|
||||
# with empty stores.
|
||||
# but which config.jax_debug_nans is meant to opt into, we'll be
|
||||
# re-executing any linear_util.py-style side effects, i.e. re-populating
|
||||
# Stores created by any transformation_with_aux's applied to fun. Since this
|
||||
# is intentional here, to avoid "Store occupied" errors we clone the
|
||||
# WrappedFun with empty stores.
|
||||
stores = [lu.Store() for _ in fun.stores]
|
||||
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params, fun.in_type)
|
||||
|
||||
with core.new_sublevel():
|
||||
_ = clone.call_wrapped(*args) # probably won't return
|
||||
return out
|
||||
_ = clone.call_wrapped(*args) # may raise, not return
|
||||
|
||||
# If control reaches this line, we got a NaN on the output of `compiled_fun`
|
||||
# but not `clone.call_wrapped` on the same arguments. Let's tell the user.
|
||||
fun_info = pe.fun_sourceinfo(fun.f)
|
||||
msg = ("An invalid value was encountered in the output of the "
|
||||
f"`jit`-decorated function {fun_info}. Because "
|
||||
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
|
||||
"de-optimized function (i.e., the function as if the `jit` "
|
||||
"decorator were removed) was called in an attempt to get a more "
|
||||
"precise error message. However, the de-optimized function did not "
|
||||
"produce invalid values during its execution. This behavior can "
|
||||
"result from `jit` optimizations causing the invalud value to be "
|
||||
"produced. It may also arise from having nan/inf constants as "
|
||||
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
||||
"\n\n"
|
||||
"It may be possible to avoid the invalid value by removing the "
|
||||
"`jit` decorator, at the cost of losing optimizations. "
|
||||
"\n\n"
|
||||
"If you see this error, consider opening a bug report at "
|
||||
"https://github.com/google/jax.")
|
||||
raise FloatingPointError(msg)
|
||||
|
||||
xla.xla_call_p.def_impl(_xla_call_impl)
|
||||
|
||||
|
@ -220,5 +220,16 @@ class DebugInfsTest(jtu.JaxTestCase):
|
||||
except FloatingPointError:
|
||||
pass
|
||||
|
||||
def testDebugNansDoesntReturnDeoptimizedResult(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
x + 2 # avoid trivial dispatch path by adding some eqn
|
||||
return jnp.nan
|
||||
|
||||
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
|
||||
with jax.debug_nans(True):
|
||||
f(3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user