diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e67f624fc..32cc4feb9 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from jax import dtypes from jax import lax +from jax.experimental import shard_map from jax._src import api from jax._src import linear_util as lu from jax._src import config @@ -931,6 +932,64 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check + +def shard_map_error_check( + error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs +): + if (mesh := kwargs.get('mesh')) is None: + raise ValueError('Mesh must be provided for shard_map with checkify.') + + err_vals, err_tree = jtu.tree_flatten(error) + num_error_vals = len(err_vals) + # Replicated sharding for in errors. + new_in_names = (*([{}] * num_error_vals), *in_names) + new_vals_in = [*err_vals, *vals_in] + in_avals = list(map(get_shaped_aval, new_vals_in)) + for i, v in enumerate(in_avals): + if not (sharder := core.shard_aval_handlers.get(type(v))): + raise ValueError(f'Unsupported aval type: {type(v)}') + in_avals[i] = sharder(mesh, new_in_names[i], v) + + if not isinstance(jaxpr, core.ClosedJaxpr): + jaxpr = core.ClosedJaxpr(jaxpr, ()) + with core.extend_axis_env_nd(mesh.shape.items()): + # jaxpr to checked_jaxpr + checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( + jaxpr, enabled_errors, err_tree, *in_avals + ) + num_out_error_vals = out_tree.num_leaves - len(out_names) + + @lu.wrap_init + def expand_errors_leading_dim(*xs): + outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) + errs, outs = split_list(outs, [num_out_error_vals]) + errs = [lax.expand_dims(e, [0]) for e in errs] + return *errs, *outs + + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + expand_errors_leading_dim, checked_jaxpr.in_avals + ) + checked_jaxpr = core.ClosedJaxpr(jaxpr, consts) + + # Update shard_map params to account for extra error values. + # Use fully sharded partitioning for out errors. + new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names) + subfun = lu.hashable_partial( + lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts + ) + new_params = dict( + jaxpr=checked_jaxpr.jaxpr, + in_names=new_in_names, + out_names=new_out_names, + **kwargs, + ) + _, new_params = shard_map.shard_map_p.get_bind_params(new_params) + + err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params) + return tree_unflatten(out_tree, err_and_out) +error_checks[shard_map.shard_map_p] = shard_map_error_check + def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, jvp_jaxpr_thunk, call_jaxpr, **params): # The types to have in mind are: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 726e89d1b..24387a767 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -23,6 +23,7 @@ import jax from jax import lax from jax.experimental import checkify from jax.experimental import pjit +from jax.experimental import shard_map from jax.sharding import NamedSharding from jax._src import array from jax._src import config @@ -539,6 +540,46 @@ class CheckifyTransformTests(jtu.JaxTestCase): self.assertIsNotNone(b_err.get()) self.assertStartsWith(b_err.get(), "division by zero") + @parameterized.parameters(True, False) + def test_shard_map(self, check_rep): + def f(x): + # unary func + return jax.lax.axis_index("dev") * x / x + + def g(x, y): + # binary func + return jax.lax.axis_index("dev") * x / y + + devices = jax.local_devices()[:8] # Taking up to 8 devices + mesh = jax.sharding.Mesh(np.array(devices), ["dev"]) + pspec = jax.sharding.PartitionSpec("dev") + ps = NamedSharding(mesh, pspec) + inp = np.tile(np.arange(4, dtype=np.int32), 2) + x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) + + f = shard_map.shard_map( + f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep + ) + f = jax.jit(f, in_shardings=ps, out_shardings=ps) + f = checkify.checkify(f, errors=checkify.float_checks) + g = shard_map.shard_map( + g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep + ) + g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps) + g = checkify.checkify(g, errors=checkify.float_checks) + u_err, _ = f(x) + b_err, _ = g(x, x) + + divbyzero = "division by zero" + expected_err = f"at mapped index 0: {divbyzero}" + if (next_device_with_zero := len(devices) // 2) != 0: + expected_err += f"\nat mapped index {next_device_with_zero}: {divbyzero}" + + self.assertIsNotNone(u_err.get()) + self.assertEqual(u_err.get(), expected_err) + self.assertIsNotNone(b_err.get()) + self.assertEqual(b_err.get(), expected_err) + def test_empty_enabled_errors(self): def multi_errors(x): x = x/0 # DIV