[pallas] Add more documentation and tests for BlockSpec.

This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
This commit is contained in:
George Necula 2024-07-09 12:43:37 +03:00
parent 6f79925d61
commit ea548e7c86
5 changed files with 203 additions and 83 deletions

View File

@ -33,7 +33,10 @@ for i in range(n):
This generalizes to any tuple of integers (a length `d` grid will correspond This generalizes to any tuple of integers (a length `d` grid will correspond
to `d` nested loops). to `d` nested loops).
The kernel is executed as many times The kernel is executed as many times
as `prod(grid)`. Each of these invocations is referred to as a "program". as `prod(grid)`.
The default grid value `None` stands for `()`, and results in one
kernel invocation.
Each of these invocations is referred to as a "program".
To access which program (i.e. which element of the grid) the kernel is currently To access which program (i.e. which element of the grid) the kernel is currently
executing, we use {func}`jax.experimental.pallas.program_id`. executing, we use {func}`jax.experimental.pallas.program_id`.
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
@ -233,3 +236,47 @@ See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/t
[309 309 309 319 319 319]] [309 309 309 319 319 319]]
``` ```
A `None` value appearing as a dimension value in the `block_shape` behaves
as the value `1`, except that the corresponding
block axis is squeezed. In the example below, observe that the
shape of the `o_ref` is (2,) when the block shape was specified as
`(None, 2)` (the leading dimension was squeezed).
```python
>>> def kernel(o_ref):
... assert o_ref.shape == (2,)
... o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
>>> pl.pallas_call(kernel,
... jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
... out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
... grid=(3, 2), interpret=True)()
Array([[ 0, 0, 10, 10],
[ 1, 1, 11, 11],
[ 2, 2, 12, 12]], dtype=int32)
```
When we construct a `BlockSpec` we can use the value `None` for the
`block_shape` parameter, in which case the shape of the overall array
is used as `block_shape`.
And if we use the value `None` for the `index_map` parameter
then a default index map function that returns a tuple of zeros is
used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`.
```python
>>> show_invocations(x_shape=(4, 4), block_shape=None, grid=(2, 3),
... out_index_map=None)
[[12 12 12 12]
[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]
>>> show_invocations(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
... out_index_map=None)
[[12 12 12 12]
[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]
```

View File

@ -211,12 +211,15 @@ class BlockSpec:
def compute_index(self, *args): def compute_index(self, *args):
assert self.index_map is not None assert self.index_map is not None
assert self.block_shape is not None
out = self.index_map(*args) out = self.index_map(*args)
if not isinstance(out, tuple): if not isinstance(out, tuple):
out = (out,) out = (out,)
return out return out
class NoBlockSpec:
pass
no_block_spec = NoBlockSpec()
# A PyTree of BlockSpec | NoBlockSpec. # A PyTree of BlockSpec | NoBlockSpec.
BlockSpecTree = Any BlockSpecTree = Any
@ -310,9 +313,11 @@ def _convert_block_spec_to_block_mapping(
return None return None
if block_spec.index_map is None: if block_spec.index_map is None:
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape) compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
block_shape = aval.shape
else: else:
compute_index = block_spec.compute_index compute_index = block_spec.compute_index
if block_spec.block_shape is None:
block_shape = aval.shape
else:
block_shape = block_spec.block_shape block_shape = block_spec.block_shape
block_shape = tuple( block_shape = tuple(
mapped if s is None else s for s in block_shape) mapped if s is None else s for s in block_shape)
@ -338,8 +343,7 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
return ref.update(inner_aval=ref.inner_aval.update(shape=shape)) return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
def _get_ref_avals(grid, def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
in_avals: Sequence[jax_core.ShapedArray],
in_specs: Sequence[BlockSpec], in_specs: Sequence[BlockSpec],
in_paths: Sequence[tree_util.KeyPath], in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.ShapedArray], out_avals: Sequence[jax_core.ShapedArray],
@ -363,9 +367,9 @@ def _get_ref_avals(grid,
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) " f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {ref_aval.shape}" f"must have the same number of dimensions as the array shape {ref_aval.shape}"
) )
trimmed_block_shape = tuple(s for s in block_shape if s is not None) block_shape_unmapped = tuple(s for s in block_shape if s is not None)
ref_aval = ref_aval.update( ref_aval = ref_aval.update(
inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape)) inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped))
if not jax_core.is_constant_shape(ref_aval.shape): if not jax_core.is_constant_shape(ref_aval.shape):
raise ValueError( raise ValueError(
@ -382,11 +386,7 @@ def _get_ref_avals(grid,
make_ref_aval(aval, out_spec, out_path, "output") make_ref_aval(aval, out_spec, out_path, "output")
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths) for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
] ]
return in_specs, in_ref_avals, out_specs, out_ref_avals return in_ref_avals, out_ref_avals
class NoBlockSpec:
pass
no_block_spec = NoBlockSpec()
@dataclasses.dataclass(init=False, unsafe_hash=True) @dataclasses.dataclass(init=False, unsafe_hash=True)
@ -453,8 +453,8 @@ class GridSpec:
) )
flat_in_specs, flat_out_specs = self._get_in_out_specs( flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree) in_avals, in_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals( in_ref_avals, out_ref_avals = _get_ref_avals(
self.grid, in_avals, flat_in_specs, in_paths, in_avals, flat_in_specs, in_paths,
out_avals, flat_out_specs, out_paths) out_avals, flat_out_specs, out_paths)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def # Create args, kwargs pytree def
@ -468,7 +468,7 @@ class GridSpec:
mapped_dims=(), mapped_dims=(),
what="input", what="input",
), ),
in_specs, flat_in_specs,
in_paths, in_paths,
in_ref_avals, in_ref_avals,
) )
@ -481,7 +481,7 @@ class GridSpec:
mapped_dims=(), mapped_dims=(),
what="output", what="output",
), ),
out_specs, flat_out_specs,
out_paths, out_paths,
out_ref_avals, out_ref_avals,
) )

View File

@ -187,9 +187,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals)) in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
flat_in_specs, flat_out_specs = self._get_in_out_specs( flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_avals_tree, out_avals, out_tree) in_avals, in_avals_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = ( in_ref_avals, out_ref_avals = (
pallas_core._get_ref_avals( pallas_core._get_ref_avals(
self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:], in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
out_avals, flat_out_specs, out_paths)) out_avals, flat_out_specs, out_paths))
scalar_ref_avals = [ scalar_ref_avals = [
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
@ -209,7 +209,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
mapped_dims=(), mapped_dims=(),
what="input", what="input",
), ),
in_specs, flat_in_specs,
in_paths[num_flat_scalar_prefetch:], in_paths[num_flat_scalar_prefetch:],
in_ref_avals, in_ref_avals,
) )
@ -222,7 +222,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
mapped_dims=(), mapped_dims=(),
what="output", what="output",
), ),
out_specs, flat_out_specs,
out_paths, out_paths,
out_ref_avals, out_ref_avals,
) )

View File

@ -1034,12 +1034,14 @@ def pallas_call(
See details at :ref:`pallas_grid`. See details at :ref:`pallas_grid`.
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the positional arguments. a structure matching that of the positional arguments.
The default value for ``in_specs`` specifies the whole array for all
inputs, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
See details at :ref:`pallas_blockspec`. See details at :ref:`pallas_blockspec`.
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the outputs. a structure matching that of the outputs.
The default value for ``out_specs`` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
See details at :ref:`pallas_blockspec`. See details at :ref:`pallas_blockspec`.
The default value for `out_specs` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
input_output_aliases: a dictionary mapping the index of some inputs to input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them. These indices are in the the index of the output that aliases them. These indices are in the
flattened inputs and outputs. flattened inputs and outputs.
@ -1063,6 +1065,9 @@ def pallas_call(
if grid_spec is None: if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs) grid_spec = GridSpec(grid, in_specs, out_specs)
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds() grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
# TODO(necula): this canonicalization may be convenient for some usage
# but it is lossy, because it prevents expressing functions that return
# lists.
if isinstance(out_shape, list): if isinstance(out_shape, list):
out_shape = tuple(out_shape) out_shape = tuple(out_shape)
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape) flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)

View File

@ -253,6 +253,114 @@ class PallasCallTest(PallasBaseTest):
# TODO(necula): we normalize out_shape to a tuple, we shouldn't. # TODO(necula): we normalize out_shape to a tuple, we shouldn't.
self.assertIsInstance(res, tuple) self.assertIsInstance(res, tuple)
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
def test_block_spec_with_padding(self):
def f(*, shape, block_shape):
def kernel(o1_ref):
assert o1_ref.shape == block_shape
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
return self.pallas_call(kernel,
jax.ShapeDtypeStruct(shape, dtype=np.int32),
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
# No padding
pids = f(shape=(8,), block_shape=(2,))
self.assertAllClose(pids,
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
# Pad the last block
pids = f(shape=(8,), block_shape=(3,))
self.assertAllClose(pids,
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
# Works even if the shape is smaller than 1 block
pids = f(shape=(3,), block_shape=(8,))
self.assertAllClose(pids,
np.array([0, 0, 0], dtype=np.int32))
def test_block_spec_mapped_dimension(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
in_specs=[
pl.BlockSpec((None, 4), lambda _: (0, 0)),
pl.BlockSpec((None, 4), lambda _: (1, 0)),
],
grid=1,
)
def add_vectors(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[:] + y_ref[:]
xy = jnp.arange(8., dtype=np.float32).reshape((2, 4))
out = add_vectors(xy, xy)
out_ref = xy[0] + xy[1]
np.testing.assert_allclose(out, out_ref)
def test_pallas_call_no_grid(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, 42)
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32))()
self.assertAllClose(pids, np.array([42] * 8, dtype=np.int32))
self.assertEqual(o_ref_shape, (8,))
def test_pallas_call_no_block_spec(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
grid=(1,))()
self.assertEqual(o_ref_shape, (8,))
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
def test_block_spec_no_block_shape_and_no_index_map(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
out_specs=pl.BlockSpec(),
grid=(1,))()
self.assertEqual(o_ref_shape, (8,))
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
def test_block_spec_no_block_shape(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
out_specs=pl.BlockSpec(None, lambda i: i),
grid=(1,))()
self.assertEqual(o_ref_shape, (8,))
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
def test_block_spec_no_index_map(self):
o_ref_shape = None
def kernel(o_ref):
nonlocal o_ref_shape
o_ref_shape = o_ref.shape
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
pids = self.pallas_call(kernel,
jax.ShapeDtypeStruct((8,), dtype=np.int32),
out_specs=pl.BlockSpec((4,)),
grid=(1,))()
self.assertEqual(o_ref_shape, (4,))
self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32))
def test_hoisted_consts(self): def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557. # See https://github.com/google/jax/issues/21557.
x = jnp.zeros(32) x = jnp.zeros(32)
@ -440,50 +548,6 @@ class PallasCallTest(PallasBaseTest):
self.assertEqual(f(x), 2.) self.assertEqual(f(x), 2.)
self.assertEqual(trace_count, 1) self.assertEqual(trace_count, 1)
def test_custom_jvp_call(self):
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
def softmax(x, axis=-1):
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
@softmax.defjvp
def softmax_jvp(axis, primals, tangents):
(x,), (x_dot,) = primals, tangents
y = softmax(x, axis)
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
m, n = 16, 32
x = random.normal(random.key(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
def softmax_kernel(x_ref, y_ref):
y_ref[:] = softmax(x_ref[:])
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
def test_block_spec_with_padding(self):
def f(*, shape, block_shape):
def kernel(o1_ref):
assert o1_ref.shape == block_shape
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
return self.pallas_call(kernel,
jax.ShapeDtypeStruct(shape, dtype=np.int32),
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
# No padding
pids = f(shape=(8,), block_shape=(2,))
self.assertAllClose(pids,
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
# Pad the last block
pids = f(shape=(8,), block_shape=(3,))
self.assertAllClose(pids,
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
# Works even if the shape is smaller than 1 block
pids = f(shape=(3,), block_shape=(8,))
self.assertAllClose(pids,
np.array([0, 0, 0], dtype=np.int32))
class PallasCallInterpreterTest(PallasCallTest): class PallasCallInterpreterTest(PallasCallTest):
INTERPRET = True INTERPRET = True
@ -1495,6 +1559,27 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2, jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2,
atol=grad_tol, rtol=grad_tol) atol=grad_tol, rtol=grad_tol)
def test_custom_jvp_call(self):
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
def softmax(x, axis=-1):
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
@softmax.defjvp
def softmax_jvp(axis, primals, tangents):
(x,), (x_dot,) = primals, tangents
y = softmax(x, axis)
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
m, n = 16, 32
x = random.normal(random.key(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
def softmax_kernel(x_ref, y_ref):
y_ref[:] = softmax(x_ref[:])
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
# TODO(sharadmv): enable this when we update Triton # TODO(sharadmv): enable this when we update Triton
# def test_jvp_matmul(self): # def test_jvp_matmul(self):
# k1, k2 = random.split(random.key(0)) # k1, k2 = random.split(random.key(0))
@ -1505,23 +1590,6 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
# interpret=self.INTERPRET) # interpret=self.INTERPRET)
# jtu.check_grads(mm, (x, y), modes=["fwd"], order=1) # jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)
def test_slicing_block_spec(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
in_specs=[
pl.BlockSpec((None, 4), lambda _: (0, 0)),
pl.BlockSpec((None, 4), lambda _: (1, 0)),
],
grid=1,
)
def add_vectors(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[:] + y_ref[:]
xy = jnp.arange(8.).reshape((2, 4))
out = add_vectors(xy, xy)
out_ref = xy[0] + xy[1]
np.testing.assert_allclose(out, out_ref)
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest): class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
INTERPRET = True INTERPRET = True