Pulled scratch_shapes into GridSpec

It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

PiperOrigin-RevId: 675950199
This commit is contained in:
Sergei Lebedev 2024-09-18 05:25:37 -07:00 committed by jax authors
parent b904599b98
commit e90336947a
8 changed files with 52 additions and 65 deletions

View File

@ -19,18 +19,22 @@ Remember to align the itemized text with the first line of an item within a list
to be scalars. The restrictions on the arguments are backend-specific:
Non-scalar arguments are currently only supported on GPU, when using Triton.
* Deprecations
* New functionality
* {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`,
a PyTree specifying backend-specific temporary objects needed by the
kernel, for example, buffers, synchronization primitives etc.
## Released with jax 0.4.33 (September 16, 2024)
## Released with jax 0.4.32 (September 11, 2024)
## Released with jax 0.4.32
* Changes
* The kernel function is not allowed to close over constants. Instead, all the needed arrays
must be passed as inputs, with proper block specs ({jax-issue}`#22746`).
* Deprecations
* New functionality
* Improved error messages for mistakes in the signature of the index map functions,
to include the name and source location of the index map.
@ -56,10 +60,6 @@ Remember to align the itemized text with the first line of an item within a list
* Previously it was possible to import many APIs that are meant to be
private, as `jax.experimental.pallas.pallas`. This is not possible anymore.
* Deprecations
* New Functionality
* Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`.
* Improved error messages for the {func}`jax.experimental.pallas.pallas_call`

View File

@ -728,7 +728,16 @@ def _convert_block_spec_to_block_mapping(
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
@dataclasses.dataclass(init=False)
class ScratchShape(Protocol):
def get_aval(self) -> jax_core.AbstractValue:
...
ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]]
@dataclasses.dataclass(init=False, kw_only=True)
class GridSpec:
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
@ -741,12 +750,14 @@ class GridSpec:
grid_names: tuple[Hashable, ...] | None
in_specs: BlockSpecTree
out_specs: BlockSpecTree
scratch_shapes: ScratchShapeTree = ()
def __init__(
self,
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
scratch_shapes: ScratchShapeTree = (),
):
# Be more lenient for in/out_specs
if isinstance(in_specs, list):
@ -758,6 +769,7 @@ class GridSpec:
self.in_specs = in_specs
self.out_specs = out_specs
self.scratch_shapes = tuple(scratch_shapes)
grid_names = None
if isinstance(grid, int):
@ -773,9 +785,6 @@ class GridSpec:
self.grid = grid # type: ignore
self.grid_names = grid_names
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
assert False # Not needed in GridSpec
def _make_scalar_ref_aval(self, aval):
assert False # Not needed in GridSpec
@ -820,12 +829,10 @@ def get_grid_mapping(
else:
num_flat_scalar_prefetch = 0
jaxpr_scalar_ref_avals = ()
scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ())
if scratch_shapes:
if grid_spec.scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
scratch_shapes)
flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes)
grid_spec.scratch_shapes)
flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)

View File

@ -19,7 +19,7 @@ from collections.abc import Sequence
import dataclasses
import enum
import functools
from typing import Any, ClassVar, Hashable, Literal
from typing import Any, ClassVar, Literal
import jax
from jax._src import core as jax_core
@ -39,6 +39,7 @@ BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
GridMapping = pallas_core.GridMapping
NoBlockSpec = pallas_core.NoBlockSpec
ScratchShapeTree = pallas_core.ScratchShapeTree
AbstractMemoryRef = pallas_core.AbstractMemoryRef
no_block_spec = pallas_core.no_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
@ -174,14 +175,9 @@ class MemoryRef:
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
@dataclasses.dataclass(init=False, unsafe_hash=True)
@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
num_scalar_prefetch: int
in_specs: pallas_core.BlockSpecTree
out_specs: pallas_core.BlockSpecTree
scratch_shapes: tuple[Any, ...]
def __init__(
self,
@ -189,9 +185,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
scratch_shapes: Any | Sequence[Any] = ()
scratch_shapes: ScratchShapeTree = ()
):
super().__init__(grid, in_specs, out_specs)
super().__init__(grid, in_specs, out_specs, scratch_shapes)
self.num_scalar_prefetch = num_scalar_prefetch
self.scratch_shapes = tuple(scratch_shapes)
@ -199,14 +195,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
TPUMemorySpace.SMEM)
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
if isinstance(obj, MemoryRef):
return obj.get_aval()
if isinstance(obj, SemaphoreType):
return obj.get_aval()
raise ValueError(f"No registered conversion for {type(obj)}. "
"Only VMEM and SemaphoreType are supported.")
@dataclasses.dataclass(frozen=True)
class TensorCore:

View File

@ -17,7 +17,6 @@
from jax._src.pallas.mosaic_gpu.core import Barrier
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUGridSpec
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem

View File

@ -150,26 +150,6 @@ class GPUBlockSpec(pallas_core.BlockSpec):
)
@dataclasses.dataclass(init=False, kw_only=True)
class GPUGridSpec(pallas_core.GridSpec):
scratch_shapes: Sequence[Any]
def __init__(
self,
grid: pallas_core.Grid = (),
in_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec,
out_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec,
scratch_shapes: Sequence[Any] = ()
):
super().__init__(grid, in_specs, out_specs)
self.scratch_shapes = tuple(scratch_shapes)
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
if isinstance(obj, (MemoryRef, Barrier)):
return obj.get_aval()
raise TypeError(f"Cannot convert {obj} to an abstract value")
# TODO(b/354568887): Cosolidate this with TPU's MemoryRef.
@dataclasses.dataclass(frozen=True)
class MemoryRef:

View File

@ -62,6 +62,7 @@ BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
ScratchShapeTree = pallas_core.ScratchShapeTree
CostEstimate = pallas_core.CostEstimate
# See the docstring for GridMapping for the calling convention
@ -1233,6 +1234,7 @@ def pallas_call(
grid: TupleGrid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
scratch_shapes: ScratchShapeTree = (),
input_output_aliases: dict[int, int] = {},
debug: bool = False,
interpret: bool = False,
@ -1250,8 +1252,9 @@ def pallas_call(
corresponding ``in_specs`` and ``out_specs``.
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
and dtypes of the outputs.
grid_spec: An alternative way to specify ``grid``, ``in_specs``, and
``out_specs``. If given, those other parameters must not be also given.
grid_spec: An alternative way to specify ``grid``, ``in_specs``,
``out_specs`` and ``scratch_shapes``. If given, those other parameters
must not be also given.
grid: the iteration space, as a tuple of integers. The kernel is executed
as many times as ``prod(grid)``.
See details at :ref:`pallas_grid`.
@ -1265,6 +1268,9 @@ def pallas_call(
The default value for ``out_specs`` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
See details at :ref:`pallas_blockspec`.
scratch_shapes: a PyTree of backend-specific temporary objects required
by the kernel, such as temporary buffers, synchronization primitives,
etc.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them. These indices are in the
flattened inputs and outputs.
@ -1305,7 +1311,7 @@ def pallas_call(
}
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
else:
if grid:
raise ValueError(
@ -1319,6 +1325,10 @@ def pallas_call(
raise ValueError(
"If `grid_spec` is specified, then `out_specs` must "
f"be `no_block_spec`. It is {out_specs}")
if scratch_shapes:
raise ValueError(
"If `grid_spec` is specified, then `scratch_shapes` must "
f"be `()`. It is {scratch_shapes}")
del grid, in_specs, out_specs
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
# TODO(necula): this canonicalization may be convenient for some usage

View File

@ -277,6 +277,10 @@ def lower_jaxpr_to_triton_module(
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
if jaxpr.invars[grid_mapping.slice_scratch_ops]:
raise NotImplementedError(
"scratch memory not implemented in the Triton backend"
)
with grid_mapping.trace_env():
jaxpr, _ = pe.dce_jaxpr(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True

View File

@ -134,16 +134,15 @@ class PallasCallTest(PallasTest):
np.testing.assert_array_equal(kernel(x), x + 1.0)
def test_add_one_with_async_copy_gmem_to_smem(self):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
grid_spec=plgpu.GPUGridSpec(
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((128,), jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((128,), jnp.float32),
plgpu.Barrier(num_arrivals=1),
],
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.async_copy_gmem_to_smem(