[Pallas] Add checkify support for pallas_call in interpret mode.

PiperOrigin-RevId: 644181742
This commit is contained in:
Justin Fu 2024-06-17 17:14:58 -07:00 committed by jax authors
parent 1d77720e9a
commit fb68f3449b
3 changed files with 394 additions and 13 deletions

View File

@ -25,6 +25,7 @@ from jax import api_util
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import checkify
from jax._src import config
from jax._src import core as jax_core
from jax._src import effects
@ -114,6 +115,21 @@ def _pad_values_to_block_dimension(value,
value = jnp.pad(value, pad_width, constant_values=pad_value)
return value
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
def _initialize_output_vals(
out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]:
oi_map = {v: k for k, v in input_output_aliases}
output_vals = []
for i, out_shape in enumerate(out_shapes):
if i in oi_map:
output_vals.append(input_args[oi_map[i]])
else:
output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype))
return output_vals
def _logical_to_interpret_mode_dtype(dtype):
"""Converts logical dtypes into JAX dtypes for interpret mode.
@ -171,13 +187,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
oi_map = {v: k for k, v in input_output_aliases}
out = []
for i, out_shape in enumerate(out_shapes):
if i in oi_map:
out.append(args[oi_map[i]])
else:
out.append(uninitialized_value(out_shape.shape, out_shape.dtype))
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
num_invars = len(jaxpr.invars)
@ -190,12 +200,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
)
scratch_avals = [v.aval for v in scratch_invars]
if not all(
hasattr(a, "shape") and hasattr(a, "dtype") for a in scratch_avals
):
raise NotImplementedError(f"Cannot initialize scratch: {scratch_avals}")
scratch_values = [uninitialized_value(a.shape, a.dtype)
for a in scratch_avals]
scratch_values = _initialize_scratch_vals(scratch_avals)
carry = []
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
@ -729,6 +734,168 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
def checkify_pallas_kernel_body_jaxpr(
body_jaxpr: jax_core.ClosedJaxpr,
enabled_errors,
error: checkify.Error,
grid_mapping: GridMapping) -> tuple[
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
err_vals, err_tree = tree_util.tree_flatten(error)
err_vals = map(checkify.get_shaped_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
checked_jaxpr, out_tree, error_effects = checkify.jaxpr_to_checkify_jaxpr(
body_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
return checked_jaxpr, out_tree, error_effects
def pallas_call_checkify_rule(error: checkify.Error,
enabled_errors,
*args: jax_core.Value,
jaxpr: jax_core.Jaxpr,
interpret: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
out_shapes,
**kwargs):
# TODO(b/346651778): Support TPU/GPU checkify.
if not interpret:
raise NotImplementedError(
"Checkify for pallas_call only supports interpret mode.")
# We implement the checkify rule in 4 steps:
# 1) First, trace the kernel body to get the expected error shapes.
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
# and outputs.
# 3) Create a new kernel which stores the errors in output memrefs instead of
# returning them, since pallas kernels do not return outputs.
# 4) Create block specs for the error state and call pallas_call with
# the new kernel.
dynamic_grid_bounds, scalars, args = split_list( # type: ignore
args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands]
)
num_scalars = len(scalars)
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
- grid_mapping.num_index_operands
- grid_mapping.num_scratch_operands
)
num_kernel_inputs = len(args)
num_scratch = num_invars - num_inputs_outputs
num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs
# Trace the jaxpr to get an initial error value so the kernel jaxpr has all of
# the required inputs.
closed_jaxpr = pe.close_jaxpr(jaxpr)
_jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr(
closed_jaxpr, enabled_errors, error, grid_mapping)
error = error._add_placeholder_effects(error_effects)
err_vals, err_tree = jax.tree.flatten(error)
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
# all enabled errors removed, but have the error as inputs and return values.
input_avals = [v.aval for v in jaxpr.invars]
num_err_vals = len(err_vals)
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
checkify_in_avals = [*shaped_err_avals,
*shaped_input_avals]
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals)
# Create a new kernel to remove the error as an return value and instead
# write them to a memref. This is because pallas kernels are expected
# to have no return values but instead write their outputs to a ref.
def checked_kernel_fn(*args):
(scalars, _, inputs, out_error_refs, outputs, scratch
) = split_list(
args,
[num_scalars, num_err_vals,
num_kernel_inputs, num_err_vals, num_kernel_outputs])
input_error_vals = [err_ref[...] for err_ref in out_error_refs]
# We need to re-order the inputs here. A checkified jaxpr always expects
# errors before other arguments.
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
result_flat = jax.core.eval_jaxpr(
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
output_errors, _ = split_list(result_flat, [num_err_vals])
# Store new errors back in the error refs.
for out_ref, error in zip(out_error_refs, output_errors):
out_ref[...] = error
return []
# Trace the new checked_kernel_fn with Memref inputs so that
# we can replace the old kernel jaxpr with the new checked jaxpr in
# pallas_call.
# TODO(justinfu): Place errors in scalar memory for non-interpret mode.
error_mem_space = None
error_memref_aval = [pallas_core.AbstractMemoryRef(
err_val, error_mem_space) for err_val in shaped_err_avals]
shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list(
shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs])
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
*error_memref_aval, *output_aval, *scratch_aval]
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals)
wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(checked_kernel_fn), jaxpr_in_tree)
debug = pe.debug_info(
checked_kernel_fn, jaxpr_in_tree, out_tree_thunk, False, "checkify_pallas")
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
wrapped_kernel_with_err, jaxpr_flat_avals, debug)
# Prepare pallas_call inputs. We need to create new block specs
# for the new error inputs and outputs.
scalar_avals = map(checkify.get_shaped_aval, scalars)
error_block_specs = [no_block_spec] * num_err_vals
grid_avals = [
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
# TODO(justinfu): Place these in device-specific scalar memory.
scalar_ref_avals = [
pallas_core.AbstractMemoryRef(
jax_core.ShapedArray(aval.shape, aval.dtype), None)
for aval in scalar_avals]
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
error_block_mappings = map(
partial(
pallas_core._convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.mapped_dims),
error_block_specs, error_memref_aval)
input_block_mappings, output_block_mappings = split_list(
grid_mapping.block_mappings, [num_kernel_inputs,])
grid_mapping_with_error = grid_mapping.replace(
block_mappings=(*error_block_mappings, *input_block_mappings,
*error_block_mappings, *output_block_mappings)
)
error_out_shapes = tuple(
jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals)
# Bump all input_output_aliases by num_err_vals to make room for error
# TODO(justinfu): Don't bump scalars here.
input_output_aliases = tuple(
(i+num_err_vals, o+num_err_vals) for (i, o) in input_output_aliases)
input_output_aliases_with_error = tuple(
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
new_vals_in = [*scalars, *err_vals, *args]
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
jaxpr=final_jaxpr,
interpret=interpret,
grid_mapping=grid_mapping_with_error,
input_output_aliases=input_output_aliases_with_error,
out_shapes=error_out_shapes + out_shapes,
**kwargs)
errors, results = split_list(result, [num_err_vals])
new_error, _ = jax.tree.unflatten(out_tree, errors)
return new_error, results
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
@weakref_lru_cache
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
flat_out_avals, in_tree, out_tree, interpret: bool):

View File

@ -23,6 +23,7 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax._src import checkify
from jax._src import state
from jax._src import test_util as jtu
from jax._src.interpreters import partial_eval as pe
@ -2481,5 +2482,108 @@ class PallasCallTraceTest(PallasTPUTest):
self.assertEqual(num_start, 2)
self.assertEqual(num_stop, 2)
class PallasCallTPUCheckifyTest(PallasTPUTest):
interpret: bool = True
@parameterized.parameters((2,), (5,), (6,), (7,))
def test_checkify_with_scalar_prefetch(self, threshold):
def body(scalar_ref, x_ref, o_ref):
scalar = scalar_ref[pl.program_id(0)]
o_ref[...] = x_ref[...]
checkify.check(scalar < threshold, 'failed on value {x}', x=scalar)
s = jnp.array([4, 3, 2, 6, 3, 5, 2, 7], jnp.int32)
x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128))
def _x_transform(i, s_ref):
s = pl.load(s_ref, (i,))
return (s, 0)
pallas_call = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(_x_transform, (x.shape[0] // 8, x.shape[1])),
],
out_specs=pl.BlockSpec(lambda i, _: (i, 0),
(x.shape[0] // 8, x.shape[1])),
grid=8,
),
)
checked_call = checkify.checkify(pallas_call)
err, out = checked_call(s, x)
expected_error_value = s[jnp.argmax(s >= threshold)]
with self.assertRaisesRegex(
checkify.JaxRuntimeError, f'failed on value {expected_error_value}'):
err.throw()
np.testing.assert_allclose(out, x.reshape((8, 8, -1))[s].reshape(x.shape))
def test_checkify_with_scratch(self):
def body(x_ref, o_ref, scratch_ref):
scratch_ref[...] = x_ref[...]
o_ref[...] = scratch_ref[...]
all_nequal = ~jnp.all(o_ref[...] == x_ref[...])
checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})',
x=pl.program_id(0), y=pl.program_id(1))
x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32)
pallas_call = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec(lambda i, j: (i, j), (32, 32)),
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (32, 32)),
scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)],
grid=(4, 4),
),
)
checked_call = checkify.checkify(pallas_call)
err, out = checked_call(x)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, r'x_ref equals o_ref id=\(0, 0\)'):
err.throw()
np.testing.assert_allclose(out, x)
@parameterized.parameters((4,), (9,))
def test_checkify_with_dynamic_grid(self, iteration):
grid_size = 4
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@pl.when(pl.program_id(0) == iteration)
def _():
checkify.check(False, f"error on iteration {iteration}")
@jax.jit
def dynamic_kernel(steps):
pallas_call = self.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)
return checkify.checkify(pallas_call)()
err, result = dynamic_kernel(jnp.int32(grid_size))
if iteration < grid_size * 2:
with self.assertRaisesRegex(
checkify.JaxRuntimeError, f"error on iteration {iteration}"):
err.throw()
np.testing.assert_array_equal(
result, np.full(shape, grid_size * 2.0, np.float32)
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import random
from jax._src import checkify
from jax._src import config
from jax._src import linear_util as lu
from jax._src import state
@ -2260,5 +2261,114 @@ class PallasInterpretModeOutOfBoundsTest(PallasTest):
np.testing.assert_allclose(out, expected, atol=atol)
class PallasCheckifyTest(PallasTest):
# TODO(b/346651778): Support non-interpret mode checkify.
INTERPRET: bool = True
def test_no_checkify(self,):
def kernel(y_ref):
y_ref[...] = jnp.zeros_like(y_ref[...])
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
pallas_call = self.pallas_call(kernel,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call)
err, result = checked_call()
err.throw() # Should not raise.
np.testing.assert_allclose(result, jnp.zeros_like(result))
def test_does_not_clobber_previous_error(self,):
def kernel(y_ref):
y_ref[...] = jnp.zeros_like(y_ref[...])
checkify.check(False, "error in kernel")
out_shape = jax.ShapeDtypeStruct((2, 2), jnp.float32)
pallas_call = self.pallas_call(kernel,
out_shape=out_shape)
def error_before_call():
checkify.check(False, "error before call")
return pallas_call()
checked_call = checkify.checkify(error_before_call)
err, result = checked_call()
with self.assertRaisesRegex(
checkify.JaxRuntimeError, "error before call"):
err.throw()
np.testing.assert_allclose(result, jnp.zeros_like(result))
@parameterized.parameters((False,), (True,))
def test_trivial_check(self, assert_cond):
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
checkify.check(assert_cond, "pallas check failed")
input = jnp.arange(4, dtype=jnp.int32)
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
pallas_call = self.pallas_call(kernel,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call)
err, result = checked_call(input)
if not assert_cond:
with self.assertRaisesRegex(
checkify.JaxRuntimeError, "pallas check failed"):
err.throw()
np.testing.assert_allclose(result, input)
def test_nan_error(self):
def kernel(x_ref, y_ref):
y_ref[...] = jnp.log(x_ref[...])
input = jnp.arange(4, dtype=jnp.float32) - 2
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
pallas_call = self.pallas_call(kernel,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call,
errors=checkify.all_checks)
err, result = checked_call(input)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, "nan generated by primitive: log"):
err.throw()
is_nan = jnp.isnan(result)
np.testing.assert_allclose(is_nan, input < 0)
def test_nan_error_with_assertion(self):
# TODO(b/346842088): Fix check asserts clobbering other errors.
self.skipTest('Known failure.')
# Test NaN error is not clobbered by an assertion failure
def kernel(x_ref, y_ref):
y_ref[...] = jnp.log(x_ref[...])
checkify.check(False, "do not raise")
input = jnp.arange(4, dtype=jnp.float32) - 10
out_shape = jax.ShapeDtypeStruct(input.shape, input.dtype)
pallas_call = self.pallas_call(kernel,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call,
errors=checkify.all_checks)
err, _ = checked_call(input)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, "nan generated by primitive: log"):
err.throw()
@parameterized.parameters((5, 0), (8, 3), (4, 3))
def test_checkify_returns_first_error_in_grid(
self, num_loops, fail_iteration):
# Check that checkify returns the first error that occurs
# TODO(justinfu): This test doesn't make sense on GPU, where threads run
# in parallel. Update checkify to return a grid of errors.
def kernel(x_ref, _):
value = jnp.squeeze(x_ref[...])
checkify.check(
value < fail_iteration, "failed on loop {itr}", itr=value)
input_arr = jnp.arange(num_loops, dtype=jnp.float32)
in_specs = [pl.BlockSpec(lambda x : (x,), (1,))]
out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32)
pallas_call = self.pallas_call(kernel,
grid=(num_loops,),
in_specs=in_specs,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call,
errors=checkify.all_checks)
err, _ = checked_call(input_arr)
with self.assertRaisesRegex(
checkify.JaxRuntimeError, f"failed on loop {fail_iteration}"):
err.throw()
if __name__ == "__main__":
absltest.main()