mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11669 from LenaMartens:check-of-pjit
PiperOrigin-RevId: 464133301
This commit is contained in:
commit
21f632740c
@ -16,7 +16,7 @@ import enum
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -25,9 +25,12 @@ import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from jax._src import source_info_util, traceback_util
|
||||
from jax._src.lax import control_flow as cf
|
||||
@ -669,6 +672,40 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
return out, Error(err, code, new_msgs, payload)
|
||||
error_checks[lax.while_p] = while_loop_error_check
|
||||
|
||||
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
checked_jaxpr, msgs = checkify_jaxpr(jaxpr, error, enabled_errors)
|
||||
new_vals_in = [error.err, error.code, error.payload, *vals_in]
|
||||
# TODO(lenamartens, yashkatariya): replace with OpShardingSharding.
|
||||
sharding = pxla._create_mesh_pspec_sharding(pxla.thread_resources.env.physical_mesh,
|
||||
pxla.PartitionSpec(None))
|
||||
pos_sem = maps._positional_semantics.val
|
||||
new_in_shardings = (*[sharding]*3, *in_shardings)
|
||||
new_out_shardings = (*[sharding]*3, *out_shardings)
|
||||
if not isinstance(in_positional_semantics, Iterable):
|
||||
in_positional_semantics = (in_positional_semantics,)
|
||||
if not isinstance(out_positional_semantics, Iterable):
|
||||
out_positional_semantics = (out_positional_semantics,)
|
||||
new_positional_sems_in = (*[pos_sem]*3, *in_positional_semantics)
|
||||
new_positional_sems_out = (*[pos_sem]*3, *out_positional_semantics)
|
||||
new_donated_invars = (*[False]*3, *donated_invars)
|
||||
err, code, payload, *vals_out = pjit.pjit_p.bind(
|
||||
*new_vals_in,
|
||||
jaxpr=checked_jaxpr,
|
||||
in_shardings=new_in_shardings,
|
||||
out_shardings=new_out_shardings,
|
||||
resource_env=resource_env,
|
||||
donated_invars=new_donated_invars,
|
||||
name=name,
|
||||
in_positional_semantics=new_positional_sems_in,
|
||||
out_positional_semantics=new_positional_sems_out)
|
||||
return vals_out, Error(err, code, msgs, payload)
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
||||
|
||||
def add_nan_check(prim):
|
||||
error_checks[prim] = partial(nan_error_check, prim)
|
||||
|
||||
|
@ -24,6 +24,8 @@ from jax import lax
|
||||
import jax._src.test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax._src.checkify import CheckEffect
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -410,6 +412,30 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
# first error which occurs is in cond
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
|
||||
def test_pjit(self):
|
||||
def f(x):
|
||||
# unary func
|
||||
return x / x
|
||||
|
||||
def g(x, y):
|
||||
# binary func
|
||||
return x / y
|
||||
|
||||
ps = pjit.PartitionSpec("dev")
|
||||
f = pjit.pjit(f, in_axis_resources=ps, out_axis_resources=ps)
|
||||
f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
g = pjit.pjit(g, in_axis_resources=ps, out_axis_resources=ps)
|
||||
g = checkify.checkify(g, errors=checkify.float_checks)
|
||||
with maps.Mesh(np.array(jax.devices()), ["dev"]):
|
||||
x = jnp.arange(8)
|
||||
u_err, _ = f(x)
|
||||
b_err, _ = g(x, x)
|
||||
|
||||
self.assertIsNotNone(u_err.get())
|
||||
self.assertStartsWith(u_err.get(), "divided by zero")
|
||||
self.assertIsNotNone(b_err.get())
|
||||
self.assertStartsWith(b_err.get(), "divided by zero")
|
||||
|
||||
def test_empty_enabled_errors(self):
|
||||
def multi_errors(x):
|
||||
x = x/0 # DIV
|
||||
|
Loading…
x
Reference in New Issue
Block a user