mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
This commit is contained in:
parent
7069a5a2e1
commit
6ba889c01c
@ -35,6 +35,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
and `lax.erf_inv` ({jax-issue}`#22310`).
|
||||
* Added initial support for shape polymorphism for the Pallas TPU custom kernels\
|
||||
({jax-issue}`#22084`).
|
||||
* Added TPU support for checkify. ({jax-issue}`#22480`)
|
||||
|
||||
## Released with JAX 0.4.30 (June 18, 2024)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from collections.abc import Callable, Iterator, Sequence
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import threading
|
||||
from typing import Any, Union
|
||||
@ -86,6 +87,19 @@ class AbstractMemoryRef(state.AbstractRef):
|
||||
return hash((self.__class__, self.inner_aval, self.memory_space))
|
||||
|
||||
|
||||
class MemorySpace(enum.Enum):
|
||||
""" Logical, device-agnostic memory spaces.
|
||||
|
||||
Each memory space will be translated to a device-specific memory
|
||||
type during lowering.
|
||||
"""
|
||||
ERROR = "error" # Memory space for checkify errors.
|
||||
INDEX = "index" # Memory space for scalar prefetch arguments.
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
|
||||
return AbstractMemoryRef(
|
||||
jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type),
|
||||
|
@ -68,6 +68,7 @@ import numpy as np
|
||||
|
||||
NDIndexer = indexing.NDIndexer
|
||||
TPUMemorySpace = tpu_core.TPUMemorySpace
|
||||
MemorySpace = pl_core.MemorySpace | TPUMemorySpace
|
||||
VMEM = tpu_core.TPUMemorySpace.VMEM
|
||||
SMEM = tpu_core.TPUMemorySpace.SMEM
|
||||
# Booleans are stored as the following type in memrefs.
|
||||
@ -117,10 +118,14 @@ class LoweringRuleContext:
|
||||
replace = dataclasses.replace
|
||||
|
||||
|
||||
def _memory_space_to_tpu_memspace(memory_space: TPUMemorySpace | None
|
||||
def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None
|
||||
) -> ir.Attribute:
|
||||
if memory_space is None:
|
||||
memory_space = VMEM
|
||||
elif memory_space == pl_core.MemorySpace.ERROR:
|
||||
memory_space = SMEM
|
||||
elif memory_space == pl_core.MemorySpace.INDEX:
|
||||
memory_space = SMEM
|
||||
return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>")
|
||||
|
||||
def _dtype_to_ir_type(dtype: jnp.dtype,
|
||||
@ -146,7 +151,7 @@ def _dtype_to_ir_type(dtype: jnp.dtype,
|
||||
|
||||
def aval_to_ir_type(aval,
|
||||
shape=None,
|
||||
memory_space: TPUMemorySpace | None = None,
|
||||
memory_space: MemorySpace | None = None,
|
||||
is_kernel_boundary: bool = False):
|
||||
if isinstance(aval, tpu_core.AbstractSemaphore):
|
||||
if aval.sem_type is tpu_core.SemaphoreType.DMA:
|
||||
|
@ -717,10 +717,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
grid_mapping: GridMapping,
|
||||
out_shapes,
|
||||
**kwargs):
|
||||
# TODO(b/346651778): Support TPU/GPU checkify.
|
||||
if not interpret:
|
||||
raise NotImplementedError(
|
||||
"Checkify for pallas_call only supports interpret mode.")
|
||||
# We implement the checkify rule in 4 steps:
|
||||
# 1) First, trace the kernel body to get the expected error shapes.
|
||||
# 2) Checkify the kernel body to obtain a jaxpr with errors as inputs
|
||||
@ -749,7 +745,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
_jaxpr, _, error_effects = checkify_pallas_kernel_body_jaxpr(
|
||||
closed_jaxpr, enabled_errors, error, grid_mapping)
|
||||
error = error._add_placeholder_effects(error_effects)
|
||||
err_vals, err_tree = jax.tree.flatten(error)
|
||||
err_vals, err_in_tree = jax.tree.flatten(error)
|
||||
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
|
||||
|
||||
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
|
||||
@ -761,19 +757,20 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
*shaped_input_avals]
|
||||
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
checked_jaxpr, out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
|
||||
closed_kernel_jaxpr, enabled_errors, err_tree, *checkify_in_avals)
|
||||
checked_jaxpr, error_out_tree, _ = checkify.jaxpr_to_checkify_jaxpr(
|
||||
closed_kernel_jaxpr, enabled_errors, err_in_tree, *checkify_in_avals)
|
||||
|
||||
# Create a new kernel to remove the error as an return value and instead
|
||||
# write them to a memref. This is because pallas kernels are expected
|
||||
# to have no return values but instead write their outputs to a ref.
|
||||
def checked_kernel_fn(*args):
|
||||
(scalars, _, inputs, out_error_refs, outputs, scratch
|
||||
(scalars, in_error_refs, inputs, out_error_refs, outputs, scratch
|
||||
) = split_list(
|
||||
args,
|
||||
[num_scalars, num_err_vals,
|
||||
num_kernel_inputs, num_err_vals, num_kernel_outputs])
|
||||
input_error_vals = [err_ref[...] for err_ref in out_error_refs]
|
||||
# TODO(b/350593266): Remove zero-indexing once we support ()-shaped scalars.
|
||||
input_error_vals = [err_ref[0, 0] for err_ref in in_error_refs]
|
||||
# We need to re-order the inputs here. A checkified jaxpr always expects
|
||||
# errors before other arguments.
|
||||
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
|
||||
@ -782,17 +779,33 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
|
||||
output_errors, _ = split_list(result_flat, [num_err_vals])
|
||||
# Store new errors back in the error refs.
|
||||
for out_ref, error in zip(out_error_refs, output_errors):
|
||||
out_ref[...] = error
|
||||
for in_ref, out_ref, error in zip(
|
||||
in_error_refs, out_error_refs, output_errors):
|
||||
in_ref[0, 0] = error
|
||||
out_ref[0, 0] = error
|
||||
return []
|
||||
|
||||
# Trace the new checked_kernel_fn with Memref inputs so that
|
||||
# we can replace the old kernel jaxpr with the new checked jaxpr in
|
||||
# pallas_call.
|
||||
# TODO(justinfu): Place errors in scalar memory for non-interpret mode.
|
||||
error_mem_space = None
|
||||
|
||||
# ensure_2d_shape is only necessary because pallas does not support
|
||||
# ()-shaped Memrefs.
|
||||
# TODO(b/350593266): Remove once we support ()-shaped scalars.
|
||||
def _ensure_2d_error_shape(arg):
|
||||
if isinstance(arg, jax_core.ShapedArray):
|
||||
dtype = arg.dtype
|
||||
return jax_core.ShapedArray((1, 1) + arg.shape, dtype=dtype,
|
||||
weak_type=arg.weak_type)
|
||||
elif isinstance(arg, jax.Array):
|
||||
return jnp.reshape(arg, (1, 1) + arg.shape)
|
||||
else:
|
||||
return jnp.array([[arg]])
|
||||
shaped_err_avals = map(_ensure_2d_error_shape, shaped_err_avals)
|
||||
err_vals = map(_ensure_2d_error_shape, err_vals)
|
||||
|
||||
error_memref_aval = [pallas_core.AbstractMemoryRef(
|
||||
err_val, error_mem_space) for err_val in shaped_err_avals]
|
||||
err_val, pallas_core.MemorySpace.ERROR) for err_val in shaped_err_avals]
|
||||
shaped_scalar_avals, input_aval, output_aval, scratch_aval = split_list(
|
||||
shaped_input_avals, [num_scalars, num_kernel_inputs, num_kernel_outputs])
|
||||
retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval,
|
||||
@ -809,14 +822,17 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
# Prepare pallas_call inputs. We need to create new block specs
|
||||
# for the new error inputs and outputs.
|
||||
scalar_avals = map(checkify.get_shaped_aval, scalars)
|
||||
error_block_specs = [no_block_spec] * num_err_vals
|
||||
error_block_specs = [pallas_core.BlockSpec(
|
||||
index_map=lambda *args: (0,) * len(error.shape),
|
||||
block_shape=error.shape)
|
||||
for error in shaped_err_avals]
|
||||
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
|
||||
grid_avals = [
|
||||
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
||||
# TODO(justinfu): Place these in device-specific scalar memory.
|
||||
scalar_ref_avals = [
|
||||
pallas_core.AbstractMemoryRef(
|
||||
jax_core.ShapedArray(aval.shape, aval.dtype), None)
|
||||
jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
pallas_core.MemorySpace.INDEX)
|
||||
for aval in scalar_avals]
|
||||
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
|
||||
error_block_mappings = map(
|
||||
@ -844,15 +860,22 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
|
||||
|
||||
new_vals_in = [*scalars, *err_vals, *args]
|
||||
new_input_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype) for x in [
|
||||
*scalars, *shaped_err_avals, *args])
|
||||
del kwargs['in_shapes']
|
||||
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
|
||||
jaxpr=final_jaxpr,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping_with_error,
|
||||
input_output_aliases=input_output_aliases_with_error,
|
||||
in_shapes=new_input_shapes,
|
||||
out_shapes=error_out_shapes + out_shapes,
|
||||
**kwargs)
|
||||
errors, results = split_list(result, [num_err_vals])
|
||||
new_error, _ = jax.tree.unflatten(out_tree, errors)
|
||||
# TODO(b/350593266): Remove line below once we support ()-shaped scalars.
|
||||
errors = [err_val[0, 0] for err_val in errors]
|
||||
new_error, _ = jax.tree.unflatten(error_out_tree, errors)
|
||||
return new_error, results
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
|
@ -1796,10 +1796,12 @@ class PallasCheckifyInterpreterTest(PallasBaseTest):
|
||||
value < fail_iteration, "failed on loop {itr}", itr=value)
|
||||
input_arr = jnp.arange(num_loops, dtype=jnp.float32)
|
||||
in_specs = [pl.BlockSpec((1,), lambda x: (x,))]
|
||||
out_specs = pl.BlockSpec((1,), lambda x: (x,))
|
||||
out_shape = jax.ShapeDtypeStruct((1,), dtype=jnp.float32)
|
||||
pallas_call = self.pallas_call(kernel,
|
||||
grid=(num_loops,),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
out_shape=out_shape)
|
||||
|
||||
checked_call = checkify.checkify(pallas_call,
|
||||
|
@ -2253,13 +2253,11 @@ class PallasCallTPUBooleanTest(PallasBaseTest):
|
||||
)(input_arr)
|
||||
|
||||
|
||||
|
||||
class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest):
|
||||
INTERPRET: bool = True
|
||||
|
||||
|
||||
class PallasCallTPUCheckifyTest(PallasBaseTest):
|
||||
INTERPRET: bool = True
|
||||
|
||||
@parameterized.parameters((2,), (5,), (6,), (7,))
|
||||
def test_checkify_with_scalar_prefetch(self, threshold):
|
||||
@ -2305,17 +2303,17 @@ class PallasCallTPUCheckifyTest(PallasBaseTest):
|
||||
checkify.check(all_nequal, 'x_ref equals o_ref id=({x}, {y})',
|
||||
x=pl.program_id(0), y=pl.program_id(1))
|
||||
|
||||
x = jax.random.uniform(jax.random.key(0), (128, 128), dtype=jnp.float32)
|
||||
x = jax.random.uniform(jax.random.key(0), (128, 512), dtype=jnp.float32)
|
||||
pallas_call = self.pallas_call(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec((32, 32), lambda i, j: (i, j)),
|
||||
pl.BlockSpec((32, 128), lambda i, j: (i, j)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((32, 32), lambda i, j: (i, j)),
|
||||
scratch_shapes=[pltpu.VMEM((32, 32), dtype=jnp.float32)],
|
||||
out_specs=pl.BlockSpec((32, 128), lambda i, j: (i, j)),
|
||||
scratch_shapes=[pltpu.VMEM((32, 128), dtype=jnp.float32)],
|
||||
grid=(4, 4),
|
||||
),
|
||||
)
|
||||
@ -2361,6 +2359,10 @@ class PallasCallTPUCheckifyTest(PallasBaseTest):
|
||||
)
|
||||
|
||||
|
||||
class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest):
|
||||
INTERPRET: bool = True
|
||||
|
||||
|
||||
class MiscellaneousInterpreterTest(PallasBaseTest):
|
||||
"""Tests for recently reported bugs; only pass in interpret mode."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user