From 6d8be966a00e86b4b1ce41adc7f0b1f30d4f99b9 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 24 Feb 2025 10:13:34 -0800 Subject: [PATCH] Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix PiperOrigin-RevId: 730499532 --- jax/_src/interpreters/pxla.py | 3 +-- jax/experimental/shard_map.py | 15 +++++++++++---- tests/debug_nans_test.py | 3 +-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9baa5e977..a69b3e774 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2555,8 +2555,7 @@ def _gspmd_to_named_sharding( assert isinstance(out_s, GSPMDSharding) assert isinstance(orig_in_s, NamedSharding) assert isinstance(orig_in_s.mesh, Mesh) - if (out_aval is not None and not out_aval.sharding.mesh.empty and - out_aval.sharding.mesh._are_all_axes_auto): + if out_aval is not None and not out_aval.sharding.mesh.empty: mesh = _abstract_to_concrete_mesh( out_aval.sharding.mesh, out_s._device_assignment) else: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3d17c410d..b8881cf3e 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1707,10 +1707,12 @@ def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x - out_cts = [ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero + out_cts = [ + ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero else x if rewrite or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) - for ns, x in zip(out_names, out_cts)] + for ns, x in zip(out_names, out_cts) + ] args = tuple(x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval)) for ns, x in zip(in_names, args)) @@ -1751,12 +1753,17 @@ def _shard_map_transpose(out_cts, *args, print("Invalid nan value encountered in the backward pass of a shard_map " "function. Calling the de-optimized backward pass.") try: - _ = fun_trans.call_wrapped(out_cts, args) + # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` + # in eager mode so that output of shmap are not manual. + with jax.disable_jit(True): + _ = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep, + rewrite=rewrite, auto=auto) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: dispatch._raise_no_nan_in_deoptimized(e) - return tree_unflatten(out_tree(), out_flat) ad.primitive_transposes[shard_map_p] = _shard_map_transpose diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index c0cef5084..d3dcfb2e7 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -155,8 +155,7 @@ class DebugNaNsTest(jtu.JaxTestCase): _, f_vjp = jax.vjp(shmap_f, jnp.zeros([1])) with self.assertRaisesRegex( - FloatingPointError, - r"invalid value \(nan\) encountered in mul\nWhen differentiating"): + FloatingPointError, r"Invalid value \(nan\) encountered"): ans, = f_vjp(jnp.ones([1])) ans.block_until_ready()