mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix
PiperOrigin-RevId: 730499532
This commit is contained in:
parent
4b4f2f9cb9
commit
6d8be966a0
@ -2555,8 +2555,7 @@ def _gspmd_to_named_sharding(
|
||||
assert isinstance(out_s, GSPMDSharding)
|
||||
assert isinstance(orig_in_s, NamedSharding)
|
||||
assert isinstance(orig_in_s.mesh, Mesh)
|
||||
if (out_aval is not None and not out_aval.sharding.mesh.empty and
|
||||
out_aval.sharding.mesh._are_all_axes_auto):
|
||||
if out_aval is not None and not out_aval.sharding.mesh.empty:
|
||||
mesh = _abstract_to_concrete_mesh(
|
||||
out_aval.sharding.mesh, out_s._device_assignment)
|
||||
else:
|
||||
|
@ -1707,10 +1707,12 @@ def _shard_map_transpose(out_cts, *args,
|
||||
jaxpr: core.Jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero
|
||||
out_cts = [
|
||||
ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite or dtypes.dtype(x) == dtypes.float0
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
for ns, x in zip(out_names, out_cts)
|
||||
]
|
||||
args = tuple(x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval))
|
||||
for ns, x in zip(in_names, args))
|
||||
@ -1751,12 +1753,17 @@ def _shard_map_transpose(out_cts, *args,
|
||||
print("Invalid nan value encountered in the backward pass of a shard_map "
|
||||
"function. Calling the de-optimized backward pass.")
|
||||
try:
|
||||
_ = fun_trans.call_wrapped(out_cts, args)
|
||||
# TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)`
|
||||
# in eager mode so that output of shmap are not manual.
|
||||
with jax.disable_jit(True):
|
||||
_ = shard_map_p.bind(
|
||||
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
except (FloatingPointError, ZeroDivisionError) as e2:
|
||||
raise e2 from None
|
||||
else:
|
||||
dispatch._raise_no_nan_in_deoptimized(e)
|
||||
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
|
||||
|
||||
|
@ -155,8 +155,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
|
||||
FloatingPointError, r"Invalid value \(nan\) encountered"):
|
||||
ans, = f_vjp(jnp.ones([1]))
|
||||
ans.block_until_ready()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user