mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #22275 from gnecula:pallas_interpret
PiperOrigin-RevId: 650607250
This commit is contained in:
commit
4b260cdc6b
@ -104,11 +104,11 @@ the overall array). Each block index is then multiplied by the
|
||||
corresponding axis size from `block_shape`
|
||||
to get the actual element index on the corresponding array axis.
|
||||
|
||||
```{note}
|
||||
This documentation applies to the case when the block shape divides
|
||||
the array shape.
|
||||
The documentation for the other cases is pending.
|
||||
```
|
||||
If the block shape does not divide evenly the overall shape then the
|
||||
last iteration on each axis will still receive references to blocks
|
||||
of `block_shape` but the elements that are out-of-bounds are padded
|
||||
on input and discarded on output. Note that at least one of the
|
||||
elements in each block must be within bounds.
|
||||
|
||||
More precisely, the slices for each axis of the input `x` of
|
||||
shape `x_shape` are computed as in the function `slice_for_invocation`
|
||||
@ -125,10 +125,9 @@ below:
|
||||
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
|
||||
... elem_indices = []
|
||||
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
|
||||
... assert block_size <= x_size # Blocks must be smaller than the array
|
||||
... start_idx = block_idx * block_size
|
||||
... # For now, we document only the case when the entire iteration is in bounds
|
||||
... assert start_idx + block_size <= x_size
|
||||
... # At least one element of the block must be within bounds
|
||||
... assert start_idx < x_size
|
||||
... elem_indices.append(slice(start_idx, start_idx + block_size))
|
||||
... return elem_indices
|
||||
|
||||
@ -139,15 +138,22 @@ For example:
|
||||
>>> slices_for_invocation(x_shape=(100, 100),
|
||||
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
|
||||
... grid = (10, 5),
|
||||
... invocation_indices = (2, 3))
|
||||
[slice(20, 30, None), slice(60, 80, None)]
|
||||
... invocation_indices = (2, 4))
|
||||
[slice(20, 30, None), slice(80, 100, None)]
|
||||
|
||||
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
|
||||
>>> slices_for_invocation(x_shape=(100, 100),
|
||||
... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
|
||||
... grid = (10, 5, 4),
|
||||
... invocation_indices = (2, 3, 0))
|
||||
[slice(20, 30, None), slice(60, 80, None)]
|
||||
... invocation_indices = (2, 4, 0))
|
||||
[slice(20, 30, None), slice(80, 100, None)]
|
||||
|
||||
>>> # An example when the block is partially out-of-bounds in the 2nd axis.
|
||||
>>> slices_for_invocation(x_shape=(100, 90),
|
||||
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
|
||||
... grid = (10, 5),
|
||||
... invocation_indices = (2, 4))
|
||||
[slice(20, 30, None), slice(80, 100, None)]
|
||||
|
||||
```
|
||||
|
||||
@ -186,6 +192,20 @@ For example:
|
||||
[30 30 30 31 31 31]
|
||||
[30 30 30 31 31 31]]
|
||||
|
||||
>>> # An example with out-of-bounds accesses
|
||||
>>> show_invocations(x_shape=(7, 5), block_shape=(2, 3), grid=(4, 2))
|
||||
[[ 0 0 0 1 1]
|
||||
[ 0 0 0 1 1]
|
||||
[10 10 10 11 11]
|
||||
[10 10 10 11 11]
|
||||
[20 20 20 21 21]
|
||||
[20 20 20 21 21]
|
||||
[30 30 30 31 31]]
|
||||
|
||||
>>> # It is allowed for the shape to be smaller than block_shape
|
||||
>>> show_invocations(x_shape=(1, 2), block_shape=(2, 3), grid=(1, 1))
|
||||
[[0 0]]
|
||||
|
||||
```
|
||||
|
||||
When multiple invocations write to the same elements of the output
|
||||
|
@ -35,7 +35,6 @@ from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.primitives import uninitialized_value
|
||||
from jax._src.state import discharge as state_discharge
|
||||
@ -164,142 +163,140 @@ def _get_next_indices(grid, indices):
|
||||
next_indices.append(jnp.where(carry, 0, i))
|
||||
return tuple(reversed(next_indices))
|
||||
|
||||
def _pallas_call_impl(*args, jaxpr, name, out_shapes,
|
||||
interpret, debug: bool,
|
||||
in_shapes,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any):
|
||||
def _pallas_call_impl(*args, **kwargs):
|
||||
assert False # We always jit a pallas call, we only need the lowering rule
|
||||
|
||||
def _pallas_call_impl_interpret(
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str, in_shapes, out_shapes,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any):
|
||||
del compiler_params, name, in_shapes
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
if interpret:
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr. This should reproduce exactly what compiling to Triton
|
||||
# will do.
|
||||
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
||||
grid = tuple(
|
||||
a if a is not pallas_core.dynamic_grid_dim
|
||||
else next(dynamic_grid_args_iter)
|
||||
for a in grid_mapping.grid
|
||||
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
||||
grid = tuple(
|
||||
a if a is not pallas_core.dynamic_grid_dim
|
||||
else next(dynamic_grid_args_iter)
|
||||
for a in grid_mapping.grid
|
||||
)
|
||||
assert next(dynamic_grid_args_iter, None) is None
|
||||
with grid_mapping.trace_env():
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
|
||||
num_invars = len(jaxpr.invars)
|
||||
num_inputs_outputs = (
|
||||
num_invars
|
||||
- grid_mapping.num_index_operands
|
||||
- grid_mapping.num_scratch_operands
|
||||
)
|
||||
_, _, scratch_invars = split_list(
|
||||
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
|
||||
)
|
||||
scratch_avals = [v.aval for v in scratch_invars]
|
||||
scratch_values = _initialize_scratch_vals(scratch_avals)
|
||||
|
||||
carry = []
|
||||
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
|
||||
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("Padding with aliasing not supported.")
|
||||
x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding])
|
||||
carry.append(x)
|
||||
|
||||
block_shapes_without_mapped_dims = [
|
||||
None if block_mapping is None else block_mapping.block_shape
|
||||
for block_mapping in grid_mapping.block_mappings
|
||||
]
|
||||
is_indexing_dim = [
|
||||
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
|
||||
for bm in block_shapes_without_mapped_dims
|
||||
]
|
||||
block_shapes = [
|
||||
None if (bm is None or iid is None)
|
||||
else tuple(1 if i else b for i, b in zip(iid, bm))
|
||||
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
|
||||
]
|
||||
|
||||
# Pad values to evenly divide into block dimensions. This matches the
|
||||
# behavior of the non-interpret mode. We pad with NaN, to make it easier
|
||||
# to catch OOB accesses.
|
||||
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
|
||||
carry.extend(scratch_values)
|
||||
|
||||
num_inout = len(args) + len(out)
|
||||
grid_start_indices = (jnp.int32(0),) * len(grid)
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid)
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
def cond(carry):
|
||||
i, *_ = carry
|
||||
return i < num_iterations
|
||||
def body(carry):
|
||||
i, loop_idx, *carry = carry
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.mapped_dims
|
||||
)
|
||||
assert next(dynamic_grid_args_iter, None) is None
|
||||
with grid_mapping.trace_env():
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
|
||||
num_invars = len(jaxpr.invars)
|
||||
num_inputs_outputs = (
|
||||
num_invars
|
||||
- grid_mapping.num_index_operands
|
||||
- grid_mapping.num_scratch_operands
|
||||
)
|
||||
_, _, scratch_invars = split_list(
|
||||
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
|
||||
)
|
||||
scratch_avals = [v.aval for v in scratch_invars]
|
||||
scratch_values = _initialize_scratch_vals(scratch_avals)
|
||||
|
||||
carry = []
|
||||
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
|
||||
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("Padding with aliasing not supported.")
|
||||
x = lax.pad(x, jnp.zeros((), x.dtype), [(*p, 0) for p in padding])
|
||||
carry.append(x)
|
||||
|
||||
block_shapes_without_mapped_dims = [
|
||||
None if block_mapping is None else block_mapping.block_shape
|
||||
for block_mapping in grid_mapping.block_mappings
|
||||
]
|
||||
is_indexing_dim = [
|
||||
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
|
||||
for bm in block_shapes_without_mapped_dims
|
||||
]
|
||||
block_shapes = [
|
||||
None if (bm is None or iid is None)
|
||||
else tuple(1 if i else b for i, b in zip(iid, bm))
|
||||
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
|
||||
]
|
||||
|
||||
# Pad values to evenly divide into block dimensions.
|
||||
# This allows interpret mode to catch errors on OOB memory accesses
|
||||
# by poisoning values with NaN. It also fixes an inconsistency with
|
||||
# lax.dynamic_slice where if the slice goes out of bounds, it will instead
|
||||
# move the start_index backwards so the slice will fit in memory.
|
||||
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
|
||||
carry.extend(scratch_values)
|
||||
|
||||
num_inout = len(args) + len(out)
|
||||
grid_start_indices = (jnp.int32(0),) * len(grid)
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid)
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
def cond(carry):
|
||||
i, *_ = carry
|
||||
return i < num_iterations
|
||||
def body(carry):
|
||||
i, loop_idx, *carry = carry
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.mapped_dims
|
||||
carry, scratch = split_list(carry, [num_inout])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
||||
is_indexing_dim)
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
|
||||
scratch_values
|
||||
), (
|
||||
len(discharged_jaxpr.invars),
|
||||
len(scalars),
|
||||
len(blocks),
|
||||
len(scratch_values),
|
||||
)
|
||||
carry, scratch = split_list(carry, [num_inout])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
||||
is_indexing_dim)
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
|
||||
scratch_values
|
||||
), (
|
||||
len(discharged_jaxpr.invars),
|
||||
len(scalars),
|
||||
len(blocks),
|
||||
len(scratch_values),
|
||||
)
|
||||
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
|
||||
*blocks, *scratch)
|
||||
blocks = blocks[grid_mapping.num_index_operands:]
|
||||
blocks, out_scratch = split_list(blocks, [num_inout])
|
||||
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry, blocks, is_indexing_dim)
|
||||
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
|
||||
(_, _, *carry) = lax.while_loop(
|
||||
cond, body, (jnp.int32(0), grid_start_indices, *carry)
|
||||
)
|
||||
_, out, _ = split_list(carry, [len(args), len(out)])
|
||||
assert len(grid_mapping.block_mappings) == len(args) + len(out)
|
||||
out_block_mappings = grid_mapping.block_mappings[len(args):]
|
||||
out_nopad = []
|
||||
for o, bm in zip(out, out_block_mappings):
|
||||
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("Padding with aliasing not supported.")
|
||||
pad_low, pad_high = zip(*padding)
|
||||
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
|
||||
o = lax.slice(o, pad_low, limit_indices)
|
||||
out_nopad.append(o)
|
||||
return out_nopad
|
||||
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
grid_mapping=grid_mapping, interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=compiler_params)
|
||||
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
|
||||
*blocks, *scratch)
|
||||
blocks = blocks[grid_mapping.num_index_operands:]
|
||||
blocks, out_scratch = split_list(blocks, [num_inout])
|
||||
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry, blocks, is_indexing_dim)
|
||||
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
|
||||
(_, _, *carry) = lax.while_loop(
|
||||
cond, body, (jnp.int32(0), grid_start_indices, *carry)
|
||||
)
|
||||
_, out, _ = split_list(carry, [len(args), len(out)])
|
||||
assert len(grid_mapping.block_mappings) == len(args) + len(out)
|
||||
out_block_mappings = grid_mapping.block_mappings[len(args):]
|
||||
out_nopad = []
|
||||
for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings):
|
||||
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError("Padding with aliasing not supported.")
|
||||
pad_low, pad_high = zip(*padding)
|
||||
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
|
||||
o = lax.slice(o, pad_low, limit_indices)
|
||||
if o.shape != expected_o_shape.shape:
|
||||
o = lax.slice(o, (0,) * o.ndim, expected_o_shape.shape)
|
||||
out_nopad.append(o)
|
||||
return out_nopad
|
||||
|
||||
pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
||||
@ -954,7 +951,7 @@ def _pallas_call_lowering(
|
||||
):
|
||||
if interpret:
|
||||
# If we are in interpret mode, we don't care what platform we are on.
|
||||
impl = partial(_pallas_call_impl, **params, interpret=True)
|
||||
impl = partial(_pallas_call_impl_interpret, **params)
|
||||
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
||||
|
||||
def cpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||
@ -1080,6 +1077,7 @@ def pallas_call(
|
||||
for a in flat_args)
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
# TODO(necula): check that input_output_aliases is well-formed: shapes match, no duplicates, etc.
|
||||
grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr(
|
||||
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
|
||||
out_tree, out_paths, interpret=interpret)
|
||||
|
@ -1471,23 +1471,31 @@ def parameterized_filterable(*,
|
||||
testcase_name: Callable[[dict[str, Any]], str] | None = None,
|
||||
one_containing: str | None = None,
|
||||
):
|
||||
"""
|
||||
Decorator for named parameterized tests, with filtering.
|
||||
"""Decorator for named parameterized tests, with filtering support.
|
||||
|
||||
Works like parameterized.named_parameters, except that it supports the
|
||||
`one_containing` option. This is useful to select only one of the tests,
|
||||
and to leave the test name unchanged (helps with specifying the desired test
|
||||
when debugging).
|
||||
Works like ``parameterized.named_parameters``, except that it sanitizes the test
|
||||
names so that we can use ``pytest -k`` and ``python test.py -k`` test filtering.
|
||||
This means, e.g., that many special characters are replaced with `_`.
|
||||
It also supports the ``one_containing`` arg to select one of the tests, while
|
||||
leaving the name unchanged, which is useful for IDEs to be able to easily
|
||||
pick up the enclosing test name.
|
||||
|
||||
Usage:
|
||||
@jtu.parameterized_filterable(
|
||||
# one_containing="a_4",
|
||||
[dict(a=4, b=5),
|
||||
dict(a=5, b=4)])
|
||||
def test_my_test(self, *, a, b): ...
|
||||
|
||||
Args:
|
||||
kwargs: Each entry is a set of kwargs to be passed to the test function.
|
||||
testcase_name: Optionally, a function to construct the testcase_name from
|
||||
one kwargs dict. If not given then kwarg may contain `testcase_name` and
|
||||
if not, the test case name is constructed as `str(kwarg)`.
|
||||
one kwargs dict. If not given then ``kwargs`` may contain ``testcase_name`` and
|
||||
otherwise the test case name is constructed as ``str(kwarg)``.
|
||||
We sanitize the test names to work with -k test filters. See
|
||||
`sanitize_test_name`.
|
||||
one_containing: If given, then leave the test name unchanged, and use
|
||||
only one `kwargs` whose `testcase_name` includes `one_containing`.
|
||||
``sanitize_test_name``.
|
||||
one_containing: If given, then leaves the test name unchanged, and use
|
||||
only one of the ``kwargs`` whose `testcase_name` includes ``one_containing``.
|
||||
"""
|
||||
# Ensure that all kwargs contain a testcase_name
|
||||
kwargs_with_testcase_name: Sequence[dict[str, Any]]
|
||||
|
@ -461,6 +461,29 @@ class PallasCallTest(PallasBaseTest):
|
||||
|
||||
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
|
||||
|
||||
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
|
||||
def test_block_spec_with_padding(self):
|
||||
def f(*, shape, block_shape):
|
||||
def kernel(o1_ref):
|
||||
assert o1_ref.shape == block_shape
|
||||
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
|
||||
|
||||
return self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct(shape, dtype=np.int32),
|
||||
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
|
||||
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
|
||||
# No padding
|
||||
pids = f(shape=(8,), block_shape=(2,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
|
||||
# Pad the last block
|
||||
pids = f(shape=(8,), block_shape=(3,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
|
||||
# Works even if the shape is smaller than 1 block
|
||||
pids = f(shape=(3,), block_shape=(8,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 0], dtype=np.int32))
|
||||
|
||||
class PallasCallInterpreterTest(PallasCallTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user