mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix debug nans test after merging jit
and pjit
codepaths
PiperOrigin-RevId: 501122848
This commit is contained in:
parent
3f712480c6
commit
e02c1da4c7
@ -124,7 +124,10 @@ def _nan_check_posthook(fun, args, kwargs, output):
|
||||
buffers.append(da_or_sda.device_buffer)
|
||||
|
||||
try:
|
||||
dispatch.check_special(xla.xla_call_p, buffers)
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
dispatch.check_special(pjit.pjit_p, buffers)
|
||||
else:
|
||||
dispatch.check_special(xla.xla_call_p, buffers)
|
||||
except FloatingPointError:
|
||||
# compiled_fun can only raise in this case
|
||||
assert config.jax_debug_nans or config.jax_debug_infs
|
||||
|
@ -1064,7 +1064,28 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
("out_shardings", out_shardings),
|
||||
("abstract args", list(map(xla.abstractify, args))),
|
||||
("fingerprint", fingerprint))
|
||||
return compiled.unsafe_call(*args)
|
||||
try:
|
||||
return compiled.unsafe_call(*args)
|
||||
except FloatingPointError:
|
||||
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
||||
msg = ("An invalid value was encountered in the output of the "
|
||||
f"`jit`-decorated function {name}. Because "
|
||||
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
|
||||
"de-optimized function (i.e., the function as if the `jit` "
|
||||
"decorator were removed) was called in an attempt to get a more "
|
||||
"precise error message. However, the de-optimized function did not "
|
||||
"produce invalid values during its execution. This behavior can "
|
||||
"result from `jit` optimizations causing the invalud value to be "
|
||||
"produced. It may also arise from having nan/inf constants as "
|
||||
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
||||
"\n\n"
|
||||
"It may be possible to avoid the invalid value by removing the "
|
||||
"`jit` decorator, at the cost of losing optimizations. "
|
||||
"\n\n"
|
||||
"If you see this error, consider opening a bug report at "
|
||||
"https://github.com/google/jax.")
|
||||
raise FloatingPointError(msg)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
|
||||
|
@ -84,6 +84,7 @@ jax_test(
|
||||
jax_test(
|
||||
name = "debug_nans_test",
|
||||
srcs = ["debug_nans_test.py"],
|
||||
enable_configs = ["cpu_jit_pjit_api_merge"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
Loading…
x
Reference in New Issue
Block a user