[pallas] Add more documentation and tests for BlockSpec.

This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
This commit is contained in:
George Necula 2024-07-09 12:43:37 +03:00
parent 6f79925d61
commit ea548e7c86
5 changed files with 203 additions and 83 deletions

View File

@ -33,7 +33,10 @@ for i in range(n):
This generalizes to any tuple of integers (a length `d` grid will correspond
to `d` nested loops).
The kernel is executed as many times
as `prod(grid)`. Each of these invocations is referred to as a "program".
as `prod(grid)`.
The default grid value `None` stands for `()`, and results in one
kernel invocation.
Each of these invocations is referred to as a "program".
To access which program (i.e. which element of the grid) the kernel is currently
executing, we use {func}`jax.experimental.pallas.program_id`.
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
@ -233,3 +236,47 @@ See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/t
[309 309 309 319 319 319]]
```
A `None` value appearing as a dimension value in the `block_shape` behaves
as the value `1`, except that the corresponding
block axis is squeezed. In the example below, observe that the
shape of the `o_ref` is (2,) when the block shape was specified as
`(None, 2)` (the leading dimension was squeezed).
```python
>>> def kernel(o_ref):
... assert o_ref.shape == (2,)
... o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
>>> pl.pallas_call(kernel,
... jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
... out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
... grid=(3, 2), interpret=True)()
Array([[ 0, 0, 10, 10],
[ 1, 1, 11, 11],
[ 2, 2, 12, 12]], dtype=int32)
```
When we construct a `BlockSpec` we can use the value `None` for the
`block_shape` parameter, in which case the shape of the overall array
is used as `block_shape`.
And if we use the value `None` for the `index_map` parameter
then a default index map function that returns a tuple of zeros is
used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`.
```python
>>> show_invocations(x_shape=(4, 4), block_shape=None, grid=(2, 3),
... out_index_map=None)
[[12 12 12 12]
[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]
>>> show_invocations(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
... out_index_map=None)
[[12 12 12 12]
[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]
```

View File

@ -211,12 +211,15 @@ class BlockSpec:
def compute_index(self, *args):
assert self.index_map is not None
assert self.block_shape is not None
out = self.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out
class NoBlockSpec:
pass
no_block_spec = NoBlockSpec()
# A PyTree of BlockSpec | NoBlockSpec.
BlockSpecTree = Any
@ -310,9 +313,11 @@ def _convert_block_spec_to_block_mapping(
return None
if block_spec.index_map is None:
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
block_shape = aval.shape
else:
compute_index = block_spec.compute_index
if block_spec.block_shape is None:
block_shape = aval.shape
else:
block_shape = block_spec.block_shape
block_shape = tuple(
mapped if s is None else s for s in block_shape)
@ -338,8 +343,7 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
def _get_ref_avals(grid,
in_avals: Sequence[jax_core.ShapedArray],
def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
in_specs: Sequence[BlockSpec],
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.ShapedArray],
@ -363,9 +367,9 @@ def _get_ref_avals(grid,
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
)
trimmed_block_shape = tuple(s for s in block_shape if s is not None)
block_shape_unmapped = tuple(s for s in block_shape if s is not None)
ref_aval = ref_aval.update(
inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape))
inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped))
if not jax_core.is_constant_shape(ref_aval.shape):
raise ValueError(
@ -382,11 +386,7 @@ def _get_ref_avals(grid,
make_ref_aval(aval, out_spec, out_path, "output")
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
]
return in_specs, in_ref_avals, out_specs, out_ref_avals
class NoBlockSpec:
pass
no_block_spec = NoBlockSpec()
return in_ref_avals, out_ref_avals
@dataclasses.dataclass(init=False, unsafe_hash=True)
@ -453,8 +453,8 @@ class GridSpec:
)
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
self.grid, in_avals, flat_in_specs, in_paths,
in_ref_avals, out_ref_avals = _get_ref_avals(
in_avals, flat_in_specs, in_paths,
out_avals, flat_out_specs, out_paths)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
@ -468,7 +468,7 @@ class GridSpec:
mapped_dims=(),
what="input",
),
in_specs,
flat_in_specs,
in_paths,
in_ref_avals,
)
@ -481,7 +481,7 @@ class GridSpec:
mapped_dims=(),
what="output",
),
out_specs,
flat_out_specs,
out_paths,
out_ref_avals,
)

View File

@ -187,9 +187,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_avals_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = (
in_ref_avals, out_ref_avals = (
pallas_core._get_ref_avals(
self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
out_avals, flat_out_specs, out_paths))
scalar_ref_avals = [
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
@ -209,7 +209,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
mapped_dims=(),
what="input",
),
in_specs,
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_ref_avals,
)
@ -222,7 +222,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
mapped_dims=(),
what="output",
),
out_specs,
flat_out_specs,
out_paths,
out_ref_avals,
)

View File

@ -1034,12 +1034,14 @@ def pallas_call(
See details at :ref:`pallas_grid`.
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the positional arguments.
The default value for ``in_specs`` specifies the whole array for all
inputs, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
See details at :ref:`pallas_blockspec`.
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the outputs.
The default value for ``out_specs`` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
See details at :ref:`pallas_blockspec`.
The default value for `out_specs` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them. These indices are in the
flattened inputs and outputs.
@ -1063,6 +1065,9 @@ def pallas_call(
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
# TODO(necula): this canonicalization may be convenient for some usage
# but it is lossy, because it prevents expressing functions that return
# lists.
if isinstance(out_shape, list):
out_shape = tuple(out_shape)
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)

View File

@ -253,6 +253,114 @@ class PallasCallTest(PallasBaseTest):
# TODO(necula): we normalize out_shape to a tuple, we shouldn't.
self.assertIsInstance(res, tuple)
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
def test_block_spec_with_padding(self):
def f(*, shape, block_shape):
def kernel(o1_ref):
assert o1_ref.shape == block_shape
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
return self.pallas_call(kernel,
jax.ShapeDtypeStruct(shape, dtype=np.int32),
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
# No padding
pids = f(shape=(8,), block_shape=(2,))
self.assertAllClose(pids,
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
# Pad the last block
pids = f(shape=(8,), block_shape=(3,))
self.assertAllClose(pids,
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
# Works even if the shape is smaller than 1 block
pids = f(shape=(3,), block_shape=(8,))
self.assertAllClose(pids,
np.array([0, 0, 0], dtype=np.int32))
def test_block_spec_mapped_dimension(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
in_specs=[
pl.BlockSpec((None, 4), lambda _: (0, 0)),
pl.BlockSpec((None, 4), lambda _: (1, 0)),
],
grid=1,
)
def add_vectors(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[:] + y_ref[:]
xy = jnp.arange(8., dtype=np.float32).reshape((2, 4))
out = add_vectors(xy, xy)
out_ref = xy[0] + xy[1]
np.testing.assert_allclose(out, out_ref)
def test_pallas_call_no_grid(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, 42)
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32))()
self.assertAllClose(pids, np.array([42] * 8, dtype=np.int32))
self.assertEqual(o_ref_shape, (8,))
def test_pallas_call_no_block_spec(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
grid=(1,))()
self.assertEqual(o_ref_shape, (8,))
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
def test_block_spec_no_block_shape_and_no_index_map(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
out_specs=pl.BlockSpec(),
grid=(1,))()
self.assertEqual(o_ref_shape, (8,))
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
def test_block_spec_no_block_shape(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
out_specs=pl.BlockSpec(None, lambda i: i),
grid=(1,))()
self.assertEqual(o_ref_shape, (8,))
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
def test_block_spec_no_index_map(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
out_specs=pl.BlockSpec((4,)),
grid=(1,))()
self.assertEqual(o_ref_shape, (4,))
self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32))
def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
x = jnp.zeros(32)
@ -440,50 +548,6 @@ class PallasCallTest(PallasBaseTest):
self.assertEqual(f(x), 2.)
self.assertEqual(trace_count, 1)
def test_custom_jvp_call(self):
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
def softmax(x, axis=-1):
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
@softmax.defjvp
def softmax_jvp(axis, primals, tangents):
(x,), (x_dot,) = primals, tangents
y = softmax(x, axis)
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
m, n = 16, 32
x = random.normal(random.key(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
def softmax_kernel(x_ref, y_ref):
y_ref[:] = softmax(x_ref[:])
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
def test_block_spec_with_padding(self):
def f(*, shape, block_shape):
def kernel(o1_ref):
assert o1_ref.shape == block_shape
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
return self.pallas_call(kernel,
jax.ShapeDtypeStruct(shape, dtype=np.int32),
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
# No padding
pids = f(shape=(8,), block_shape=(2,))
self.assertAllClose(pids,
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
# Pad the last block
pids = f(shape=(8,), block_shape=(3,))
self.assertAllClose(pids,
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
# Works even if the shape is smaller than 1 block
pids = f(shape=(3,), block_shape=(8,))
self.assertAllClose(pids,
np.array([0, 0, 0], dtype=np.int32))
class PallasCallInterpreterTest(PallasCallTest):
INTERPRET = True
@ -1495,6 +1559,27 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2,
atol=grad_tol, rtol=grad_tol)
def test_custom_jvp_call(self):
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
def softmax(x, axis=-1):
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
@softmax.defjvp
def softmax_jvp(axis, primals, tangents):
(x,), (x_dot,) = primals, tangents
y = softmax(x, axis)
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
m, n = 16, 32
x = random.normal(random.key(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
def softmax_kernel(x_ref, y_ref):
y_ref[:] = softmax(x_ref[:])
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
# TODO(sharadmv): enable this when we update Triton
# def test_jvp_matmul(self):
# k1, k2 = random.split(random.key(0))
@ -1505,23 +1590,6 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
# interpret=self.INTERPRET)
# jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)
def test_slicing_block_spec(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
in_specs=[
pl.BlockSpec((None, 4), lambda _: (0, 0)),
pl.BlockSpec((None, 4), lambda _: (1, 0)),
],
grid=1,
)
def add_vectors(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[:] + y_ref[:]
xy = jnp.arange(8.).reshape((2, 4))
out = add_vectors(xy, xy)
out_ref = xy[0] + xy[1]
np.testing.assert_allclose(out, out_ref)
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
INTERPRET = True