mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters of the `BlockSpec` constructor, and the mapped block dimensions. Fix a bug where previously a missing block_shape while the index_map was present was resulting in a crash.
This commit is contained in:
parent
6f79925d61
commit
ea548e7c86
@ -33,7 +33,10 @@ for i in range(n):
|
||||
This generalizes to any tuple of integers (a length `d` grid will correspond
|
||||
to `d` nested loops).
|
||||
The kernel is executed as many times
|
||||
as `prod(grid)`. Each of these invocations is referred to as a "program".
|
||||
as `prod(grid)`.
|
||||
The default grid value `None` stands for `()`, and results in one
|
||||
kernel invocation.
|
||||
Each of these invocations is referred to as a "program".
|
||||
To access which program (i.e. which element of the grid) the kernel is currently
|
||||
executing, we use {func}`jax.experimental.pallas.program_id`.
|
||||
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
|
||||
@ -233,3 +236,47 @@ See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/t
|
||||
[309 309 309 319 319 319]]
|
||||
|
||||
```
|
||||
|
||||
A `None` value appearing as a dimension value in the `block_shape` behaves
|
||||
as the value `1`, except that the corresponding
|
||||
block axis is squeezed. In the example below, observe that the
|
||||
shape of the `o_ref` is (2,) when the block shape was specified as
|
||||
`(None, 2)` (the leading dimension was squeezed).
|
||||
|
||||
```python
|
||||
>>> def kernel(o_ref):
|
||||
... assert o_ref.shape == (2,)
|
||||
... o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
|
||||
>>> pl.pallas_call(kernel,
|
||||
... jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
|
||||
... out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
|
||||
... grid=(3, 2), interpret=True)()
|
||||
Array([[ 0, 0, 10, 10],
|
||||
[ 1, 1, 11, 11],
|
||||
[ 2, 2, 12, 12]], dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
When we construct a `BlockSpec` we can use the value `None` for the
|
||||
`block_shape` parameter, in which case the shape of the overall array
|
||||
is used as `block_shape`.
|
||||
And if we use the value `None` for the `index_map` parameter
|
||||
then a default index map function that returns a tuple of zeros is
|
||||
used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`.
|
||||
|
||||
```python
|
||||
>>> show_invocations(x_shape=(4, 4), block_shape=None, grid=(2, 3),
|
||||
... out_index_map=None)
|
||||
[[12 12 12 12]
|
||||
[12 12 12 12]
|
||||
[12 12 12 12]
|
||||
[12 12 12 12]]
|
||||
|
||||
>>> show_invocations(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
|
||||
... out_index_map=None)
|
||||
[[12 12 12 12]
|
||||
[12 12 12 12]
|
||||
[12 12 12 12]
|
||||
[12 12 12 12]]
|
||||
|
||||
```
|
||||
|
@ -211,12 +211,15 @@ class BlockSpec:
|
||||
|
||||
def compute_index(self, *args):
|
||||
assert self.index_map is not None
|
||||
assert self.block_shape is not None
|
||||
out = self.index_map(*args)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out,)
|
||||
return out
|
||||
|
||||
class NoBlockSpec:
|
||||
pass
|
||||
no_block_spec = NoBlockSpec()
|
||||
|
||||
|
||||
# A PyTree of BlockSpec | NoBlockSpec.
|
||||
BlockSpecTree = Any
|
||||
@ -310,9 +313,11 @@ def _convert_block_spec_to_block_mapping(
|
||||
return None
|
||||
if block_spec.index_map is None:
|
||||
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
|
||||
block_shape = aval.shape
|
||||
else:
|
||||
compute_index = block_spec.compute_index
|
||||
if block_spec.block_shape is None:
|
||||
block_shape = aval.shape
|
||||
else:
|
||||
block_shape = block_spec.block_shape
|
||||
block_shape = tuple(
|
||||
mapped if s is None else s for s in block_shape)
|
||||
@ -338,8 +343,7 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
||||
|
||||
|
||||
def _get_ref_avals(grid,
|
||||
in_avals: Sequence[jax_core.ShapedArray],
|
||||
def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
|
||||
in_specs: Sequence[BlockSpec],
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_avals: Sequence[jax_core.ShapedArray],
|
||||
@ -363,9 +367,9 @@ def _get_ref_avals(grid,
|
||||
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
|
||||
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
|
||||
)
|
||||
trimmed_block_shape = tuple(s for s in block_shape if s is not None)
|
||||
block_shape_unmapped = tuple(s for s in block_shape if s is not None)
|
||||
ref_aval = ref_aval.update(
|
||||
inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape))
|
||||
inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped))
|
||||
|
||||
if not jax_core.is_constant_shape(ref_aval.shape):
|
||||
raise ValueError(
|
||||
@ -382,11 +386,7 @@ def _get_ref_avals(grid,
|
||||
make_ref_aval(aval, out_spec, out_path, "output")
|
||||
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
|
||||
]
|
||||
return in_specs, in_ref_avals, out_specs, out_ref_avals
|
||||
|
||||
class NoBlockSpec:
|
||||
pass
|
||||
no_block_spec = NoBlockSpec()
|
||||
return in_ref_avals, out_ref_avals
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
@ -453,8 +453,8 @@ class GridSpec:
|
||||
)
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_tree, out_avals, out_tree)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs, in_paths,
|
||||
in_ref_avals, out_ref_avals = _get_ref_avals(
|
||||
in_avals, flat_in_specs, in_paths,
|
||||
out_avals, flat_out_specs, out_paths)
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
# Create args, kwargs pytree def
|
||||
@ -468,7 +468,7 @@ class GridSpec:
|
||||
mapped_dims=(),
|
||||
what="input",
|
||||
),
|
||||
in_specs,
|
||||
flat_in_specs,
|
||||
in_paths,
|
||||
in_ref_avals,
|
||||
)
|
||||
@ -481,7 +481,7 @@ class GridSpec:
|
||||
mapped_dims=(),
|
||||
what="output",
|
||||
),
|
||||
out_specs,
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_ref_avals,
|
||||
)
|
||||
|
@ -187,9 +187,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_avals_tree, out_avals, out_tree)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = (
|
||||
in_ref_avals, out_ref_avals = (
|
||||
pallas_core._get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
|
||||
in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
|
||||
out_avals, flat_out_specs, out_paths))
|
||||
scalar_ref_avals = [
|
||||
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
@ -209,7 +209,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
mapped_dims=(),
|
||||
what="input",
|
||||
),
|
||||
in_specs,
|
||||
flat_in_specs,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_ref_avals,
|
||||
)
|
||||
@ -222,7 +222,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
mapped_dims=(),
|
||||
what="output",
|
||||
),
|
||||
out_specs,
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_ref_avals,
|
||||
)
|
||||
|
@ -1034,12 +1034,14 @@ def pallas_call(
|
||||
See details at :ref:`pallas_grid`.
|
||||
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||
a structure matching that of the positional arguments.
|
||||
The default value for ``in_specs`` specifies the whole array for all
|
||||
inputs, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
||||
See details at :ref:`pallas_blockspec`.
|
||||
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||
a structure matching that of the outputs.
|
||||
The default value for ``out_specs`` specifies the whole array,
|
||||
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
||||
See details at :ref:`pallas_blockspec`.
|
||||
The default value for `out_specs` specifies the whole array,
|
||||
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
|
||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||
the index of the output that aliases them. These indices are in the
|
||||
flattened inputs and outputs.
|
||||
@ -1063,6 +1065,9 @@ def pallas_call(
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
||||
# TODO(necula): this canonicalization may be convenient for some usage
|
||||
# but it is lossy, because it prevents expressing functions that return
|
||||
# lists.
|
||||
if isinstance(out_shape, list):
|
||||
out_shape = tuple(out_shape)
|
||||
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
|
||||
|
@ -253,6 +253,114 @@ class PallasCallTest(PallasBaseTest):
|
||||
# TODO(necula): we normalize out_shape to a tuple, we shouldn't.
|
||||
self.assertIsInstance(res, tuple)
|
||||
|
||||
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
|
||||
def test_block_spec_with_padding(self):
|
||||
def f(*, shape, block_shape):
|
||||
def kernel(o1_ref):
|
||||
assert o1_ref.shape == block_shape
|
||||
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
|
||||
|
||||
return self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct(shape, dtype=np.int32),
|
||||
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
|
||||
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
|
||||
# No padding
|
||||
pids = f(shape=(8,), block_shape=(2,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
|
||||
# Pad the last block
|
||||
pids = f(shape=(8,), block_shape=(3,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
|
||||
# Works even if the shape is smaller than 1 block
|
||||
pids = f(shape=(3,), block_shape=(8,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 0], dtype=np.int32))
|
||||
|
||||
def test_block_spec_mapped_dimension(self):
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 4), lambda _: (0, 0)),
|
||||
pl.BlockSpec((None, 4), lambda _: (1, 0)),
|
||||
],
|
||||
grid=1,
|
||||
)
|
||||
def add_vectors(x_ref, y_ref, o_ref):
|
||||
o_ref[:] = x_ref[:] + y_ref[:]
|
||||
xy = jnp.arange(8., dtype=np.float32).reshape((2, 4))
|
||||
out = add_vectors(xy, xy)
|
||||
out_ref = xy[0] + xy[1]
|
||||
np.testing.assert_allclose(out, out_ref)
|
||||
|
||||
def test_pallas_call_no_grid(self):
|
||||
o_ref_shape = None
|
||||
def kernel(o_ref):
|
||||
nonlocal o_ref_shape
|
||||
o_ref_shape = o_ref.shape
|
||||
o_ref[...] = jnp.full(o_ref.shape, 42)
|
||||
|
||||
pids = self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct((8,), dtype=np.int32))()
|
||||
self.assertAllClose(pids, np.array([42] * 8, dtype=np.int32))
|
||||
self.assertEqual(o_ref_shape, (8,))
|
||||
|
||||
def test_pallas_call_no_block_spec(self):
|
||||
o_ref_shape = None
|
||||
def kernel(o_ref):
|
||||
nonlocal o_ref_shape
|
||||
o_ref_shape = o_ref.shape
|
||||
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||
|
||||
pids = self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||
grid=(1,))()
|
||||
self.assertEqual(o_ref_shape, (8,))
|
||||
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
|
||||
|
||||
def test_block_spec_no_block_shape_and_no_index_map(self):
|
||||
o_ref_shape = None
|
||||
def kernel(o_ref):
|
||||
nonlocal o_ref_shape
|
||||
o_ref_shape = o_ref.shape
|
||||
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||
|
||||
pids = self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||
out_specs=pl.BlockSpec(),
|
||||
grid=(1,))()
|
||||
self.assertEqual(o_ref_shape, (8,))
|
||||
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
|
||||
|
||||
def test_block_spec_no_block_shape(self):
|
||||
o_ref_shape = None
|
||||
def kernel(o_ref):
|
||||
nonlocal o_ref_shape
|
||||
o_ref_shape = o_ref.shape
|
||||
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||
|
||||
pids = self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||
out_specs=pl.BlockSpec(None, lambda i: i),
|
||||
grid=(1,))()
|
||||
self.assertEqual(o_ref_shape, (8,))
|
||||
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
|
||||
|
||||
def test_block_spec_no_index_map(self):
|
||||
o_ref_shape = None
|
||||
def kernel(o_ref):
|
||||
nonlocal o_ref_shape
|
||||
o_ref_shape = o_ref.shape
|
||||
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||
|
||||
pids = self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||
out_specs=pl.BlockSpec((4,)),
|
||||
grid=(1,))()
|
||||
self.assertEqual(o_ref_shape, (4,))
|
||||
self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32))
|
||||
|
||||
def test_hoisted_consts(self):
|
||||
# See https://github.com/google/jax/issues/21557.
|
||||
x = jnp.zeros(32)
|
||||
@ -440,50 +548,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.assertEqual(f(x), 2.)
|
||||
self.assertEqual(trace_count, 1)
|
||||
|
||||
def test_custom_jvp_call(self):
|
||||
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||
def softmax(x, axis=-1):
|
||||
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
|
||||
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
|
||||
|
||||
@softmax.defjvp
|
||||
def softmax_jvp(axis, primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
y = softmax(x, axis)
|
||||
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
|
||||
|
||||
m, n = 16, 32
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
|
||||
@functools.partial(self.pallas_call, out_shape=x, grid=1)
|
||||
def softmax_kernel(x_ref, y_ref):
|
||||
y_ref[:] = softmax(x_ref[:])
|
||||
|
||||
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
|
||||
|
||||
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
|
||||
def test_block_spec_with_padding(self):
|
||||
def f(*, shape, block_shape):
|
||||
def kernel(o1_ref):
|
||||
assert o1_ref.shape == block_shape
|
||||
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
|
||||
|
||||
return self.pallas_call(kernel,
|
||||
jax.ShapeDtypeStruct(shape, dtype=np.int32),
|
||||
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
|
||||
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
|
||||
# No padding
|
||||
pids = f(shape=(8,), block_shape=(2,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
|
||||
# Pad the last block
|
||||
pids = f(shape=(8,), block_shape=(3,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
|
||||
# Works even if the shape is smaller than 1 block
|
||||
pids = f(shape=(3,), block_shape=(8,))
|
||||
self.assertAllClose(pids,
|
||||
np.array([0, 0, 0], dtype=np.int32))
|
||||
|
||||
class PallasCallInterpreterTest(PallasCallTest):
|
||||
INTERPRET = True
|
||||
@ -1495,6 +1559,27 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
||||
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2,
|
||||
atol=grad_tol, rtol=grad_tol)
|
||||
|
||||
def test_custom_jvp_call(self):
|
||||
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||
def softmax(x, axis=-1):
|
||||
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
|
||||
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
|
||||
|
||||
@softmax.defjvp
|
||||
def softmax_jvp(axis, primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
y = softmax(x, axis)
|
||||
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
|
||||
|
||||
m, n = 16, 32
|
||||
x = random.normal(random.key(0), (m, n))
|
||||
|
||||
@functools.partial(self.pallas_call, out_shape=x, grid=1)
|
||||
def softmax_kernel(x_ref, y_ref):
|
||||
y_ref[:] = softmax(x_ref[:])
|
||||
|
||||
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
|
||||
|
||||
# TODO(sharadmv): enable this when we update Triton
|
||||
# def test_jvp_matmul(self):
|
||||
# k1, k2 = random.split(random.key(0))
|
||||
@ -1505,23 +1590,6 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
||||
# interpret=self.INTERPRET)
|
||||
# jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)
|
||||
|
||||
def test_slicing_block_spec(self):
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 4), lambda _: (0, 0)),
|
||||
pl.BlockSpec((None, 4), lambda _: (1, 0)),
|
||||
],
|
||||
grid=1,
|
||||
)
|
||||
def add_vectors(x_ref, y_ref, o_ref):
|
||||
o_ref[:] = x_ref[:] + y_ref[:]
|
||||
xy = jnp.arange(8.).reshape((2, 4))
|
||||
out = add_vectors(xy, xy)
|
||||
out_ref = xy[0] + xy[1]
|
||||
np.testing.assert_allclose(out, out_ref)
|
||||
|
||||
|
||||
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user