mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
This commit is contained in:
parent
9748e2ab1a
commit
9cf952a535
@ -993,6 +993,85 @@ def checkify_pallas_kernel_body_jaxpr(
|
||||
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
||||
return checked_jaxpr, out_tree, error_effects
|
||||
|
||||
def pallas_call_checkify_oob_grid(error: checkify.Error,
|
||||
enabled_errors,
|
||||
args: jax_core.Value,
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases) -> checkify.Error:
|
||||
if checkify.OOBError not in enabled_errors:
|
||||
return error
|
||||
dynamic_grid_args, args = split_list(
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
output_args = _initialize_output_vals(grid_mapping.block_mappings_output,
|
||||
args, input_output_aliases)
|
||||
scalars, input_args, _ = split_list(
|
||||
args, [grid_mapping.num_index_operands,
|
||||
grid_mapping.num_inputs],
|
||||
)
|
||||
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
||||
grid = tuple(
|
||||
a if a is not pallas_core.dynamic_grid_dim
|
||||
else next(dynamic_grid_args_iter)
|
||||
for a in grid_mapping.grid
|
||||
)
|
||||
grid_start_indices = (jnp.int32(0),) * len(grid)
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid)
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
|
||||
is_indexing_dim = [
|
||||
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
||||
for bm in grid_mapping.block_mappings
|
||||
]
|
||||
block_shapes = [
|
||||
None if iid is None
|
||||
else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
||||
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
||||
]
|
||||
# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
|
||||
# i:int32 is the interation index
|
||||
# loop_idx: tuple[int32] are the program ids for each grid axis
|
||||
def cond(carry):
|
||||
i, *_ = carry
|
||||
return i < num_iterations
|
||||
def body(carry):
|
||||
i, loop_idx = carry
|
||||
if grid_mapping.local_grid_env is not None:
|
||||
local_grid_env = grid_mapping.local_grid_env(loop_idx, grid)
|
||||
else:
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.vmapped_dims
|
||||
)
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
# We perform a dynamic slice on the i/o blocks, which will be checked by
|
||||
# checkify for OOB accesses.
|
||||
map(_maybe_dynamic_slice, start_indices, block_shapes,
|
||||
[*input_args, *output_args], is_indexing_dim)
|
||||
return (i + 1, _get_next_indices(grid, loop_idx))
|
||||
def f(_):
|
||||
lax.while_loop(
|
||||
cond, body, (jnp.int32(0), grid_start_indices)
|
||||
)
|
||||
flat_args, jaxpr_in_tree = jax.tree_util.tree_flatten((jnp.int32(0),))
|
||||
wrapped_loop, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(f), jaxpr_in_tree)
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
avals_in = map(jax_core.get_aval, flat_args)
|
||||
traced_loop, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
wrapped_loop, list(avals_in))
|
||||
traced_loop = jax_core.ClosedJaxpr(traced_loop, consts)
|
||||
out_error, _ = checkify.checkify_jaxpr(
|
||||
traced_loop, checkify.index_checks, error, flat_args)
|
||||
return out_error
|
||||
|
||||
def pallas_call_checkify_rule(error: checkify.Error,
|
||||
enabled_errors,
|
||||
*args: jax_core.Value,
|
||||
@ -1002,6 +1081,10 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
grid_mapping: GridMapping,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
**kwargs):
|
||||
# Check for OOB accesses in the grid.
|
||||
error = pallas_call_checkify_oob_grid(error, enabled_errors,
|
||||
args, grid_mapping,
|
||||
input_output_aliases)
|
||||
# We implement the checkify rule in 4 steps:
|
||||
# 1) First, trace the kernel body to get the expected error shapes.
|
||||
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
|
||||
|
@ -2033,11 +2033,12 @@ class PallasOutOfBoundsInterpretTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(out, expected, atol=atol)
|
||||
|
||||
|
||||
class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
# TODO(b/346651778): Support non-interpret mode checkify.
|
||||
INTERPRET = True
|
||||
class PallasCheckifyTest(PallasBaseTest):
|
||||
INTERPRET = False
|
||||
|
||||
def test_no_checkify(self,):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Not supported on GPU.")
|
||||
def kernel(y_ref):
|
||||
y_ref[...] = jnp.zeros_like(y_ref[...])
|
||||
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
@ -2049,6 +2050,8 @@ class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(result, jnp.zeros_like(result))
|
||||
|
||||
def test_does_not_clobber_previous_error(self,):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Not supported on GPU.")
|
||||
def kernel(y_ref):
|
||||
y_ref[...] = jnp.zeros_like(y_ref[...])
|
||||
checkify.check(False, "error in kernel")
|
||||
@ -2067,6 +2070,8 @@ class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
|
||||
@parameterized.parameters((False,), (True,))
|
||||
def test_trivial_check(self, assert_cond):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Not supported on GPU.")
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...]
|
||||
checkify.check(assert_cond, "pallas check failed")
|
||||
@ -2083,6 +2088,8 @@ class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(result, input)
|
||||
|
||||
def test_nan_error(self):
|
||||
if not self.INTERPRET:
|
||||
self.skipTest("Not supported in non-interpret mode.")
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = jnp.log(x_ref[...])
|
||||
input = jnp.arange(4, dtype=jnp.float32) - 2
|
||||
@ -2090,7 +2097,7 @@ class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
out_shape=out_shape)
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.all_checks)
|
||||
errors=checkify.nan_checks)
|
||||
err, result = checked_call(input)
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, "nan generated by primitive: log"):
|
||||
@ -2119,6 +2126,8 @@ class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
@parameterized.parameters((5, 0), (8, 3), (4, 3))
|
||||
def test_checkify_returns_first_error_in_grid(
|
||||
self, num_loops, fail_iteration):
|
||||
if not self.INTERPRET:
|
||||
self.skipTest("Not supported in non-interpret mode.")
|
||||
# Check that checkify returns the first error that occurs
|
||||
# TODO(justinfu): This test doesn't make sense on GPU, where threads run
|
||||
# in parallel. Update checkify to return a grid of errors.
|
||||
@ -2137,12 +2146,42 @@ class PallasCheckifyInterpretTest(PallasBaseTest):
|
||||
out_shape=out_shape)
|
||||
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.all_checks)
|
||||
errors=checkify.user_checks)
|
||||
err, _ = checked_call(input_arr)
|
||||
with self.assertRaisesRegex(
|
||||
checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"):
|
||||
err.throw()
|
||||
|
||||
def test_checkify_on_oob_grid_access(self):
|
||||
if not self.INTERPRET:
|
||||
self.skipTest("Not supported in non-interpret mode.")
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("Not supported in x64 mode.")
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
input_arr = jnp.arange(18, dtype=jnp.float32)
|
||||
in_specs = [pl.BlockSpec((8,), lambda x: (x,))]
|
||||
out_specs = pl.BlockSpec((8,), lambda x: (x,))
|
||||
out_shape = jax.ShapeDtypeStruct((18,), dtype=jnp.float32)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
grid=(3,),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
out_shape=out_shape)
|
||||
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
errors=checkify.index_checks)
|
||||
err, result = checked_call(input_arr)
|
||||
with self.assertRaisesRegex(checkify.JaxRuntimeError,
|
||||
(r"out-of-bounds indexing for array of shape \(18,\): index 16 "
|
||||
r"is out of bounds for axis 0 with size 18")):
|
||||
err.throw()
|
||||
np.testing.assert_array_equal(result, input_arr)
|
||||
|
||||
|
||||
class PallasCheckifyInterpretTest(PallasCheckifyTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class PallasCallNamedGridTest(PallasBaseTest):
|
||||
def test_named_grid(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user