[Pallas] Add support for checkify in TPU execution mode.

PiperOrigin-RevId: 653045818
This commit is contained in:
Justin Fu 2024-07-16 18:12:19 -07:00 committed by jax authors
parent 7069a5a2e1
commit 6ba889c01c
6 changed files with 73 additions and 26 deletions

View File

@ -35,6 +35,7 @@ Remember to align the itemized text with the first line of an item within a list
and `lax.erf_inv` ({jax-issue}`#22310`).
* Added initial support for shape polymorphism for the Pallas TPU custom kernels\
({jax-issue}`#22084`).
* Added TPU support for checkify. ({jax-issue}`#22480`)
## Released with JAX 0.4.30 (June 18, 2024)

View File

@ -19,6 +19,7 @@ from collections.abc import Callable, Iterator, Sequence
import contextlib
import copy
import dataclasses
import enum
import functools
import threading
from typing import Any, Union
@ -86,6 +87,19 @@ class AbstractMemoryRef(state.AbstractRef):
return hash((self.__class__, self.inner_aval, self.memory_space))
class MemorySpace(enum.Enum):
""" Logical, device-agnostic memory spaces.
Each memory space will be translated to a device-specific memory
type during lowering.
"""
ERROR = "error" # Memory space for checkify errors.
INDEX = "index" # Memory space for scalar prefetch arguments.
def __str__(self) -> str:
return self.value
def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
return AbstractMemoryRef(
jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type),

View File

@ -68,6 +68,7 @@ import numpy as np
NDIndexer = indexing.NDIndexer
TPUMemorySpace = tpu_core.TPUMemorySpace
MemorySpace = pl_core.MemorySpace | TPUMemorySpace
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
# Booleans are stored as the following type in memrefs.
@ -117,10 +118,14 @@ class LoweringRuleContext:
replace = dataclasses.replace
def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None
def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None
) -> ir.Attribute:
if memory_space is None:
memory_space = VMEM
elif memory_space == pl_core.MemorySpace.ERROR:
memory_space = SMEM
elif memory_space == pl_core.MemorySpace.INDEX:
memory_space = SMEM
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")
def _dtype_to_ir_type(dtype: jnp.dtype,
@ -146,7 +151,7 @@ def _dtype_to_ir_type(dtype: jnp.dtype,
def aval_to_ir_type(aval,
shape=None,
memory_space: TPUMemorySpace | None = None,
memory_space: MemorySpace | None = None,
is_kernel_boundary: bool = False):
if isinstance(aval, tpu_core.AbstractSemaphore):
if aval.sem_type is tpu_core.SemaphoreType.DMA:

View File

@ -717,10 +717,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
grid_mapping: GridMapping,
out_shapes,
**kwargs):
# TODO(b/346651778): Support TPU/GPU checkify.
if not interpret:
raise NotImplementedError(
"Checkify for pallas_call only supports interpret mode.")
# We implement the checkify rule in 4 steps:
# 1) First, trace the kernel body to get the expected error shapes.
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
@ -749,7 +745,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
_jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr(
closed_jaxpr, enabled_errors, error, grid_mapping)
error = error._add_placeholder_effects(error_effects)
err_vals, err_tree = jax.tree.flatten(error)
err_vals, err_in_tree = jax.tree.flatten(error)
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
@ -761,19 +757,20 @@ def pallas_call_checkify_rule(error: checkify.Error,
*shaped_input_avals]
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals)
checked_jaxpr, error_out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
closed_kernel_jaxpr, enabled_errors, err_in_tree, *checkify_in_avals)
# Create a new kernel to remove the error as an return value and instead
# write them to a memref. This is because pallas kernels are expected
# to have no return values but instead write their outputs to a ref.
def checked_kernel_fn(*args):
(scalars, _, inputs, out_error_refs, outputs, scratch
(scalars, in_error_refs, inputs, out_error_refs, outputs, scratch
) = split_list(
args,
[num_scalars, num_err_vals,
num_kernel_inputs, num_err_vals, num_kernel_outputs])
input_error_vals = [err_ref[...] for err_ref in out_error_refs]
# TODO(b/350593266): Remove zero-indexing once we support ()-shaped scalars.
input_error_vals = [err_ref[0, 0] for err_ref in in_error_refs]
# We need to re-order the inputs here. A checkified jaxpr always expects
# errors before other arguments.
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
@ -782,17 +779,33 @@ def pallas_call_checkify_rule(error: checkify.Error,
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
output_errors, _ = split_list(result_flat, [num_err_vals])
# Store new errors back in the error refs.
for out_ref, error in zip(out_error_refs, output_errors):
out_ref[...] = error
for in_ref, out_ref, error in zip(
in_error_refs, out_error_refs, output_errors):
in_ref[0, 0] = error
out_ref[0, 0] = error
return []
# Trace the new checked_kernel_fn with Memref inputs so that
# we can replace the old kernel jaxpr with the new checked jaxpr in
# pallas_call.
# TODO(justinfu): Place errors in scalar memory for non-interpret mode.
error_mem_space = None
# ensure_2d_shape is only necessary because pallas does not support
# ()-shaped Memrefs.
# TODO(b/350593266): Remove once we support ()-shaped scalars.
def _ensure_2d_error_shape(arg):
if isinstance(arg, jax_core.ShapedArray):
dtype = arg.dtype
return jax_core.ShapedArray((1, 1) + arg.shape, dtype=dtype,
weak_type=arg.weak_type)
elif isinstance(arg, jax.Array):
return jnp.reshape(arg, (1, 1) + arg.shape)
else:
return jnp.array([[arg]])
shaped_err_avals = map(_ensure_2d_error_shape, shaped_err_avals)
err_vals = map(_ensure_2d_error_shape, err_vals)
error_memref_aval = [pallas_core.AbstractMemoryRef(
err_val, error_mem_space) for err_val in shaped_err_avals]
err_val, pallas_core.MemorySpace.ERROR) for err_val in shaped_err_avals]
shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list(
shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs])
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
@ -809,14 +822,17 @@ def pallas_call_checkify_rule(error: checkify.Error,
# Prepare pallas_call inputs. We need to create new block specs
# for the new error inputs and outputs.
scalar_avals = map(checkify.get_shaped_aval, scalars)
error_block_specs = [no_block_spec] * num_err_vals
error_block_specs = [pallas_core.BlockSpec(
index_map=lambda *args: (0,) * len(error.shape),
block_shape=error.shape)
for error in shaped_err_avals]
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
grid_avals = [
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
# TODO(justinfu): Place these in device-specific scalar memory.
scalar_ref_avals = [
pallas_core.AbstractMemoryRef(
jax_core.ShapedArray(aval.shape, aval.dtype), None)
jax_core.ShapedArray(aval.shape, aval.dtype),
pallas_core.MemorySpace.INDEX)
for aval in scalar_avals]
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
error_block_mappings = map(
@ -844,15 +860,22 @@ def pallas_call_checkify_rule(error: checkify.Error,
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
new_vals_in = [*scalars, *err_vals, *args]
new_input_shapes = tuple(
jax.ShapeDtypeStruct(x.shape, x.dtype) for x in [
*scalars, *shaped_err_avals, *args])
del kwargs['in_shapes']
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
jaxpr=final_jaxpr,
interpret=interpret,
grid_mapping=grid_mapping_with_error,
input_output_aliases=input_output_aliases_with_error,
in_shapes=new_input_shapes,
out_shapes=error_out_shapes + out_shapes,
**kwargs)
errors, results = split_list(result, [num_err_vals])
new_error, _ = jax.tree.unflatten(out_tree, errors)
# TODO(b/350593266): Remove line below once we support ()-shaped scalars.
errors = [err_val[0, 0] for err_val in errors]
new_error, _ = jax.tree.unflatten(error_out_tree, errors)
return new_error, results
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule

View File

@ -1796,10 +1796,12 @@ class PallasCheckifyInterpreterTest(PallasBaseTest):
value < fail_iteration, "failed on loop {itr}", itr=value)
input_arr = jnp.arange(num_loops, dtype=jnp.float32)
in_specs = [pl.BlockSpec((1,), lambda x: (x,))]
out_specs = pl.BlockSpec((1,), lambda x: (x,))
out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32)
pallas_call = self.pallas_call(kernel,
grid=(num_loops,),
in_specs=in_specs,
out_specs=out_specs,
out_shape=out_shape)
checked_call = checkify.checkify(pallas_call,

View File

@ -2253,13 +2253,11 @@ class PallasCallTPUBooleanTest(PallasBaseTest):
)(input_arr)
class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest):
INTERPRET: bool = True
class PallasCallTPUCheckifyTest(PallasBaseTest):
INTERPRET: bool = True
@parameterized.parameters((2,), (5,), (6,), (7,))
def test_checkify_with_scalar_prefetch(self, threshold):
@ -2305,17 +2303,17 @@ class PallasCallTPUCheckifyTest(PallasBaseTest):
checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})',
x=pl.program_id(0), y=pl.program_id(1))
x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32)
x = jax.random.uniform(jax.random.key(0), (128, 512), dtype=jnp.float32)
pallas_call = self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((32, 32), lambda i, j: (i, j)),
pl.BlockSpec((32, 128), lambda i, j: (i, j)),
],
out_specs=pl.BlockSpec((32, 32), lambda i, j: (i, j)),
scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)],
out_specs=pl.BlockSpec((32, 128), lambda i, j: (i, j)),
scratch_shapes=[pltpu.VMEM((32, 128), dtype=jnp.float32)],
grid=(4, 4),
),
)
@ -2361,6 +2359,10 @@ class PallasCallTPUCheckifyTest(PallasBaseTest):
)
class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest):
INTERPRET: bool = True
class MiscellaneousInterpreterTest(PallasBaseTest):
"""Tests for recently reported bugs; only pass in interpret mode."""