mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Pulled scratch_shapes
into GridSpec
It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton. PiperOrigin-RevId: 675950199
This commit is contained in:
parent
b904599b98
commit
e90336947a
@ -19,18 +19,22 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
to be scalars. The restrictions on the arguments are backend-specific:
|
||||
Non-scalar arguments are currently only supported on GPU, when using Triton.
|
||||
|
||||
* Deprecations
|
||||
|
||||
* New functionality
|
||||
|
||||
* {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`,
|
||||
a PyTree specifying backend-specific temporary objects needed by the
|
||||
kernel, for example, buffers, synchronization primitives etc.
|
||||
|
||||
## Released with jax 0.4.33 (September 16, 2024)
|
||||
|
||||
## Released with jax 0.4.32 (September 11, 2024)
|
||||
|
||||
## Released with jax 0.4.32
|
||||
|
||||
* Changes
|
||||
* The kernel function is not allowed to close over constants. Instead, all the needed arrays
|
||||
must be passed as inputs, with proper block specs ({jax-issue}`#22746`).
|
||||
|
||||
* Deprecations
|
||||
|
||||
* New functionality
|
||||
* Improved error messages for mistakes in the signature of the index map functions,
|
||||
to include the name and source location of the index map.
|
||||
@ -56,10 +60,6 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Previously it was possible to import many APIs that are meant to be
|
||||
private, as `jax.experimental.pallas.pallas`. This is not possible anymore.
|
||||
|
||||
|
||||
* Deprecations
|
||||
|
||||
|
||||
* New Functionality
|
||||
* Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`.
|
||||
* Improved error messages for the {func}`jax.experimental.pallas.pallas_call`
|
||||
|
@ -728,7 +728,16 @@ def _convert_block_spec_to_block_mapping(
|
||||
|
||||
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
|
||||
class ScratchShape(Protocol):
|
||||
def get_aval(self) -> jax_core.AbstractValue:
|
||||
...
|
||||
|
||||
|
||||
ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]]
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, kw_only=True)
|
||||
class GridSpec:
|
||||
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
|
||||
|
||||
@ -741,12 +750,14 @@ class GridSpec:
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
in_specs: BlockSpecTree
|
||||
out_specs: BlockSpecTree
|
||||
scratch_shapes: ScratchShapeTree = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid: Grid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
scratch_shapes: ScratchShapeTree = (),
|
||||
):
|
||||
# Be more lenient for in/out_specs
|
||||
if isinstance(in_specs, list):
|
||||
@ -758,6 +769,7 @@ class GridSpec:
|
||||
|
||||
self.in_specs = in_specs
|
||||
self.out_specs = out_specs
|
||||
self.scratch_shapes = tuple(scratch_shapes)
|
||||
|
||||
grid_names = None
|
||||
if isinstance(grid, int):
|
||||
@ -773,9 +785,6 @@ class GridSpec:
|
||||
self.grid = grid # type: ignore
|
||||
self.grid_names = grid_names
|
||||
|
||||
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
|
||||
assert False # Not needed in GridSpec
|
||||
|
||||
def _make_scalar_ref_aval(self, aval):
|
||||
assert False # Not needed in GridSpec
|
||||
|
||||
@ -820,12 +829,10 @@ def get_grid_mapping(
|
||||
else:
|
||||
num_flat_scalar_prefetch = 0
|
||||
jaxpr_scalar_ref_avals = ()
|
||||
|
||||
scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ())
|
||||
if scratch_shapes:
|
||||
if grid_spec.scratch_shapes:
|
||||
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
|
||||
scratch_shapes)
|
||||
flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes)
|
||||
grid_spec.scratch_shapes)
|
||||
flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes)
|
||||
num_flat_scratch_operands = len(flat_scratch_avals)
|
||||
jaxpr_scratch_avals = tree_util.tree_unflatten(
|
||||
scratch_tree, flat_scratch_avals)
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any, ClassVar, Hashable, Literal
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
@ -39,6 +39,7 @@ BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
GridMapping = pallas_core.GridMapping
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
ScratchShapeTree = pallas_core.ScratchShapeTree
|
||||
AbstractMemoryRef = pallas_core.AbstractMemoryRef
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
|
||||
@ -174,14 +175,9 @@ class MemoryRef:
|
||||
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
num_scalar_prefetch: int
|
||||
in_specs: pallas_core.BlockSpecTree
|
||||
out_specs: pallas_core.BlockSpecTree
|
||||
scratch_shapes: tuple[Any, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -189,9 +185,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: Grid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
scratch_shapes: Any | Sequence[Any] = ()
|
||||
scratch_shapes: ScratchShapeTree = ()
|
||||
):
|
||||
super().__init__(grid, in_specs, out_specs)
|
||||
super().__init__(grid, in_specs, out_specs, scratch_shapes)
|
||||
self.num_scalar_prefetch = num_scalar_prefetch
|
||||
self.scratch_shapes = tuple(scratch_shapes)
|
||||
|
||||
@ -199,14 +195,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
TPUMemorySpace.SMEM)
|
||||
|
||||
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
|
||||
if isinstance(obj, MemoryRef):
|
||||
return obj.get_aval()
|
||||
if isinstance(obj, SemaphoreType):
|
||||
return obj.get_aval()
|
||||
raise ValueError(f"No registered conversion for {type(obj)}. "
|
||||
"Only VMEM and SemaphoreType are supported.")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
|
@ -17,7 +17,6 @@
|
||||
from jax._src.pallas.mosaic_gpu.core import Barrier
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUGridSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem
|
||||
|
@ -150,26 +150,6 @@ class GPUBlockSpec(pallas_core.BlockSpec):
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, kw_only=True)
|
||||
class GPUGridSpec(pallas_core.GridSpec):
|
||||
scratch_shapes: Sequence[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid: pallas_core.Grid = (),
|
||||
in_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec,
|
||||
out_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec,
|
||||
scratch_shapes: Sequence[Any] = ()
|
||||
):
|
||||
super().__init__(grid, in_specs, out_specs)
|
||||
self.scratch_shapes = tuple(scratch_shapes)
|
||||
|
||||
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
|
||||
if isinstance(obj, (MemoryRef, Barrier)):
|
||||
return obj.get_aval()
|
||||
raise TypeError(f"Cannot convert {obj} to an abstract value")
|
||||
|
||||
|
||||
# TODO(b/354568887): Cosolidate this with TPU's MemoryRef.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MemoryRef:
|
||||
|
@ -62,6 +62,7 @@ BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
ScratchShapeTree = pallas_core.ScratchShapeTree
|
||||
CostEstimate = pallas_core.CostEstimate
|
||||
|
||||
# See the docstring for GridMapping for the calling convention
|
||||
@ -1233,6 +1234,7 @@ def pallas_call(
|
||||
grid: TupleGrid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
scratch_shapes: ScratchShapeTree = (),
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
debug: bool = False,
|
||||
interpret: bool = False,
|
||||
@ -1250,8 +1252,9 @@ def pallas_call(
|
||||
corresponding ``in_specs`` and ``out_specs``.
|
||||
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
|
||||
and dtypes of the outputs.
|
||||
grid_spec: An alternative way to specify ``grid``, ``in_specs``, and
|
||||
``out_specs``. If given, those other parameters must not be also given.
|
||||
grid_spec: An alternative way to specify ``grid``, ``in_specs``,
|
||||
``out_specs`` and ``scratch_shapes``. If given, those other parameters
|
||||
must not be also given.
|
||||
grid: the iteration space, as a tuple of integers. The kernel is executed
|
||||
as many times as ``prod(grid)``.
|
||||
See details at :ref:`pallas_grid`.
|
||||
@ -1265,6 +1268,9 @@ def pallas_call(
|
||||
The default value for ``out_specs`` specifies the whole array,
|
||||
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
||||
See details at :ref:`pallas_blockspec`.
|
||||
scratch_shapes: a PyTree of backend-specific temporary objects required
|
||||
by the kernel, such as temporary buffers, synchronization primitives,
|
||||
etc.
|
||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||
the index of the output that aliases them. These indices are in the
|
||||
flattened inputs and outputs.
|
||||
@ -1305,7 +1311,7 @@ def pallas_call(
|
||||
}
|
||||
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
|
||||
else:
|
||||
if grid:
|
||||
raise ValueError(
|
||||
@ -1319,6 +1325,10 @@ def pallas_call(
|
||||
raise ValueError(
|
||||
"If `grid_spec` is specified, then `out_specs` must "
|
||||
f"be `no_block_spec`. It is {out_specs}")
|
||||
if scratch_shapes:
|
||||
raise ValueError(
|
||||
"If `grid_spec` is specified, then `scratch_shapes` must "
|
||||
f"be `()`. It is {scratch_shapes}")
|
||||
del grid, in_specs, out_specs
|
||||
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
||||
# TODO(necula): this canonicalization may be convenient for some usage
|
||||
|
@ -277,6 +277,10 @@ def lower_jaxpr_to_triton_module(
|
||||
raise NotImplementedError(
|
||||
"scalar prefetch not implemented in the Triton backend"
|
||||
)
|
||||
if jaxpr.invars[grid_mapping.slice_scratch_ops]:
|
||||
raise NotImplementedError(
|
||||
"scratch memory not implemented in the Triton backend"
|
||||
)
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _ = pe.dce_jaxpr(
|
||||
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
|
||||
|
@ -134,16 +134,15 @@ class PallasCallTest(PallasTest):
|
||||
np.testing.assert_array_equal(kernel(x), x + 1.0)
|
||||
|
||||
def test_add_one_with_async_copy_gmem_to_smem(self):
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
|
||||
grid_spec=plgpu.GPUGridSpec(
|
||||
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
||||
scratch_shapes=[
|
||||
plgpu.SMEM((128,), jnp.float32),
|
||||
plgpu.Barrier(num_arrivals=1),
|
||||
],
|
||||
),
|
||||
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
|
||||
scratch_shapes=[
|
||||
plgpu.SMEM((128,), jnp.float32),
|
||||
plgpu.Barrier(num_arrivals=1),
|
||||
],
|
||||
)
|
||||
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
|
||||
plgpu.async_copy_gmem_to_smem(
|
||||
|
Loading…
x
Reference in New Issue
Block a user