From e90336947a7f763226e8609ea96bc49a64fdb2c9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 18 Sep 2024 05:25:37 -0700 Subject: [PATCH] Pulled `scratch_shapes` into `GridSpec` It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton. PiperOrigin-RevId: 675950199 --- docs/pallas/CHANGELOG.md | 16 ++++++++-------- jax/_src/pallas/core.py | 25 ++++++++++++++++--------- jax/_src/pallas/mosaic/core.py | 22 +++++----------------- jax/_src/pallas/mosaic_gpu/__init__.py | 1 - jax/_src/pallas/mosaic_gpu/core.py | 20 -------------------- jax/_src/pallas/pallas_call.py | 16 +++++++++++++--- jax/_src/pallas/triton/lowering.py | 4 ++++ tests/pallas/mosaic_gpu_test.py | 13 ++++++------- 8 files changed, 52 insertions(+), 65 deletions(-) diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index b39d0211c..43ba3ebd6 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -19,18 +19,22 @@ Remember to align the itemized text with the first line of an item within a list to be scalars. The restrictions on the arguments are backend-specific: Non-scalar arguments are currently only supported on GPU, when using Triton. +* Deprecations + +* New functionality + + * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`, + a PyTree specifying backend-specific temporary objects needed by the + kernel, for example, buffers, synchronization primitives etc. + ## Released with jax 0.4.33 (September 16, 2024) ## Released with jax 0.4.32 (September 11, 2024) -## Released with jax 0.4.32 - * Changes * The kernel function is not allowed to close over constants. Instead, all the needed arrays must be passed as inputs, with proper block specs ({jax-issue}`#22746`). -* Deprecations - * New functionality * Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map. @@ -56,10 +60,6 @@ Remember to align the itemized text with the first line of an item within a list * Previously it was possible to import many APIs that are meant to be private, as `jax.experimental.pallas.pallas`. This is not possible anymore. - -* Deprecations - - * New Functionality * Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`. * Improved error messages for the {func}`jax.experimental.pallas.pallas_call` diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 56c47b940..f354dd83f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -728,7 +728,16 @@ def _convert_block_spec_to_block_mapping( index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) -@dataclasses.dataclass(init=False) + +class ScratchShape(Protocol): + def get_aval(self) -> jax_core.AbstractValue: + ... + + +ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]] + + +@dataclasses.dataclass(init=False, kw_only=True) class GridSpec: """Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`. @@ -741,12 +750,14 @@ class GridSpec: grid_names: tuple[Hashable, ...] | None in_specs: BlockSpecTree out_specs: BlockSpecTree + scratch_shapes: ScratchShapeTree = () def __init__( self, grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), ): # Be more lenient for in/out_specs if isinstance(in_specs, list): @@ -758,6 +769,7 @@ class GridSpec: self.in_specs = in_specs self.out_specs = out_specs + self.scratch_shapes = tuple(scratch_shapes) grid_names = None if isinstance(grid, int): @@ -773,9 +785,6 @@ class GridSpec: self.grid = grid # type: ignore self.grid_names = grid_names - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - assert False # Not needed in GridSpec - def _make_scalar_ref_aval(self, aval): assert False # Not needed in GridSpec @@ -820,12 +829,10 @@ def get_grid_mapping( else: num_flat_scalar_prefetch = 0 jaxpr_scalar_ref_avals = () - - scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ()) - if scratch_shapes: + if grid_spec.scratch_shapes: flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( - scratch_shapes) - flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes) + grid_spec.scratch_shapes) + flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes) num_flat_scratch_operands = len(flat_scratch_avals) jaxpr_scratch_avals = tree_util.tree_unflatten( scratch_tree, flat_scratch_avals) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 61b1dc435..b2b892a64 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ from collections.abc import Sequence import dataclasses import enum import functools -from typing import Any, ClassVar, Hashable, Literal +from typing import Any, ClassVar, Literal import jax from jax._src import core as jax_core @@ -39,6 +39,7 @@ BlockSpec = pallas_core.BlockSpec BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec +ScratchShapeTree = pallas_core.ScratchShapeTree AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping @@ -174,14 +175,9 @@ class MemoryRef: jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) -@dataclasses.dataclass(init=False, unsafe_hash=True) +@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): - grid: TupleGrid - grid_names: tuple[Hashable, ...] | None num_scalar_prefetch: int - in_specs: pallas_core.BlockSpecTree - out_specs: pallas_core.BlockSpecTree - scratch_shapes: tuple[Any, ...] def __init__( self, @@ -189,9 +185,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, - scratch_shapes: Any | Sequence[Any] = () + scratch_shapes: ScratchShapeTree = () ): - super().__init__(grid, in_specs, out_specs) + super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) @@ -199,14 +195,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), TPUMemorySpace.SMEM) - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, MemoryRef): - return obj.get_aval() - if isinstance(obj, SemaphoreType): - return obj.get_aval() - raise ValueError(f"No registered conversion for {type(obj)}. " - "Only VMEM and SemaphoreType are supported.") - @dataclasses.dataclass(frozen=True) class TensorCore: diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 11258f741..1bd512834 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -17,7 +17,6 @@ from jax._src.pallas.mosaic_gpu.core import Barrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUGridSpec from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 3ef205d33..5a046afea 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -150,26 +150,6 @@ class GPUBlockSpec(pallas_core.BlockSpec): ) -@dataclasses.dataclass(init=False, kw_only=True) -class GPUGridSpec(pallas_core.GridSpec): - scratch_shapes: Sequence[Any] - - def __init__( - self, - grid: pallas_core.Grid = (), - in_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, - out_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, - scratch_shapes: Sequence[Any] = () - ): - super().__init__(grid, in_specs, out_specs) - self.scratch_shapes = tuple(scratch_shapes) - - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, (MemoryRef, Barrier)): - return obj.get_aval() - raise TypeError(f"Cannot convert {obj} to an abstract value") - - # TODO(b/354568887): Cosolidate this with TPU's MemoryRef. @dataclasses.dataclass(frozen=True) class MemoryRef: diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b69fb03f0..206c0cdee 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -62,6 +62,7 @@ BlockSpec = pallas_core.BlockSpec BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec +ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate # See the docstring for GridMapping for the calling convention @@ -1233,6 +1234,7 @@ def pallas_call( grid: TupleGrid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), input_output_aliases: dict[int, int] = {}, debug: bool = False, interpret: bool = False, @@ -1250,8 +1252,9 @@ def pallas_call( corresponding ``in_specs`` and ``out_specs``. out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape and dtypes of the outputs. - grid_spec: An alternative way to specify ``grid``, ``in_specs``, and - ``out_specs``. If given, those other parameters must not be also given. + grid_spec: An alternative way to specify ``grid``, ``in_specs``, + ``out_specs`` and ``scratch_shapes``. If given, those other parameters + must not be also given. grid: the iteration space, as a tuple of integers. The kernel is executed as many times as ``prod(grid)``. See details at :ref:`pallas_grid`. @@ -1265,6 +1268,9 @@ def pallas_call( The default value for ``out_specs`` specifies the whole array, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``. See details at :ref:`pallas_blockspec`. + scratch_shapes: a PyTree of backend-specific temporary objects required + by the kernel, such as temporary buffers, synchronization primitives, + etc. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. @@ -1305,7 +1311,7 @@ def pallas_call( } if grid_spec is None: - grid_spec = GridSpec(grid, in_specs, out_specs) + grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) else: if grid: raise ValueError( @@ -1319,6 +1325,10 @@ def pallas_call( raise ValueError( "If `grid_spec` is specified, then `out_specs` must " f"be `no_block_spec`. It is {out_specs}") + if scratch_shapes: + raise ValueError( + "If `grid_spec` is specified, then `scratch_shapes` must " + f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 856bcae97..0a23e512d 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -277,6 +277,10 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scalar prefetch not implemented in the Triton backend" ) + if jaxpr.invars[grid_mapping.slice_scratch_ops]: + raise NotImplementedError( + "scratch memory not implemented in the Triton backend" + ) with grid_mapping.trace_env(): jaxpr, _ = pe.dce_jaxpr( jaxpr, [True] * len(jaxpr.outvars), instantiate=True diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bd9df6182..4810e7808 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -134,16 +134,15 @@ class PallasCallTest(PallasTest): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_with_async_copy_gmem_to_smem(self): + @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[ - plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], - ), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.async_copy_gmem_to_smem(