[Checkify] Add checks for shard_map.

PiperOrigin-RevId: 677798938
This commit is contained in:
jax authors 2024-09-23 08:10:29 -07:00
parent 1256e18fd4
commit 6c52ddc97f
2 changed files with 100 additions and 0 deletions

View File

@ -25,6 +25,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax import lax
from jax.experimental import shard_map
from jax._src import api
from jax._src import linear_util as lu
from jax._src import config
@ -931,6 +932,64 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
return tree_unflatten(out_tree, err_and_out)
error_checks[pjit.pjit_p] = pjit_error_check
def shard_map_error_check(
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
):
if (mesh := kwargs.get('mesh')) is None:
raise ValueError('Mesh must be provided for shard_map with checkify.')
err_vals, err_tree = jtu.tree_flatten(error)
num_error_vals = len(err_vals)
# Replicated sharding for in errors.
new_in_names = (*([{}] * num_error_vals), *in_names)
new_vals_in = [*err_vals, *vals_in]
in_avals = list(map(get_shaped_aval, new_vals_in))
for i, v in enumerate(in_avals):
if not (sharder := core.shard_aval_handlers.get(type(v))):
raise ValueError(f'Unsupported aval type: {type(v)}')
in_avals[i] = sharder(mesh, new_in_names[i], v)
if not isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = core.ClosedJaxpr(jaxpr, ())
with core.extend_axis_env_nd(mesh.shape.items()):
# jaxpr to checked_jaxpr
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
jaxpr, enabled_errors, err_tree, *in_avals
)
num_out_error_vals = out_tree.num_leaves - len(out_names)
@lu.wrap_init
def expand_errors_leading_dim(*xs):
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
errs, outs = split_list(outs, [num_out_error_vals])
errs = [lax.expand_dims(e, [0]) for e in errs]
return *errs, *outs
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
expand_errors_leading_dim, checked_jaxpr.in_avals
)
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
# Update shard_map params to account for extra error values.
# Use fully sharded partitioning for out errors.
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
subfun = lu.hashable_partial(
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
)
new_params = dict(
jaxpr=checked_jaxpr.jaxpr,
in_names=new_in_names,
out_names=new_out_names,
**kwargs,
)
_, new_params = shard_map.shard_map_p.get_bind_params(new_params)
err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params)
return tree_unflatten(out_tree, err_and_out)
error_checks[shard_map.shard_map_p] = shard_map_error_check
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
jvp_jaxpr_thunk, call_jaxpr, **params):
# The types to have in mind are:

View File

@ -23,6 +23,7 @@ import jax
from jax import lax
from jax.experimental import checkify
from jax.experimental import pjit
from jax.experimental import shard_map
from jax.sharding import NamedSharding
from jax._src import array
from jax._src import config
@ -539,6 +540,46 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNotNone(b_err.get())
self.assertStartsWith(b_err.get(), "division by zero")
@parameterized.parameters(True, False)
def test_shard_map(self, check_rep):
def f(x):
# unary func
return jax.lax.axis_index("dev") * x / x
def g(x, y):
# binary func
return jax.lax.axis_index("dev") * x / y
devices = jax.local_devices()[:8] # Taking up to 8 devices
mesh = jax.sharding.Mesh(np.array(devices), ["dev"])
pspec = jax.sharding.PartitionSpec("dev")
ps = NamedSharding(mesh, pspec)
inp = np.tile(np.arange(4, dtype=np.int32), 2)
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
f = shard_map.shard_map(
f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep
)
f = jax.jit(f, in_shardings=ps, out_shardings=ps)
f = checkify.checkify(f, errors=checkify.float_checks)
g = shard_map.shard_map(
g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep
)
g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps)
g = checkify.checkify(g, errors=checkify.float_checks)
u_err, _ = f(x)
b_err, _ = g(x, x)
divbyzero = "division by zero"
expected_err = f"at mapped index 0: {divbyzero}"
if (next_device_with_zero := len(devices) // 2) != 0:
expected_err += f"\nat mapped index {next_device_with_zero}: {divbyzero}"
self.assertIsNotNone(u_err.get())
self.assertEqual(u_err.get(), expected_err)
self.assertIsNotNone(b_err.get())
self.assertEqual(b_err.get(), expected_err)
def test_empty_enabled_errors(self):
def multi_errors(x):
x = x/0 # DIV