diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 58a59f492..4e9e8d11b 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -35,6 +35,7 @@ Remember to align the itemized text with the first line of an item within a list and `lax.erf_inv` ({jax-issue}`#22310`). * Added initial support for shape polymorphism for the Pallas TPU custom kernels\ ({jax-issue}`#22084`). + * Added TPU support for checkify. ({jax-issue}`#22480`) ## Released with JAX 0.4.30 (June 18, 2024) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 7b5f0b99e..c088fbcfe 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -19,6 +19,7 @@ from collections.abc import Callable, Iterator, Sequence import contextlib import copy import dataclasses +import enum import functools import threading from typing import Any, Union @@ -86,6 +87,19 @@ class AbstractMemoryRef(state.AbstractRef): return hash((self.__class__, self.inner_aval, self.memory_space)) +class MemorySpace(enum.Enum): + """ Logical, device-agnostic memory spaces. + + Each memory space will be translated to a device-specific memory + type during lowering. + """ + ERROR = "error" # Memory space for checkify errors. + INDEX = "index" # Memory space for scalar prefetch arguments. + + def __str__(self) -> str: + return self.value + + def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): return AbstractMemoryRef( jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type), diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index eb8ab086b..04ff17781 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -68,6 +68,7 @@ import numpy as np NDIndexer = indexing.NDIndexer TPUMemorySpace = tpu_core.TPUMemorySpace +MemorySpace = pl_core.MemorySpace | TPUMemorySpace VMEM = tpu_core.TPUMemorySpace.VMEM SMEM = tpu_core.TPUMemorySpace.SMEM # Booleans are stored as the following type in memrefs. @@ -117,10 +118,14 @@ class LoweringRuleContext: replace = dataclasses.replace -def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None +def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None ) -> ir.Attribute: if memory_space is None: memory_space = VMEM + elif memory_space == pl_core.MemorySpace.ERROR: + memory_space = SMEM + elif memory_space == pl_core.MemorySpace.INDEX: + memory_space = SMEM return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>") def _dtype_to_ir_type(dtype: jnp.dtype, @@ -146,7 +151,7 @@ def _dtype_to_ir_type(dtype: jnp.dtype, def aval_to_ir_type(aval, shape=None, - memory_space: TPUMemorySpace | None = None, + memory_space: MemorySpace | None = None, is_kernel_boundary: bool = False): if isinstance(aval, tpu_core.AbstractSemaphore): if aval.sem_type is tpu_core.SemaphoreType.DMA: diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 2e7cb519f..0bac322fe 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -717,10 +717,6 @@ def pallas_call_checkify_rule(error: checkify.Error, grid_mapping: GridMapping, out_shapes, **kwargs): - # TODO(b/346651778): Support TPU/GPU checkify. - if not interpret: - raise NotImplementedError( - "Checkify for pallas_call only supports interpret mode.") # We implement the checkify rule in 4 steps: # 1) First, trace the kernel body to get the expected error shapes. # 2) Checkify the kernel body to obtain a jaxpr with errors as inputs @@ -749,7 +745,7 @@ def pallas_call_checkify_rule(error: checkify.Error, _jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr( closed_jaxpr, enabled_errors, error, grid_mapping) error = error._add_placeholder_effects(error_effects) - err_vals, err_tree = jax.tree.flatten(error) + err_vals, err_in_tree = jax.tree.flatten(error) shaped_err_avals = map(checkify.get_shaped_aval, err_vals) # Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have @@ -761,19 +757,20 @@ def pallas_call_checkify_rule(error: checkify.Error, *shaped_input_avals] closed_kernel_jaxpr = pe.close_jaxpr(jaxpr) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): - checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr( - closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals) + checked_jaxpr, error_out_tree, _ = checkify.jaxpr_to_checkify_jaxpr( + closed_kernel_jaxpr, enabled_errors, err_in_tree, *checkify_in_avals) # Create a new kernel to remove the error as an return value and instead # write them to a memref. This is because pallas kernels are expected # to have no return values but instead write their outputs to a ref. def checked_kernel_fn(*args): - (scalars, _, inputs, out_error_refs, outputs, scratch + (scalars, in_error_refs, inputs, out_error_refs, outputs, scratch ) = split_list( args, [num_scalars, num_err_vals, num_kernel_inputs, num_err_vals, num_kernel_outputs]) - input_error_vals = [err_ref[...] for err_ref in out_error_refs] + # TODO(b/350593266): Remove zero-indexing once we support ()-shaped scalars. + input_error_vals = [err_ref[0, 0] for err_ref in in_error_refs] # We need to re-order the inputs here. A checkified jaxpr always expects # errors before other arguments. jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch] @@ -782,17 +779,33 @@ def pallas_call_checkify_rule(error: checkify.Error, checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args) output_errors, _ = split_list(result_flat, [num_err_vals]) # Store new errors back in the error refs. - for out_ref, error in zip(out_error_refs, output_errors): - out_ref[...] = error + for in_ref, out_ref, error in zip( + in_error_refs, out_error_refs, output_errors): + in_ref[0, 0] = error + out_ref[0, 0] = error return [] # Trace the new checked_kernel_fn with Memref inputs so that # we can replace the old kernel jaxpr with the new checked jaxpr in # pallas_call. - # TODO(justinfu): Place errors in scalar memory for non-interpret mode. - error_mem_space = None + + # ensure_2d_shape is only necessary because pallas does not support + # ()-shaped Memrefs. + # TODO(b/350593266): Remove once we support ()-shaped scalars. + def _ensure_2d_error_shape(arg): + if isinstance(arg, jax_core.ShapedArray): + dtype = arg.dtype + return jax_core.ShapedArray((1, 1) + arg.shape, dtype=dtype, + weak_type=arg.weak_type) + elif isinstance(arg, jax.Array): + return jnp.reshape(arg, (1, 1) + arg.shape) + else: + return jnp.array([[arg]]) + shaped_err_avals = map(_ensure_2d_error_shape, shaped_err_avals) + err_vals = map(_ensure_2d_error_shape, err_vals) + error_memref_aval = [pallas_core.AbstractMemoryRef( - err_val, error_mem_space) for err_val in shaped_err_avals] + err_val, pallas_core.MemorySpace.ERROR) for err_val in shaped_err_avals] shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list( shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs]) retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, @@ -809,14 +822,17 @@ def pallas_call_checkify_rule(error: checkify.Error, # Prepare pallas_call inputs. We need to create new block specs # for the new error inputs and outputs. scalar_avals = map(checkify.get_shaped_aval, scalars) - error_block_specs = [no_block_spec] * num_err_vals + error_block_specs = [pallas_core.BlockSpec( + index_map=lambda *args: (0,) * len(error.shape), + block_shape=error.shape) + for error in shaped_err_avals] error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0]) grid_avals = [ jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) - # TODO(justinfu): Place these in device-specific scalar memory. scalar_ref_avals = [ pallas_core.AbstractMemoryRef( - jax_core.ShapedArray(aval.shape, aval.dtype), None) + jax_core.ShapedArray(aval.shape, aval.dtype), + pallas_core.MemorySpace.INDEX) for aval in scalar_avals] grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {})) error_block_mappings = map( @@ -844,15 +860,22 @@ def pallas_call_checkify_rule(error: checkify.Error, (i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases new_vals_in = [*scalars, *err_vals, *args] + new_input_shapes = tuple( + jax.ShapeDtypeStruct(x.shape, x.dtype) for x in [ + *scalars, *shaped_err_avals, *args]) + del kwargs['in_shapes'] result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in, jaxpr=final_jaxpr, interpret=interpret, grid_mapping=grid_mapping_with_error, input_output_aliases=input_output_aliases_with_error, + in_shapes=new_input_shapes, out_shapes=error_out_shapes + out_shapes, **kwargs) errors, results = split_list(result, [num_err_vals]) - new_error, _ = jax.tree.unflatten(out_tree, errors) + # TODO(b/350593266): Remove line below once we support ()-shaped scalars. + errors = [err_val[0, 0] for err_val in errors] + new_error, _ = jax.tree.unflatten(error_out_tree, errors) return new_error, results checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index e1d48db20..f24c2cb28 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1796,10 +1796,12 @@ class PallasCheckifyInterpreterTest(PallasBaseTest): value < fail_iteration, "failed on loop {itr}", itr=value) input_arr = jnp.arange(num_loops, dtype=jnp.float32) in_specs = [pl.BlockSpec((1,), lambda x: (x,))] + out_specs = pl.BlockSpec((1,), lambda x: (x,)) out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32) pallas_call = self.pallas_call(kernel, grid=(num_loops,), in_specs=in_specs, + out_specs=out_specs, out_shape=out_shape) checked_call = checkify.checkify(pallas_call, diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 9f49a0365..dbb549cd3 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2253,13 +2253,11 @@ class PallasCallTPUBooleanTest(PallasBaseTest): )(input_arr) - class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest): INTERPRET: bool = True class PallasCallTPUCheckifyTest(PallasBaseTest): - INTERPRET: bool = True @parameterized.parameters((2,), (5,), (6,), (7,)) def test_checkify_with_scalar_prefetch(self, threshold): @@ -2305,17 +2303,17 @@ class PallasCallTPUCheckifyTest(PallasBaseTest): checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})', x=pl.program_id(0), y=pl.program_id(1)) - x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32) + x = jax.random.uniform(jax.random.key(0), (128, 512), dtype=jnp.float32) pallas_call = self.pallas_call( body, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec((32, 32), lambda i, j: (i, j)), + pl.BlockSpec((32, 128), lambda i, j: (i, j)), ], - out_specs=pl.BlockSpec((32, 32), lambda i, j: (i, j)), - scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)], + out_specs=pl.BlockSpec((32, 128), lambda i, j: (i, j)), + scratch_shapes=[pltpu.VMEM((32, 128), dtype=jnp.float32)], grid=(4, 4), ), ) @@ -2361,6 +2359,10 @@ class PallasCallTPUCheckifyTest(PallasBaseTest): ) +class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest): + INTERPRET: bool = True + + class MiscellaneousInterpreterTest(PallasBaseTest): """Tests for recently reported bugs; only pass in interpret mode."""