Merge pull request #11669 from LenaMartens:check-of-pjit

PiperOrigin-RevId: 464133301
This commit is contained in:
jax authors 2022-07-29 12:25:32 -07:00
commit 21f632740c
2 changed files with 64 additions and 1 deletions

View File

@ -16,7 +16,7 @@ import enum
from dataclasses import dataclass
from functools import partial
import itertools as it
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable
import numpy as np
@ -25,9 +25,12 @@ import jax.numpy as jnp
from jax import core
from jax import linear_util as lu
from jax.api_util import flatten_fun
from jax.experimental import pjit
from jax.experimental import maps
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax._src import source_info_util, traceback_util
from jax._src.lax import control_flow as cf
@ -669,6 +672,40 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
return out, Error(err, code, new_msgs, payload)
error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics):
checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors)
new_vals_in = [error.err, error.code, error.payload, *vals_in]
# TODO(lenamartens, yashkatariya): replace with OpShardingSharding.
sharding = pxla._create_mesh_pspec_sharding(pxla.thread_resources.env.physical_mesh,
pxla.PartitionSpec(None))
pos_sem = maps._positional_semantics.val
new_in_shardings = (*[sharding]*3, *in_shardings)
new_out_shardings = (*[sharding]*3, *out_shardings)
if not isinstance(in_positional_semantics, Iterable):
in_positional_semantics = (in_positional_semantics,)
if not isinstance(out_positional_semantics, Iterable):
out_positional_semantics = (out_positional_semantics,)
new_positional_sems_in = (*[pos_sem]*3, *in_positional_semantics)
new_positional_sems_out = (*[pos_sem]*3, *out_positional_semantics)
new_donated_invars = (*[False]*3, *donated_invars)
err, code, payload, *vals_out = pjit.pjit_p.bind(
*new_vals_in,
jaxpr=checked_jaxpr,
in_shardings=new_in_shardings,
out_shardings=new_out_shardings,
resource_env=resource_env,
donated_invars=new_donated_invars,
name=name,
in_positional_semantics=new_positional_sems_in,
out_positional_semantics=new_positional_sems_out)
return vals_out, Error(err, code, msgs, payload)
error_checks[pjit.pjit_p] = pjit_error_check
def add_nan_check(prim):
error_checks[prim] = partial(nan_error_check, prim)

View File

@ -24,6 +24,8 @@ from jax import lax
import jax._src.test_util as jtu
from jax.config import config
from jax.experimental import checkify
from jax.experimental import pjit
from jax.experimental import maps
from jax._src.checkify import CheckEffect
import jax.numpy as jnp
@ -410,6 +412,30 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# first error which occurs is in cond
self.assertStartsWith(err.get(), "nan generated by primitive sin")
def test_pjit(self):
def f(x):
# unary func
return x / x
def g(x, y):
# binary func
return x / y
ps = pjit.PartitionSpec("dev")
f = pjit.pjit(f, in_axis_resources=ps, out_axis_resources=ps)
f = checkify.checkify(f, errors=checkify.float_checks)
g = pjit.pjit(g, in_axis_resources=ps, out_axis_resources=ps)
g = checkify.checkify(g, errors=checkify.float_checks)
with maps.Mesh(np.array(jax.devices()), ["dev"]):
x = jnp.arange(8)
u_err, _ = f(x)
b_err, _ = g(x, x)
self.assertIsNotNone(u_err.get())
self.assertStartsWith(u_err.get(), "divided by zero")
self.assertIsNotNone(b_err.get())
self.assertStartsWith(b_err.get(), "divided by zero")
def test_empty_enabled_errors(self):
def multi_errors(x):
x = x/0 # DIV