From ea548e7c86d7d82e6ddaf79a27e7aed280c78785 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 9 Jul 2024 12:43:37 +0300 Subject: [PATCH] [pallas] Add more documentation and tests for BlockSpec. This PR deals with the default values for the parameters of the `BlockSpec` constructor, and the mapped block dimensions. Fix a bug where previously a missing block_shape while the index_map was present was resulting in a crash. --- docs/pallas/grid_blockspec.md | 49 ++++++++- jax/_src/pallas/core.py | 30 +++--- jax/_src/pallas/mosaic/core.py | 8 +- jax/_src/pallas/pallas_call.py | 9 +- tests/pallas/pallas_test.py | 190 ++++++++++++++++++++++----------- 5 files changed, 203 insertions(+), 83 deletions(-) diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index 11151983f..b58c991ee 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -33,7 +33,10 @@ for i in range(n): This generalizes to any tuple of integers (a length `d` grid will correspond to `d` nested loops). The kernel is executed as many times -as `prod(grid)`. Each of these invocations is referred to as a "program". +as `prod(grid)`. +The default grid value `None` stands for `()`, and results in one +kernel invocation. +Each of these invocations is referred to as a "program". To access which program (i.e. which element of the grid) the kernel is currently executing, we use {func}`jax.experimental.pallas.program_id`. For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and @@ -233,3 +236,47 @@ See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/t [309 309 309 319 319 319]] ``` + +A `None` value appearing as a dimension value in the `block_shape` behaves +as the value `1`, except that the corresponding +block axis is squeezed. In the example below, observe that the +shape of the `o_ref` is (2,) when the block shape was specified as +`(None, 2)` (the leading dimension was squeezed). + +```python +>>> def kernel(o_ref): +... assert o_ref.shape == (2,) +... o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0)) +>>> pl.pallas_call(kernel, +... jax.ShapeDtypeStruct((3, 4), dtype=np.int32), +... out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)), +... grid=(3, 2), interpret=True)() +Array([[ 0, 0, 10, 10], + [ 1, 1, 11, 11], + [ 2, 2, 12, 12]], dtype=int32) + +``` + +When we construct a `BlockSpec` we can use the value `None` for the +`block_shape` parameter, in which case the shape of the overall array +is used as `block_shape`. +And if we use the value `None` for the `index_map` parameter +then a default index map function that returns a tuple of zeros is +used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`. + +```python +>>> show_invocations(x_shape=(4, 4), block_shape=None, grid=(2, 3), +... out_index_map=None) +[[12 12 12 12] + [12 12 12 12] + [12 12 12 12] + [12 12 12 12]] + +>>> show_invocations(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3), +... out_index_map=None) +[[12 12 12 12] + [12 12 12 12] + [12 12 12 12] + [12 12 12 12]] + +``` diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 7bc99c369..7b5f0b99e 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -211,12 +211,15 @@ class BlockSpec: def compute_index(self, *args): assert self.index_map is not None - assert self.block_shape is not None out = self.index_map(*args) if not isinstance(out, tuple): out = (out,) return out +class NoBlockSpec: + pass +no_block_spec = NoBlockSpec() + # A PyTree of BlockSpec | NoBlockSpec. BlockSpecTree = Any @@ -310,9 +313,11 @@ def _convert_block_spec_to_block_mapping( return None if block_spec.index_map is None: compute_index = lambda *args, **kwargs: (0,) * len(aval.shape) - block_shape = aval.shape else: compute_index = block_spec.compute_index + if block_spec.block_shape is None: + block_shape = aval.shape + else: block_shape = block_spec.block_shape block_shape = tuple( mapped if s is None else s for s in block_shape) @@ -338,8 +343,7 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None return ref.update(inner_aval=ref.inner_aval.update(shape=shape)) -def _get_ref_avals(grid, - in_avals: Sequence[jax_core.ShapedArray], +def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray], in_specs: Sequence[BlockSpec], in_paths: Sequence[tree_util.KeyPath], out_avals: Sequence[jax_core.ShapedArray], @@ -363,9 +367,9 @@ def _get_ref_avals(grid, f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) " f"must have the same number of dimensions as the array shape {ref_aval.shape}" ) - trimmed_block_shape = tuple(s for s in block_shape if s is not None) + block_shape_unmapped = tuple(s for s in block_shape if s is not None) ref_aval = ref_aval.update( - inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape)) + inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped)) if not jax_core.is_constant_shape(ref_aval.shape): raise ValueError( @@ -382,11 +386,7 @@ def _get_ref_avals(grid, make_ref_aval(aval, out_spec, out_path, "output") for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths) ] - return in_specs, in_ref_avals, out_specs, out_ref_avals - -class NoBlockSpec: - pass -no_block_spec = NoBlockSpec() + return in_ref_avals, out_ref_avals @dataclasses.dataclass(init=False, unsafe_hash=True) @@ -453,8 +453,8 @@ class GridSpec: ) flat_in_specs, flat_out_specs = self._get_in_out_specs( in_avals, in_tree, out_avals, out_tree) - in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals( - self.grid, in_avals, flat_in_specs, in_paths, + in_ref_avals, out_ref_avals = _get_ref_avals( + in_avals, flat_in_specs, in_paths, out_avals, flat_out_specs, out_paths) grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) # Create args, kwargs pytree def @@ -468,7 +468,7 @@ class GridSpec: mapped_dims=(), what="input", ), - in_specs, + flat_in_specs, in_paths, in_ref_avals, ) @@ -481,7 +481,7 @@ class GridSpec: mapped_dims=(), what="output", ), - out_specs, + flat_out_specs, out_paths, out_ref_avals, ) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 09979faf8..20a7f8207 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -187,9 +187,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals)) flat_in_specs, flat_out_specs = self._get_in_out_specs( in_avals, in_avals_tree, out_avals, out_tree) - in_specs, in_ref_avals, out_specs, out_ref_avals = ( + in_ref_avals, out_ref_avals = ( pallas_core._get_ref_avals( - self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:], + in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:], out_avals, flat_out_specs, out_paths)) scalar_ref_avals = [ AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), @@ -209,7 +209,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): mapped_dims=(), what="input", ), - in_specs, + flat_in_specs, in_paths[num_flat_scalar_prefetch:], in_ref_avals, ) @@ -222,7 +222,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): mapped_dims=(), what="output", ), - out_specs, + flat_out_specs, out_paths, out_ref_avals, ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 8dfb4e33a..e2d11fa73 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1034,12 +1034,14 @@ def pallas_call( See details at :ref:`pallas_grid`. in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with a structure matching that of the positional arguments. + The default value for ``in_specs`` specifies the whole array for all + inputs, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``. See details at :ref:`pallas_blockspec`. out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with a structure matching that of the outputs. + The default value for ``out_specs`` specifies the whole array, + e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``. See details at :ref:`pallas_blockspec`. - The default value for `out_specs` specifies the whole array, - e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. @@ -1063,6 +1065,9 @@ def pallas_call( if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs) grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds() + # TODO(necula): this canonicalization may be convenient for some usage + # but it is lossy, because it prevents expressing functions that return + # lists. if isinstance(out_shape, list): out_shape = tuple(out_shape) flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index f2ddabc74..c2b4a1095 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -253,6 +253,114 @@ class PallasCallTest(PallasBaseTest): # TODO(necula): we normalize out_shape to a tuple, we shouldn't. self.assertIsInstance(res, tuple) + @jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure + def test_block_spec_with_padding(self): + def f(*, shape, block_shape): + def kernel(o1_ref): + assert o1_ref.shape == block_shape + o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) + + return self.pallas_call(kernel, + jax.ShapeDtypeStruct(shape, dtype=np.int32), + grid=((shape[0] + block_shape[0] - 1) // block_shape[0],), + out_specs=pl.BlockSpec(block_shape, lambda i: i))() + # No padding + pids = f(shape=(8,), block_shape=(2,)) + self.assertAllClose(pids, + np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32)) + # Pad the last block + pids = f(shape=(8,), block_shape=(3,)) + self.assertAllClose(pids, + np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32)) + # Works even if the shape is smaller than 1 block + pids = f(shape=(3,), block_shape=(8,)) + self.assertAllClose(pids, + np.array([0, 0, 0], dtype=np.int32)) + + def test_block_spec_mapped_dimension(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + in_specs=[ + pl.BlockSpec((None, 4), lambda _: (0, 0)), + pl.BlockSpec((None, 4), lambda _: (1, 0)), + ], + grid=1, + ) + def add_vectors(x_ref, y_ref, o_ref): + o_ref[:] = x_ref[:] + y_ref[:] + xy = jnp.arange(8., dtype=np.float32).reshape((2, 4)) + out = add_vectors(xy, xy) + out_ref = xy[0] + xy[1] + np.testing.assert_allclose(out, out_ref) + + def test_pallas_call_no_grid(self): + o_ref_shape = None + def kernel(o_ref): + nonlocal o_ref_shape + o_ref_shape = o_ref.shape + o_ref[...] = jnp.full(o_ref.shape, 42) + + pids = self.pallas_call(kernel, + jax.ShapeDtypeStruct((8,), dtype=np.int32))() + self.assertAllClose(pids, np.array([42] * 8, dtype=np.int32)) + self.assertEqual(o_ref_shape, (8,)) + + def test_pallas_call_no_block_spec(self): + o_ref_shape = None + def kernel(o_ref): + nonlocal o_ref_shape + o_ref_shape = o_ref.shape + o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0)) + + pids = self.pallas_call(kernel, + jax.ShapeDtypeStruct((8,), dtype=np.int32), + grid=(1,))() + self.assertEqual(o_ref_shape, (8,)) + self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32)) + + def test_block_spec_no_block_shape_and_no_index_map(self): + o_ref_shape = None + def kernel(o_ref): + nonlocal o_ref_shape + o_ref_shape = o_ref.shape + o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0)) + + pids = self.pallas_call(kernel, + jax.ShapeDtypeStruct((8,), dtype=np.int32), + out_specs=pl.BlockSpec(), + grid=(1,))() + self.assertEqual(o_ref_shape, (8,)) + self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32)) + + def test_block_spec_no_block_shape(self): + o_ref_shape = None + def kernel(o_ref): + nonlocal o_ref_shape + o_ref_shape = o_ref.shape + o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0)) + + pids = self.pallas_call(kernel, + jax.ShapeDtypeStruct((8,), dtype=np.int32), + out_specs=pl.BlockSpec(None, lambda i: i), + grid=(1,))() + self.assertEqual(o_ref_shape, (8,)) + self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32)) + + def test_block_spec_no_index_map(self): + o_ref_shape = None + def kernel(o_ref): + nonlocal o_ref_shape + o_ref_shape = o_ref.shape + o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0)) + + pids = self.pallas_call(kernel, + jax.ShapeDtypeStruct((8,), dtype=np.int32), + out_specs=pl.BlockSpec((4,)), + grid=(1,))() + self.assertEqual(o_ref_shape, (4,)) + self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32)) + def test_hoisted_consts(self): # See https://github.com/google/jax/issues/21557. x = jnp.zeros(32) @@ -440,50 +548,6 @@ class PallasCallTest(PallasBaseTest): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) - def test_custom_jvp_call(self): - @functools.partial(jax.custom_jvp, nondiff_argnums=(1,)) - def softmax(x, axis=-1): - unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True)) - return unnormalized / jnp.sum(unnormalized, axis, keepdims=True) - - @softmax.defjvp - def softmax_jvp(axis, primals, tangents): - (x,), (x_dot,) = primals, tangents - y = softmax(x, axis) - return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True)) - - m, n = 16, 32 - x = random.normal(random.key(0), (m, n)) - - @functools.partial(self.pallas_call, out_shape=x, grid=1) - def softmax_kernel(x_ref, y_ref): - y_ref[:] = softmax(x_ref[:]) - - np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7) - - @jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure - def test_block_spec_with_padding(self): - def f(*, shape, block_shape): - def kernel(o1_ref): - assert o1_ref.shape == block_shape - o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) - - return self.pallas_call(kernel, - jax.ShapeDtypeStruct(shape, dtype=np.int32), - grid=((shape[0] + block_shape[0] - 1) // block_shape[0],), - out_specs=pl.BlockSpec(block_shape, lambda i: i))() - # No padding - pids = f(shape=(8,), block_shape=(2,)) - self.assertAllClose(pids, - np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32)) - # Pad the last block - pids = f(shape=(8,), block_shape=(3,)) - self.assertAllClose(pids, - np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32)) - # Works even if the shape is smaller than 1 block - pids = f(shape=(3,), block_shape=(8,)) - self.assertAllClose(pids, - np.array([0, 0, 0], dtype=np.int32)) class PallasCallInterpreterTest(PallasCallTest): INTERPRET = True @@ -1495,6 +1559,27 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2, atol=grad_tol, rtol=grad_tol) + def test_custom_jvp_call(self): + @functools.partial(jax.custom_jvp, nondiff_argnums=(1,)) + def softmax(x, axis=-1): + unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True)) + return unnormalized / jnp.sum(unnormalized, axis, keepdims=True) + + @softmax.defjvp + def softmax_jvp(axis, primals, tangents): + (x,), (x_dot,) = primals, tangents + y = softmax(x, axis) + return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True)) + + m, n = 16, 32 + x = random.normal(random.key(0), (m, n)) + + @functools.partial(self.pallas_call, out_shape=x, grid=1) + def softmax_kernel(x_ref, y_ref): + y_ref[:] = softmax(x_ref[:]) + + np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7) + # TODO(sharadmv): enable this when we update Triton # def test_jvp_matmul(self): # k1, k2 = random.split(random.key(0)) @@ -1505,23 +1590,6 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): # interpret=self.INTERPRET) # jtu.check_grads(mm, (x, y), modes=["fwd"], order=1) - def test_slicing_block_spec(self): - @functools.partial( - self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), - in_specs=[ - pl.BlockSpec((None, 4), lambda _: (0, 0)), - pl.BlockSpec((None, 4), lambda _: (1, 0)), - ], - grid=1, - ) - def add_vectors(x_ref, y_ref, o_ref): - o_ref[:] = x_ref[:] + y_ref[:] - xy = jnp.arange(8.).reshape((2, 4)) - out = add_vectors(xy, xy) - out_ref = xy[0] + xy[1] - np.testing.assert_allclose(out, out_ref) - class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest): INTERPRET = True