Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix

PiperOrigin-RevId: 730499532
This commit is contained in:
Yash Katariya 2025-02-24 10:13:34 -08:00 committed by jax authors
parent 4b4f2f9cb9
commit 6d8be966a0
3 changed files with 13 additions and 8 deletions

View File

@ -2555,8 +2555,7 @@ def _gspmd_to_named_sharding(
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, NamedSharding)
assert isinstance(orig_in_s.mesh, Mesh)
if (out_aval is not None and not out_aval.sharding.mesh.empty and
out_aval.sharding.mesh._are_all_axes_auto):
if out_aval is not None and not out_aval.sharding.mesh.empty:
mesh = _abstract_to_concrete_mesh(
out_aval.sharding.mesh, out_s._device_assignment)
else:

View File

@ -1707,10 +1707,12 @@ def _shard_map_transpose(out_cts, *args,
jaxpr: core.Jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero
out_cts = [
ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
for ns, x in zip(out_names, out_cts)
]
args = tuple(x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval))
for ns, x in zip(in_names, args))
@ -1751,12 +1753,17 @@ def _shard_map_transpose(out_cts, *args,
print("Invalid nan value encountered in the backward pass of a shard_map "
"function. Calling the de-optimized backward pass.")
try:
_ = fun_trans.call_wrapped(out_cts, args)
# TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)`
# in eager mode so that output of shmap are not manual.
with jax.disable_jit(True):
_ = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
dispatch._raise_no_nan_in_deoptimized(e)
return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose

View File

@ -155,8 +155,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
FloatingPointError, r"Invalid value \(nan\) encountered"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()