mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Checkify] Add checks for shard_map.
PiperOrigin-RevId: 677798938
This commit is contained in:
parent
1256e18fd4
commit
6c52ddc97f
@ -25,6 +25,7 @@ import jax.numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
|
||||
from jax.experimental import shard_map
|
||||
from jax._src import api
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import config
|
||||
@ -931,6 +932,64 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
return tree_unflatten(out_tree, err_and_out)
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
||||
|
||||
def shard_map_error_check(
|
||||
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
|
||||
):
|
||||
if (mesh := kwargs.get('mesh')) is None:
|
||||
raise ValueError('Mesh must be provided for shard_map with checkify.')
|
||||
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
num_error_vals = len(err_vals)
|
||||
# Replicated sharding for in errors.
|
||||
new_in_names = (*([{}] * num_error_vals), *in_names)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = list(map(get_shaped_aval, new_vals_in))
|
||||
for i, v in enumerate(in_avals):
|
||||
if not (sharder := core.shard_aval_handlers.get(type(v))):
|
||||
raise ValueError(f'Unsupported aval type: {type(v)}')
|
||||
in_avals[i] = sharder(mesh, new_in_names[i], v)
|
||||
|
||||
if not isinstance(jaxpr, core.ClosedJaxpr):
|
||||
jaxpr = core.ClosedJaxpr(jaxpr, ())
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
# jaxpr to checked_jaxpr
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
jaxpr, enabled_errors, err_tree, *in_avals
|
||||
)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_names)
|
||||
|
||||
@lu.wrap_init
|
||||
def expand_errors_leading_dim(*xs):
|
||||
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
|
||||
errs, outs = split_list(outs, [num_out_error_vals])
|
||||
errs = [lax.expand_dims(e, [0]) for e in errs]
|
||||
return *errs, *outs
|
||||
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
expand_errors_leading_dim, checked_jaxpr.in_avals
|
||||
)
|
||||
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
# Update shard_map params to account for extra error values.
|
||||
# Use fully sharded partitioning for out errors.
|
||||
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
|
||||
subfun = lu.hashable_partial(
|
||||
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
|
||||
)
|
||||
new_params = dict(
|
||||
jaxpr=checked_jaxpr.jaxpr,
|
||||
in_names=new_in_names,
|
||||
out_names=new_out_names,
|
||||
**kwargs,
|
||||
)
|
||||
_, new_params = shard_map.shard_map_p.get_bind_params(new_params)
|
||||
|
||||
err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params)
|
||||
return tree_unflatten(out_tree, err_and_out)
|
||||
error_checks[shard_map.shard_map_p] = shard_map_error_check
|
||||
|
||||
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
||||
jvp_jaxpr_thunk, call_jaxpr, **params):
|
||||
# The types to have in mind are:
|
||||
|
@ -23,6 +23,7 @@ import jax
|
||||
from jax import lax
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import shard_map
|
||||
from jax.sharding import NamedSharding
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
@ -539,6 +540,46 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(b_err.get())
|
||||
self.assertStartsWith(b_err.get(), "division by zero")
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
def test_shard_map(self, check_rep):
|
||||
def f(x):
|
||||
# unary func
|
||||
return jax.lax.axis_index("dev") * x / x
|
||||
|
||||
def g(x, y):
|
||||
# binary func
|
||||
return jax.lax.axis_index("dev") * x / y
|
||||
|
||||
devices = jax.local_devices()[:8] # Taking up to 8 devices
|
||||
mesh = jax.sharding.Mesh(np.array(devices), ["dev"])
|
||||
pspec = jax.sharding.PartitionSpec("dev")
|
||||
ps = NamedSharding(mesh, pspec)
|
||||
inp = np.tile(np.arange(4, dtype=np.int32), 2)
|
||||
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
|
||||
|
||||
f = shard_map.shard_map(
|
||||
f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep
|
||||
)
|
||||
f = jax.jit(f, in_shardings=ps, out_shardings=ps)
|
||||
f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
g = shard_map.shard_map(
|
||||
g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep
|
||||
)
|
||||
g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps)
|
||||
g = checkify.checkify(g, errors=checkify.float_checks)
|
||||
u_err, _ = f(x)
|
||||
b_err, _ = g(x, x)
|
||||
|
||||
divbyzero = "division by zero"
|
||||
expected_err = f"at mapped index 0: {divbyzero}"
|
||||
if (next_device_with_zero := len(devices) // 2) != 0:
|
||||
expected_err += f"\nat mapped index {next_device_with_zero}: {divbyzero}"
|
||||
|
||||
self.assertIsNotNone(u_err.get())
|
||||
self.assertEqual(u_err.get(), expected_err)
|
||||
self.assertIsNotNone(b_err.get())
|
||||
self.assertEqual(b_err.get(), expected_err)
|
||||
|
||||
def test_empty_enabled_errors(self):
|
||||
def multi_errors(x):
|
||||
x = x/0 # DIV
|
||||
|
Loading…
x
Reference in New Issue
Block a user