mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Add checkify support for pallas_call in interpret mode.
PiperOrigin-RevId: 644181742
This commit is contained in:
parent
1d77720e9a
commit
fb68f3449b
@ -25,6 +25,7 @@ from jax import api_util
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import checkify
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
@ -114,6 +115,21 @@ def _pad_values_to_block_dimension(value,
|
||||
value = jnp.pad(value, pad_width, constant_values=pad_value)
|
||||
return value
|
||||
|
||||
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
|
||||
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
|
||||
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
|
||||
|
||||
def _initialize_output_vals(
|
||||
out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]:
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
output_vals = []
|
||||
for i, out_shape in enumerate(out_shapes):
|
||||
if i in oi_map:
|
||||
output_vals.append(input_args[oi_map[i]])
|
||||
else:
|
||||
output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype))
|
||||
return output_vals
|
||||
|
||||
def _logical_to_interpret_mode_dtype(dtype):
|
||||
"""Converts logical dtypes into JAX dtypes for interpret mode.
|
||||
|
||||
@ -171,13 +187,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
out = []
|
||||
for i, out_shape in enumerate(out_shapes):
|
||||
if i in oi_map:
|
||||
out.append(args[oi_map[i]])
|
||||
else:
|
||||
out.append(uninitialized_value(out_shape.shape, out_shape.dtype))
|
||||
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
|
||||
num_invars = len(jaxpr.invars)
|
||||
@ -190,12 +200,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
|
||||
)
|
||||
scratch_avals = [v.aval for v in scratch_invars]
|
||||
if not all(
|
||||
hasattr(a, "shape") and hasattr(a, "dtype") for a in scratch_avals
|
||||
):
|
||||
raise NotImplementedError(f"Cannot initialize scratch: {scratch_avals}")
|
||||
scratch_values = [uninitialized_value(a.shape, a.dtype)
|
||||
for a in scratch_avals]
|
||||
scratch_values = _initialize_scratch_vals(scratch_avals)
|
||||
|
||||
carry = []
|
||||
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
|
||||
@ -729,6 +734,168 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
||||
|
||||
def checkify_pallas_kernel_body_jaxpr(
|
||||
body_jaxpr: jax_core.ClosedJaxpr,
|
||||
enabled_errors,
|
||||
error: checkify.Error,
|
||||
grid_mapping: GridMapping) -> tuple[
|
||||
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
|
||||
err_vals, err_tree = tree_util.tree_flatten(error)
|
||||
err_vals = map(checkify.get_shaped_aval, err_vals)
|
||||
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
||||
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr(
|
||||
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
||||
return checked_jaxpr, out_tree, error_effects
|
||||
|
||||
def pallas_call_checkify_rule(error: checkify.Error,
|
||||
enabled_errors,
|
||||
*args: jax_core.Value,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
interpret: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
out_shapes,
|
||||
**kwargs):
|
||||
# TODO(b/346651778): Support TPU/GPU checkify.
|
||||
if not interpret:
|
||||
raise NotImplementedError(
|
||||
"Checkify for pallas_call only supports interpret mode.")
|
||||
# We implement the checkify rule in 4 steps:
|
||||
# 1) First, trace the kernel body to get the expected error shapes.
|
||||
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
|
||||
# and outputs.
|
||||
# 3) Create a new kernel which stores the errors in output memrefs instead of
|
||||
# returning them, since pallas kernels do not return outputs.
|
||||
# 4) Create block specs for the error state and call pallas_call with
|
||||
# the new kernel.
|
||||
dynamic_grid_bounds, scalars, args = split_list( # type: ignore
|
||||
args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands]
|
||||
)
|
||||
num_scalars = len(scalars)
|
||||
num_invars = len(jaxpr.invars)
|
||||
num_inputs_outputs = (
|
||||
num_invars
|
||||
- grid_mapping.num_index_operands
|
||||
- grid_mapping.num_scratch_operands
|
||||
)
|
||||
num_kernel_inputs = len(args)
|
||||
num_scratch = num_invars - num_inputs_outputs
|
||||
num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs
|
||||
|
||||
# Trace the jaxpr to get an initial error value so the kernel jaxpr has all of
|
||||
# the required inputs.
|
||||
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
_jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr(
|
||||
closed_jaxpr, enabled_errors, error, grid_mapping)
|
||||
error = error._add_placeholder_effects(error_effects)
|
||||
err_vals, err_tree = jax.tree.flatten(error)
|
||||
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
|
||||
|
||||
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
|
||||
# all enabled errors removed, but have the error as inputs and return values.
|
||||
input_avals = [v.aval for v in jaxpr.invars]
|
||||
num_err_vals = len(err_vals)
|
||||
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
|
||||
checkify_in_avals = [*shaped_err_avals,
|
||||
*shaped_input_avals]
|
||||
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
|
||||
closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals)
|
||||
|
||||
# Create a new kernel to remove the error as an return value and instead
|
||||
# write them to a memref. This is because pallas kernels are expected
|
||||
# to have no return values but instead write their outputs to a ref.
|
||||
def checked_kernel_fn(*args):
|
||||
(scalars, _, inputs, out_error_refs, outputs, scratch
|
||||
) = split_list(
|
||||
args,
|
||||
[num_scalars, num_err_vals,
|
||||
num_kernel_inputs, num_err_vals, num_kernel_outputs])
|
||||
input_error_vals = [err_ref[...] for err_ref in out_error_refs]
|
||||
# We need to re-order the inputs here. A checkified jaxpr always expects
|
||||
# errors before other arguments.
|
||||
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
|
||||
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
|
||||
result_flat = jax.core.eval_jaxpr(
|
||||
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
|
||||
output_errors, _ = split_list(result_flat, [num_err_vals])
|
||||
# Store new errors back in the error refs.
|
||||
for out_ref, error in zip(out_error_refs, output_errors):
|
||||
out_ref[...] = error
|
||||
return []
|
||||
|
||||
# Trace the new checked_kernel_fn with Memref inputs so that
|
||||
# we can replace the old kernel jaxpr with the new checked jaxpr in
|
||||
# pallas_call.
|
||||
# TODO(justinfu): Place errors in scalar memory for non-interpret mode.
|
||||
error_mem_space = None
|
||||
error_memref_aval = [pallas_core.AbstractMemoryRef(
|
||||
err_val, error_mem_space) for err_val in shaped_err_avals]
|
||||
shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list(
|
||||
shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs])
|
||||
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
|
||||
*error_memref_aval, *output_aval, *scratch_aval]
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
|
||||
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
|
||||
debug = pe.debug_info(
|
||||
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
|
||||
|
||||
# Prepare pallas_call inputs. We need to create new block specs
|
||||
# for the new error inputs and outputs.
|
||||
scalar_avals = map(checkify.get_shaped_aval, scalars)
|
||||
error_block_specs = [no_block_spec] * num_err_vals
|
||||
grid_avals = [
|
||||
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
||||
# TODO(justinfu): Place these in device-specific scalar memory.
|
||||
scalar_ref_avals = [
|
||||
pallas_core.AbstractMemoryRef(
|
||||
jax_core.ShapedArray(aval.shape, aval.dtype), None)
|
||||
for aval in scalar_avals]
|
||||
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
|
||||
error_block_mappings = map(
|
||||
partial(
|
||||
pallas_core._convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals),
|
||||
in_tree=grid_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=grid_mapping.mapped_dims),
|
||||
error_block_specs, error_memref_aval)
|
||||
input_block_mappings, output_block_mappings = split_list(
|
||||
grid_mapping.block_mappings, [num_kernel_inputs,])
|
||||
grid_mapping_with_error = grid_mapping.replace(
|
||||
block_mappings=(*error_block_mappings, *input_block_mappings,
|
||||
*error_block_mappings, *output_block_mappings)
|
||||
)
|
||||
error_out_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals)
|
||||
# Bump all input_output_aliases by num_err_vals to make room for error
|
||||
# TODO(justinfu): Don't bump scalars here.
|
||||
input_output_aliases = tuple(
|
||||
(i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases)
|
||||
input_output_aliases_with_error = tuple(
|
||||
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
|
||||
|
||||
new_vals_in = [*scalars, *err_vals, *args]
|
||||
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
|
||||
jaxpr=final_jaxpr,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping_with_error,
|
||||
input_output_aliases=input_output_aliases_with_error,
|
||||
out_shapes=error_out_shapes + out_shapes,
|
||||
**kwargs)
|
||||
errors, results = split_list(result, [num_err_vals])
|
||||
new_error, _ = jax.tree.unflatten(out_tree, errors)
|
||||
return new_error, results
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
|
||||
flat_out_avals, in_tree, out_tree, interpret: bool):
|
||||
|
@ -23,6 +23,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import checkify
|
||||
from jax._src import state
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -2481,5 +2482,108 @@ class PallasCallTraceTest(PallasTPUTest):
|
||||
self.assertEqual(num_start, 2)
|
||||
self.assertEqual(num_stop, 2)
|
||||
|
||||
|
||||
class PallasCallTPUCheckifyTest(PallasTPUTest):
|
||||
interpret: bool = True
|
||||
|
||||
@parameterized.parameters((2,), (5,), (6,), (7,))
|
||||
def test_checkify_with_scalar_prefetch(self, threshold):
|
||||
def body(scalar_ref, x_ref, o_ref):
|
||||
scalar = scalar_ref[pl.program_id(0)]
|
||||
o_ref[...] = x_ref[...]
|
||||
checkify.check(scalar < threshold, 'failed on value {x}', x=scalar)
|
||||
|
||||
s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32)
|
||||
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
|
||||
|
||||
def _x_transform(i, s_ref):
|
||||
s = pl.load(s_ref, (i,))
|
||||
return (s, 0)
|
||||
|
||||
pallas_call = self.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
in_specs=[
|
||||
pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])),
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda i, _: (i, 0),
|
||||
(x.shape[0] // 8, x.shape[1])),
|
||||
grid=8,
|
||||
),
|
||||
)
|
||||
checked_call = checkify.checkify(pallas_call)
|
||||
err, out = checked_call(s, x)
|
||||
expected_error_value = s[jnp.argmax(s >= threshold)]
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, f'failed on value {expected_error_value}'):
|
||||
err.throw()
|
||||
np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape))
|
||||
|
||||
def test_checkify_with_scratch(self):
|
||||
def body(x_ref, o_ref, scratch_ref):
|
||||
scratch_ref[...] = x_ref[...]
|
||||
o_ref[...] = scratch_ref[...]
|
||||
all_nequal = ~jnp.all(o_ref[...] == x_ref[...])
|
||||
checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})',
|
||||
x=pl.program_id(0), y=pl.program_id(1))
|
||||
|
||||
x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32)
|
||||
pallas_call = self.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i, j: (i, j), (32, 32)),
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda i, j: (i, j), (32, 32)),
|
||||
scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)],
|
||||
grid=(4, 4),
|
||||
),
|
||||
)
|
||||
checked_call = checkify.checkify(pallas_call)
|
||||
err, out = checked_call(x)
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'):
|
||||
err.throw()
|
||||
np.testing.assert_allclose(out, x)
|
||||
|
||||
@parameterized.parameters((4,), (9,))
|
||||
def test_checkify_with_dynamic_grid(self, iteration):
|
||||
grid_size = 4
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
@pl.when(pl.program_id(0) == iteration)
|
||||
def _():
|
||||
checkify.check(False, f"error on iteration {iteration}")
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(steps):
|
||||
pallas_call = self.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)
|
||||
return checkify.checkify(pallas_call)()
|
||||
|
||||
err, result = dynamic_kernel(jnp.int32(grid_size))
|
||||
if iteration < grid_size * 2:
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, f"error on iteration {iteration}"):
|
||||
err.throw()
|
||||
np.testing.assert_array_equal(
|
||||
result, np.full(shape, grid_size * 2.0, np.float32)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import checkify
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
@ -2260,5 +2261,114 @@ class PallasInterpretModeOutOfBoundsTest(PallasTest):
|
||||
np.testing.assert_allclose(out, expected, atol=atol)
|
||||
|
||||
|
||||
class PallasCheckifyTest(PallasTest):
|
||||
# TODO(b/346651778): Support non-interpret mode checkify.
|
||||
INTERPRET: bool = True
|
||||
|
||||
def test_no_checkify(self,):
|
||||
def kernel(y_ref):
|
||||
y_ref[...] = jnp.zeros_like(y_ref[...])
|
||||
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
checked_call = checkify.checkify(pallas_call)
|
||||
err, result = checked_call()
|
||||
err.throw() # Should not raise.
|
||||
np.testing.assert_allclose(result, jnp.zeros_like(result))
|
||||
|
||||
def test_does_not_clobber_previous_error(self,):
|
||||
def kernel(y_ref):
|
||||
y_ref[...] = jnp.zeros_like(y_ref[...])
|
||||
checkify.check(False, "error in kernel")
|
||||
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
def error_before_call():
|
||||
checkify.check(False, "error before call")
|
||||
return pallas_call()
|
||||
checked_call = checkify.checkify(error_before_call)
|
||||
err, result = checked_call()
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, "error before call"):
|
||||
err.throw()
|
||||
np.testing.assert_allclose(result, jnp.zeros_like(result))
|
||||
|
||||
@parameterized.parameters((False,), (True,))
|
||||
def test_trivial_check(self, assert_cond):
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...]
|
||||
checkify.check(assert_cond, "pallas check failed")
|
||||
input = jnp.arange(4, dtype=jnp.int32)
|
||||
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
checked_call = checkify.checkify(pallas_call)
|
||||
err, result = checked_call(input)
|
||||
if not assert_cond:
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, "pallas check failed"):
|
||||
err.throw()
|
||||
np.testing.assert_allclose(result, input)
|
||||
|
||||
def test_nan_error(self):
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = jnp.log(x_ref[...])
|
||||
input = jnp.arange(4, dtype=jnp.float32) - 2
|
||||
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.all_checks)
|
||||
err, result = checked_call(input)
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, "nan generated by primitive: log"):
|
||||
err.throw()
|
||||
is_nan = jnp.isnan(result)
|
||||
np.testing.assert_allclose(is_nan, input < 0)
|
||||
|
||||
def test_nan_error_with_assertion(self):
|
||||
# TODO(b/346842088): Fix check asserts clobbering other errors.
|
||||
self.skipTest('Known failure.')
|
||||
# Test NaN error is not clobbered by an assertion failure
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = jnp.log(x_ref[...])
|
||||
checkify.check(False, "do not raise")
|
||||
input = jnp.arange(4, dtype=jnp.float32) - 10
|
||||
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.all_checks)
|
||||
err, _ = checked_call(input)
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, "nan generated by primitive: log"):
|
||||
err.throw()
|
||||
|
||||
@parameterized.parameters((5, 0), (8, 3), (4, 3))
|
||||
def test_checkify_returns_first_error_in_grid(
|
||||
self, num_loops, fail_iteration):
|
||||
# Check that checkify returns the first error that occurs
|
||||
# TODO(justinfu): This test doesn't make sense on GPU, where threads run
|
||||
# in parallel. Update checkify to return a grid of errors.
|
||||
def kernel(x_ref, _):
|
||||
value = jnp.squeeze(x_ref[...])
|
||||
checkify.check(
|
||||
value < fail_iteration, "failed on loop {itr}", itr=value)
|
||||
input_arr = jnp.arange(num_loops, dtype=jnp.float32)
|
||||
in_specs = [pl.BlockSpec(lambda x : (x,), (1,))]
|
||||
out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
grid=(num_loops,),
|
||||
in_specs=in_specs,
|
||||
out_shape=out_shape)
|
||||
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.all_checks)
|
||||
err, _ = checked_call(input_arr)
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"):
|
||||
err.throw()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user