From fb68f3449b2fc89db728fb1a69ccbc4125ce8e54 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 17 Jun 2024 17:14:58 -0700 Subject: [PATCH] [Pallas] Add checkify support for pallas_call in interpret mode. PiperOrigin-RevId: 644181742 --- jax/_src/pallas/pallas_call.py | 193 +++++++++++++++++++++++++-- tests/pallas/pallas_call_tpu_test.py | 104 +++++++++++++++ tests/pallas/pallas_test.py | 110 +++++++++++++++ 3 files changed, 394 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index c5d2c7841..973ca4054 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -25,6 +25,7 @@ from jax import api_util from jax import lax from jax import tree_util from jax._src import ad_util +from jax._src import checkify from jax._src import config from jax._src import core as jax_core from jax._src import effects @@ -114,6 +115,21 @@ def _pad_values_to_block_dimension(value, value = jnp.pad(value, pad_width, constant_values=pad_value) return value +def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: + scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals) + return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) + +def _initialize_output_vals( + out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]: + oi_map = {v: k for k, v in input_output_aliases} + output_vals = [] + for i, out_shape in enumerate(out_shapes): + if i in oi_map: + output_vals.append(input_args[oi_map[i]]) + else: + output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype)) + return output_vals + def _logical_to_interpret_mode_dtype(dtype): """Converts logical dtypes into JAX dtypes for interpret mode. @@ -171,13 +187,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) - oi_map = {v: k for k, v in input_output_aliases} - out = [] - for i, out_shape in enumerate(out_shapes): - if i in oi_map: - out.append(args[oi_map[i]]) - else: - out.append(uninitialized_value(out_shape.shape, out_shape.dtype)) + out = _initialize_output_vals(out_shapes, args, input_output_aliases) scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore # invars: [*scalar_prefetch, *inputs, *outputs, *scratch] num_invars = len(jaxpr.invars) @@ -190,12 +200,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear, jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs] ) scratch_avals = [v.aval for v in scratch_invars] - if not all( - hasattr(a, "shape") and hasattr(a, "dtype") for a in scratch_avals - ): - raise NotImplementedError(f"Cannot initialize scratch: {scratch_avals}") - scratch_values = [uninitialized_value(a.shape, a.dtype) - for a in scratch_avals] + scratch_values = _initialize_scratch_vals(scratch_avals) carry = [] for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings): @@ -729,6 +734,168 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr: assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr + +def checkify_pallas_kernel_body_jaxpr( + body_jaxpr: jax_core.ClosedJaxpr, + enabled_errors, + error: checkify.Error, + grid_mapping: GridMapping) -> tuple[ + jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]: + err_vals, err_tree = tree_util.tree_flatten(error) + err_vals = map(checkify.get_shaped_aval, err_vals) + flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals] + + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr( + body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + return checked_jaxpr, out_tree, error_effects + +def pallas_call_checkify_rule(error: checkify.Error, + enabled_errors, + *args: jax_core.Value, + jaxpr: jax_core.Jaxpr, + interpret: bool, + input_output_aliases: tuple[tuple[int, int], ...], + grid_mapping: GridMapping, + out_shapes, + **kwargs): + # TODO(b/346651778): Support TPU/GPU checkify. + if not interpret: + raise NotImplementedError( + "Checkify for pallas_call only supports interpret mode.") + # We implement the checkify rule in 4 steps: + # 1) First, trace the kernel body to get the expected error shapes. + # 2) Checkify the kernel body to obtain a jaxpr with errors as inputs + # and outputs. + # 3) Create a new kernel which stores the errors in output memrefs instead of + # returning them, since pallas kernels do not return outputs. + # 4) Create block specs for the error state and call pallas_call with + # the new kernel. + dynamic_grid_bounds, scalars, args = split_list( # type: ignore + args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands] + ) + num_scalars = len(scalars) + num_invars = len(jaxpr.invars) + num_inputs_outputs = ( + num_invars + - grid_mapping.num_index_operands + - grid_mapping.num_scratch_operands + ) + num_kernel_inputs = len(args) + num_scratch = num_invars - num_inputs_outputs + num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs + + # Trace the jaxpr to get an initial error value so the kernel jaxpr has all of + # the required inputs. + closed_jaxpr = pe.close_jaxpr(jaxpr) + _jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr( + closed_jaxpr, enabled_errors, error, grid_mapping) + error = error._add_placeholder_effects(error_effects) + err_vals, err_tree = jax.tree.flatten(error) + shaped_err_avals = map(checkify.get_shaped_aval, err_vals) + + # Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have + # all enabled errors removed, but have the error as inputs and return values. + input_avals = [v.aval for v in jaxpr.invars] + num_err_vals = len(err_vals) + shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals) + checkify_in_avals = [*shaped_err_avals, + *shaped_input_avals] + closed_kernel_jaxpr = pe.close_jaxpr(jaxpr) + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr( + closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals) + + # Create a new kernel to remove the error as an return value and instead + # write them to a memref. This is because pallas kernels are expected + # to have no return values but instead write their outputs to a ref. + def checked_kernel_fn(*args): + (scalars, _, inputs, out_error_refs, outputs, scratch + ) = split_list( + args, + [num_scalars, num_err_vals, + num_kernel_inputs, num_err_vals, num_kernel_outputs]) + input_error_vals = [err_ref[...] for err_ref in out_error_refs] + # We need to re-order the inputs here. A checkified jaxpr always expects + # errors before other arguments. + jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch] + assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args) + result_flat = jax.core.eval_jaxpr( + checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args) + output_errors, _ = split_list(result_flat, [num_err_vals]) + # Store new errors back in the error refs. + for out_ref, error in zip(out_error_refs, output_errors): + out_ref[...] = error + return [] + + # Trace the new checked_kernel_fn with Memref inputs so that + # we can replace the old kernel jaxpr with the new checked jaxpr in + # pallas_call. + # TODO(justinfu): Place errors in scalar memory for non-interpret mode. + error_mem_space = None + error_memref_aval = [pallas_core.AbstractMemoryRef( + err_val, error_mem_space) for err_val in shaped_err_avals] + shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list( + shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs]) + retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, + *error_memref_aval, *output_aval, *scratch_aval] + jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) + wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(checked_kernel_fn), jaxpr_in_tree) + debug = pe.debug_info( + checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas") + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + wrapped_kernel_with_err, jaxpr_flat_avals, debug) + + # Prepare pallas_call inputs. We need to create new block specs + # for the new error inputs and outputs. + scalar_avals = map(checkify.get_shaped_aval, scalars) + error_block_specs = [no_block_spec] * num_err_vals + grid_avals = [ + jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) + # TODO(justinfu): Place these in device-specific scalar memory. + scalar_ref_avals = [ + pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(aval.shape, aval.dtype), None) + for aval in scalar_avals] + grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {})) + error_block_mappings = map( + partial( + pallas_core._convert_block_spec_to_block_mapping, + (*grid_avals, *scalar_ref_avals), + in_tree=grid_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.mapped_dims), + error_block_specs, error_memref_aval) + input_block_mappings, output_block_mappings = split_list( + grid_mapping.block_mappings, [num_kernel_inputs,]) + grid_mapping_with_error = grid_mapping.replace( + block_mappings=(*error_block_mappings, *input_block_mappings, + *error_block_mappings, *output_block_mappings) + ) + error_out_shapes = tuple( + jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals) + # Bump all input_output_aliases by num_err_vals to make room for error + # TODO(justinfu): Don't bump scalars here. + input_output_aliases = tuple( + (i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases) + input_output_aliases_with_error = tuple( + (i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases + + new_vals_in = [*scalars, *err_vals, *args] + result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in, + jaxpr=final_jaxpr, + interpret=interpret, + grid_mapping=grid_mapping_with_error, + input_output_aliases=input_output_aliases_with_error, + out_shapes=error_out_shapes + out_shapes, + **kwargs) + errors, results = split_list(result, [num_err_vals]) + new_error, _ = jax.tree.unflatten(out_tree, errors) + return new_error, results +checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule + @weakref_lru_cache def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals, flat_out_avals, in_tree, out_tree, interpret: bool): diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 6a30c1d92..62096ebf8 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax +from jax._src import checkify from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe @@ -2481,5 +2482,108 @@ class PallasCallTraceTest(PallasTPUTest): self.assertEqual(num_start, 2) self.assertEqual(num_stop, 2) + +class PallasCallTPUCheckifyTest(PallasTPUTest): + interpret: bool = True + + @parameterized.parameters((2,), (5,), (6,), (7,)) + def test_checkify_with_scalar_prefetch(self, threshold): + def body(scalar_ref, x_ref, o_ref): + scalar = scalar_ref[pl.program_id(0)] + o_ref[...] = x_ref[...] + checkify.check(scalar < threshold, 'failed on value {x}', x=scalar) + + s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32) + x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) + + def _x_transform(i, s_ref): + s = pl.load(s_ref, (i,)) + return (s, 0) + + pallas_call = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])), + ], + out_specs=pl.BlockSpec(lambda i, _: (i, 0), + (x.shape[0] // 8, x.shape[1])), + grid=8, + ), + ) + checked_call = checkify.checkify(pallas_call) + err, out = checked_call(s, x) + expected_error_value = s[jnp.argmax(s >= threshold)] + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f'failed on value {expected_error_value}'): + err.throw() + np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape)) + + def test_checkify_with_scratch(self): + def body(x_ref, o_ref, scratch_ref): + scratch_ref[...] = x_ref[...] + o_ref[...] = scratch_ref[...] + all_nequal = ~jnp.all(o_ref[...] == x_ref[...]) + checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})', + x=pl.program_id(0), y=pl.program_id(1)) + + x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32) + pallas_call = self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(lambda i, j: (i, j), (32, 32)), + ], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (32, 32)), + scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)], + grid=(4, 4), + ), + ) + checked_call = checkify.checkify(pallas_call) + err, out = checked_call(x) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'): + err.throw() + np.testing.assert_allclose(out, x) + + @parameterized.parameters((4,), (9,)) + def test_checkify_with_dynamic_grid(self, iteration): + grid_size = 4 + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(y_ref): + @pl.when(pl.program_id(0) == 0) + def _init(): + y_ref[...] = jnp.zeros_like(y_ref) + y_ref[...] += 1 + @pl.when(pl.program_id(0) == iteration) + def _(): + checkify.check(False, f"error on iteration {iteration}") + + @jax.jit + def dynamic_kernel(steps): + pallas_call = self.pallas_call( + kernel, + grid=(steps * 2,), + out_specs=pl.BlockSpec(lambda i: (0, 0), shape), + out_shape=result_ty, + ) + return checkify.checkify(pallas_call)() + + err, result = dynamic_kernel(jnp.int32(grid_size)) + if iteration < grid_size * 2: + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f"error on iteration {iteration}"): + err.throw() + np.testing.assert_array_equal( + result, np.full(shape, grid_size * 2.0, np.float32) + ) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index e71622ce2..95b195e11 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import jax from jax import lax from jax import random +from jax._src import checkify from jax._src import config from jax._src import linear_util as lu from jax._src import state @@ -2260,5 +2261,114 @@ class PallasInterpretModeOutOfBoundsTest(PallasTest): np.testing.assert_allclose(out, expected, atol=atol) +class PallasCheckifyTest(PallasTest): + # TODO(b/346651778): Support non-interpret mode checkify. + INTERPRET: bool = True + + def test_no_checkify(self,): + def kernel(y_ref): + y_ref[...] = jnp.zeros_like(y_ref[...]) + out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call) + err, result = checked_call() + err.throw() # Should not raise. + np.testing.assert_allclose(result, jnp.zeros_like(result)) + + def test_does_not_clobber_previous_error(self,): + def kernel(y_ref): + y_ref[...] = jnp.zeros_like(y_ref[...]) + checkify.check(False, "error in kernel") + out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + def error_before_call(): + checkify.check(False, "error before call") + return pallas_call() + checked_call = checkify.checkify(error_before_call) + err, result = checked_call() + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "error before call"): + err.throw() + np.testing.assert_allclose(result, jnp.zeros_like(result)) + + @parameterized.parameters((False,), (True,)) + def test_trivial_check(self, assert_cond): + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + checkify.check(assert_cond, "pallas check failed") + input = jnp.arange(4, dtype=jnp.int32) + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call) + err, result = checked_call(input) + if not assert_cond: + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "pallas check failed"): + err.throw() + np.testing.assert_allclose(result, input) + + def test_nan_error(self): + def kernel(x_ref, y_ref): + y_ref[...] = jnp.log(x_ref[...]) + input = jnp.arange(4, dtype=jnp.float32) - 2 + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, result = checked_call(input) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "nan generated by primitive: log"): + err.throw() + is_nan = jnp.isnan(result) + np.testing.assert_allclose(is_nan, input < 0) + + def test_nan_error_with_assertion(self): + # TODO(b/346842088): Fix check asserts clobbering other errors. + self.skipTest('Known failure.') + # Test NaN error is not clobbered by an assertion failure + def kernel(x_ref, y_ref): + y_ref[...] = jnp.log(x_ref[...]) + checkify.check(False, "do not raise") + input = jnp.arange(4, dtype=jnp.float32) - 10 + out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype) + pallas_call = self.pallas_call(kernel, + out_shape=out_shape) + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, _ = checked_call(input) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, "nan generated by primitive: log"): + err.throw() + + @parameterized.parameters((5, 0), (8, 3), (4, 3)) + def test_checkify_returns_first_error_in_grid( + self, num_loops, fail_iteration): + # Check that checkify returns the first error that occurs + # TODO(justinfu): This test doesn't make sense on GPU, where threads run + # in parallel. Update checkify to return a grid of errors. + def kernel(x_ref, _): + value = jnp.squeeze(x_ref[...]) + checkify.check( + value < fail_iteration, "failed on loop {itr}", itr=value) + input_arr = jnp.arange(num_loops, dtype=jnp.float32) + in_specs = [pl.BlockSpec(lambda x : (x,), (1,))] + out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32) + pallas_call = self.pallas_call(kernel, + grid=(num_loops,), + in_specs=in_specs, + out_shape=out_shape) + + checked_call = checkify.checkify(pallas_call, + errors=checkify.all_checks) + err, _ = checked_call(input_arr) + with self.assertRaisesRegex( + checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"): + err.throw() + + if __name__ == "__main__": absltest.main()