Merge pull request #22275 from gnecula:pallas_interpret

PiperOrigin-RevId: 650607250
This commit is contained in:
jax authors 2024-07-09 06:35:39 -07:00
commit 4b260cdc6b
4 changed files with 205 additions and 156 deletions

View File

@ -104,11 +104,11 @@ the overall array). Each block index is then multiplied by the
corresponding axis size from `block_shape`
to get the actual element index on the corresponding array axis.
```{note}
This documentation applies to the case when the block shape divides
the array shape.
The documentation for the other cases is pending.
```
If the block shape does not divide evenly the overall shape then the
last iteration on each axis will still receive references to blocks
of `block_shape` but the elements that are out-of-bounds are padded
on input and discarded on output. Note that at least one of the
elements in each block must be within bounds.
More precisely, the slices for each axis of the input `x` of
shape `x_shape` are computed as in the function `slice_for_invocation`
@ -125,10 +125,9 @@ below:
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
... elem_indices = []
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
... assert block_size <= x_size # Blocks must be smaller than the array
... start_idx = block_idx * block_size
... # For now, we document only the case when the entire iteration is in bounds
... assert start_idx + block_size <= x_size
... # At least one element of the block must be within bounds
... assert start_idx < x_size
... elem_indices.append(slice(start_idx, start_idx + block_size))
... return elem_indices
@ -139,15 +138,22 @@ For example:
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
... grid = (10, 5),
... invocation_indices = (2, 3))
[slice(20, 30, None), slice(60, 80, None)]
... invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
... grid = (10, 5, 4),
... invocation_indices = (2, 3, 0))
[slice(20, 30, None), slice(60, 80, None)]
... invocation_indices = (2, 4, 0))
[slice(20, 30, None), slice(80, 100, None)]
>>> # An example when the block is partially out-of-bounds in the 2nd axis.
>>> slices_for_invocation(x_shape=(100, 90),
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
... grid = (10, 5),
... invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]
```
@ -186,6 +192,20 @@ For example:
[30 30 30 31 31 31]
[30 30 30 31 31 31]]
>>> # An example with out-of-bounds accesses
>>> show_invocations(x_shape=(7, 5), block_shape=(2, 3), grid=(4, 2))
[[ 0 0 0 1 1]
[ 0 0 0 1 1]
[10 10 10 11 11]
[10 10 10 11 11]
[20 20 20 21 21]
[20 20 20 21 21]
[30 30 30 31 31]]
>>> # It is allowed for the shape to be smaller than block_shape
>>> show_invocations(x_shape=(1, 2), block_shape=(2, 3), grid=(1, 1))
[[0 0]]
```
When multiple invocations write to the same elements of the output

View File

@ -35,7 +35,6 @@ from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.pallas import core as pallas_core
from jax._src.pallas.primitives import uninitialized_value
from jax._src.state import discharge as state_discharge
@ -164,142 +163,140 @@ def _get_next_indices(grid, indices):
next_indices.append(jnp.where(carry, 0, i))
return tuple(reversed(next_indices))
def _pallas_call_impl(*args, jaxpr, name, out_shapes,
interpret, debug: bool,
in_shapes,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
compiler_params: Any):
def _pallas_call_impl(*args, **kwargs):
assert False # We always jit a pallas call, we only need the lowering rule
def _pallas_call_impl_interpret(
*args,
jaxpr: jax_core.Jaxpr,
name: str, in_shapes, out_shapes,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
compiler_params: Any):
del compiler_params, name, in_shapes
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list( # type: ignore
args, [grid_mapping.num_dynamic_grid_bounds]
)
if interpret:
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr. This should reproduce exactly what compiling to Triton
# will do.
dynamic_grid_args_iter = iter(dynamic_grid_args)
grid = tuple(
a if a is not pallas_core.dynamic_grid_dim
else next(dynamic_grid_args_iter)
for a in grid_mapping.grid
dynamic_grid_args_iter = iter(dynamic_grid_args)
grid = tuple(
a if a is not pallas_core.dynamic_grid_dim
else next(dynamic_grid_args_iter)
for a in grid_mapping.grid
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
- grid_mapping.num_index_operands
- grid_mapping.num_scratch_operands
)
_, _, scratch_invars = split_list(
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
)
scratch_avals = [v.aval for v in scratch_invars]
scratch_values = _initialize_scratch_vals(scratch_avals)
carry = []
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
raise NotImplementedError("Padding with aliasing not supported.")
x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding])
carry.append(x)
block_shapes_without_mapped_dims = [
None if block_mapping is None else block_mapping.block_shape
for block_mapping in grid_mapping.block_mappings
]
is_indexing_dim = [
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
for bm in block_shapes_without_mapped_dims
]
block_shapes = [
None if (bm is None or iid is None)
else tuple(1 if i else b for i, b in zip(iid, bm))
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
]
# Pad values to evenly divide into block dimensions. This matches the
# behavior of the non-interpret mode. We pad with NaN, to make it easier
# to catch OOB accesses.
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
carry.extend(scratch_values)
num_inout = len(args) + len(out)
grid_start_indices = (jnp.int32(0),) * len(grid)
if grid:
num_iterations = reduce(jnp.multiply, grid)
else:
# Base case is always one iteration when grid is ()
num_iterations = 1
def cond(carry):
i, *_ = carry
return i < num_iterations
def body(carry):
i, loop_idx, *carry = carry
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.mapped_dims
)
assert next(dynamic_grid_args_iter, None) is None
with grid_mapping.trace_env():
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
- grid_mapping.num_index_operands
- grid_mapping.num_scratch_operands
)
_, _, scratch_invars = split_list(
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
)
scratch_avals = [v.aval for v in scratch_invars]
scratch_values = _initialize_scratch_vals(scratch_avals)
carry = []
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
raise NotImplementedError("Padding with aliasing not supported.")
x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding])
carry.append(x)
block_shapes_without_mapped_dims = [
None if block_mapping is None else block_mapping.block_shape
for block_mapping in grid_mapping.block_mappings
]
is_indexing_dim = [
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
for bm in block_shapes_without_mapped_dims
]
block_shapes = [
None if (bm is None or iid is None)
else tuple(1 if i else b for i, b in zip(iid, bm))
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
]
# Pad values to evenly divide into block dimensions.
# This allows interpret mode to catch errors on OOB memory accesses
# by poisoning values with NaN. It also fixes an inconsistency with
# lax.dynamic_slice where if the slice goes out of bounds, it will instead
# move the start_index backwards so the slice will fit in memory.
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
carry.extend(scratch_values)
num_inout = len(args) + len(out)
grid_start_indices = (jnp.int32(0),) * len(grid)
if grid:
num_iterations = reduce(jnp.multiply, grid)
else:
# Base case is always one iteration when grid is ()
num_iterations = 1
def cond(carry):
i, *_ = carry
return i < num_iterations
def body(carry):
i, loop_idx, *carry = carry
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.mapped_dims
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
with pallas_core.grid_env(local_grid_env):
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
scratch_values
), (
len(discharged_jaxpr.invars),
len(scalars),
len(blocks),
len(scratch_values),
)
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
with pallas_core.grid_env(local_grid_env):
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
scratch_values
), (
len(discharged_jaxpr.invars),
len(scalars),
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
_, out, _ = split_list(carry, [len(args), len(out)])
assert len(grid_mapping.block_mappings) == len(args) + len(out)
out_block_mappings = grid_mapping.block_mappings[len(args):]
out_nopad = []
for o, bm in zip(out, out_block_mappings):
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
raise NotImplementedError("Padding with aliasing not supported.")
pad_low, pad_high = zip(*padding)
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
o = lax.slice(o, pad_low, limit_indices)
out_nopad.append(o)
return out_nopad
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
grid_mapping=grid_mapping, interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
compiler_params=compiler_params)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
_, out, _ = split_list(carry, [len(args), len(out)])
assert len(grid_mapping.block_mappings) == len(args) + len(out)
out_block_mappings = grid_mapping.block_mappings[len(args):]
out_nopad = []
for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings):
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
raise NotImplementedError("Padding with aliasing not supported.")
pad_low, pad_high = zip(*padding)
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
o = lax.slice(o, pad_low, limit_indices)
if o.shape != expected_o_shape.shape:
o = lax.slice(o, (0,) * o.ndim, expected_o_shape.shape)
out_nopad.append(o)
return out_nopad
pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
@ -954,7 +951,7 @@ def _pallas_call_lowering(
):
if interpret:
# If we are in interpret mode, we don't care what platform we are on.
impl = partial(_pallas_call_impl, **params, interpret=True)
impl = partial(_pallas_call_impl_interpret, **params)
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
def cpu_lowering(ctx: mlir.LoweringRuleContext,
@ -1080,6 +1077,7 @@ def pallas_call(
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
# TODO(necula): check that input_output_aliases is well-formed: shapes match, no duplicates, etc.
grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
out_tree, out_paths, interpret=interpret)

View File

@ -1471,23 +1471,31 @@ def parameterized_filterable(*,
testcase_name: Callable[[dict[str, Any]], str] | None = None,
one_containing: str | None = None,
):
"""
Decorator for named parameterized tests, with filtering.
"""Decorator for named parameterized tests, with filtering support.
Works like parameterized.named_parameters, except that it supports the
`one_containing` option. This is useful to select only one of the tests,
and to leave the test name unchanged (helps with specifying the desired test
when debugging).
Works like ``parameterized.named_parameters``, except that it sanitizes the test
names so that we can use ``pytest -k`` and ``python test.py -k`` test filtering.
This means, e.g., that many special characters are replaced with `_`.
It also supports the ``one_containing`` arg to select one of the tests, while
leaving the name unchanged, which is useful for IDEs to be able to easily
pick up the enclosing test name.
Usage:
@jtu.parameterized_filterable(
# one_containing="a_4",
[dict(a=4, b=5),
dict(a=5, b=4)])
def test_my_test(self, *, a, b): ...
Args:
kwargs: Each entry is a set of kwargs to be passed to the test function.
testcase_name: Optionally, a function to construct the testcase_name from
one kwargs dict. If not given then kwarg may contain `testcase_name` and
if not, the test case name is constructed as `str(kwarg)`.
one kwargs dict. If not given then ``kwargs`` may contain ``testcase_name`` and
otherwise the test case name is constructed as ``str(kwarg)``.
We sanitize the test names to work with -k test filters. See
`sanitize_test_name`.
one_containing: If given, then leave the test name unchanged, and use
only one `kwargs` whose `testcase_name` includes `one_containing`.
``sanitize_test_name``.
one_containing: If given, then leaves the test name unchanged, and use
only one of the ``kwargs`` whose `testcase_name` includes ``one_containing``.
"""
# Ensure that all kwargs contain a testcase_name
kwargs_with_testcase_name: Sequence[dict[str, Any]]

View File

@ -461,6 +461,29 @@ class PallasCallTest(PallasBaseTest):
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
def test_block_spec_with_padding(self):
def f(*, shape, block_shape):
def kernel(o1_ref):
assert o1_ref.shape == block_shape
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
return self.pallas_call(kernel,
jax.ShapeDtypeStruct(shape, dtype=np.int32),
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
# No padding
pids = f(shape=(8,), block_shape=(2,))
self.assertAllClose(pids,
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
# Pad the last block
pids = f(shape=(8,), block_shape=(3,))
self.assertAllClose(pids,
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
# Works even if the shape is smaller than 1 block
pids = f(shape=(3,), block_shape=(8,))
self.assertAllClose(pids,
np.array([0, 0, 0], dtype=np.int32))
class PallasCallInterpreterTest(PallasCallTest):
INTERPRET = True