Fix debug nans test after merging jit and pjit codepaths

PiperOrigin-RevId: 501122848
This commit is contained in:
Yash Katariya 2023-01-10 16:26:18 -08:00 committed by jax authors
parent 3f712480c6
commit e02c1da4c7
3 changed files with 27 additions and 2 deletions

View File

@ -124,7 +124,10 @@ def _nan_check_posthook(fun, args, kwargs, output):
buffers.append(da_or_sda.device_buffer)
try:
dispatch.check_special(xla.xla_call_p, buffers)
if jax.config.jax_jit_pjit_api_merge:
dispatch.check_special(pjit.pjit_p, buffers)
else:
dispatch.check_special(xla.xla_call_p, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs

View File

@ -1064,7 +1064,28 @@ def _pjit_call_impl(*args, jaxpr,
("out_shardings", out_shardings),
("abstract args", list(map(xla.abstractify, args))),
("fingerprint", fingerprint))
return compiled.unsafe_call(*args)
try:
return compiled.unsafe_call(*args)
except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
msg = ("An invalid value was encountered in the output of the "
f"`jit`-decorated function {name}. Because "
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalud value to be "
"produced. It may also arise from having nan/inf constants as "
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/google/jax.")
raise FloatingPointError(msg)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -84,6 +84,7 @@ jax_test(
jax_test(
name = "debug_nans_test",
srcs = ["debug_nans_test.py"],
enable_configs = ["cpu_jit_pjit_api_merge"],
)
py_test(