Add support for dynamically computed grid bounds in Pallas kernels.

PiperOrigin-RevId: 603389883
This commit is contained in:
Adam Paszke 2024-02-01 09:14:30 -08:00 committed by jax authors
parent 69e4dd41b5
commit 21070b24d7
7 changed files with 83 additions and 9 deletions

View File

@ -15,11 +15,12 @@
"""Module for pallas-core functionality."""
from __future__ import annotations
import copy
from collections.abc import Sequence
import contextlib
import dataclasses
import functools
from typing import Any, Callable
from typing import Any, Callable, Union
from collections.abc import Iterator
from jax._src import api_util
@ -36,7 +37,8 @@ import jax.numpy as jnp
# mypy: ignore-errors
partial = functools.partial
Grid = tuple[int, ...]
Grid = tuple[Union[int, None], ...] # None indicates that the bound is dynamic.
StaticGrid = tuple[int, ...] # None indicates that the bound is dynamic.
split_list = util.split_list
map, unsafe_map = util.safe_map, map
@ -160,7 +162,7 @@ class BlockMapping:
@dataclasses.dataclass(frozen=True)
class GridMapping:
grid: tuple[int, ...]
grid: Grid
block_mappings: tuple[BlockMapping | None, ...]
mapped_dims: tuple[int, ...]
num_index_operands: int
@ -168,6 +170,16 @@ class GridMapping:
replace = dataclasses.replace
@property
def num_dynamic_grid_bounds(self):
return sum(b is None for b in self.grid)
@property
def static_grid(self) -> StaticGrid:
if self.num_dynamic_grid_bounds:
raise ValueError("Expected a grid with fully static bounds")
return self.grid # typing: ignore
def _preprocess_grid(grid: Grid | int | None) -> Grid:
if grid is None:
@ -317,3 +329,12 @@ class GridSpec:
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
def unzip_dynamic_grid_bounds(self) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(d if isinstance(d, int) else None for d in self.grid)
dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(self)
static_self.grid = static_grid
return static_self, dynamic_bounds

View File

@ -68,6 +68,9 @@ TPUMemorySpace = tpu_core.TPUMemorySpace
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
# The value interpreter as a dynamic dimension by MLIR.
MLIR_DYNAMIC = -9223372036854775808
partial = functools.partial
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
@ -381,6 +384,10 @@ def lower_jaxpr_to_module(
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
):
func_name = f"transform_{i}"
if bm is None:
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
if bm.index_map_jaxpr.consts:
raise NotImplementedError("Index map jaxpr with consts not supported.")
# ANY operands don't support windowing and require empty window_params.
@ -421,7 +428,8 @@ def lower_jaxpr_to_module(
m.body.append(mlir_func)
sym_tab.insert(mlir_func)
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(grid)
static_grid = [MLIR_DYNAMIC if b is None else b for b in grid]
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))

View File

@ -45,6 +45,10 @@ def pallas_call_tpu_lowering_rule(
**compiler_params: Any):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"Dynamic grid bounds not supported in interpret mode."
)
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
@ -78,8 +82,14 @@ def pallas_call_tpu_lowering_rule(
raise NotImplementedError(
"Cannot use both input_output_aliases and extra_args."
)
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
input_output_aliases = tuple(
(a[0] + num_dyn_bounds, a[1]) for a in input_output_aliases
)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
def _lower_fun(*args):
# Dynamic grid bounds have to go at the front.
dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:],
return mosaic.as_tpu_kernel(
mosaic_module,
out_avals,
@ -91,6 +101,7 @@ def pallas_call_tpu_lowering_rule(
flags=mosaic_params.get("flags", None),
input_output_aliases=input_output_aliases,
)(
*dynamic_grid_args,
*extra_args,
*args,
collective_id=mosaic_params.get("collective_id", None),

View File

@ -120,7 +120,7 @@ def _tree_map_with_kwargs(f, *args, **kwargs):
)
def _get_next_indices(grid: core.Grid, indices: GridIndices) -> GridIndices:
def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndices:
"""Takes a grid and current indices and returns the next indices.
grid: (3, 4, 5)
@ -412,7 +412,7 @@ class Pipeline(Protocol):
def emit_pipeline_with_allocations(
body: PipelineBody,
*,
grid: core.Grid,
grid: core.StaticGrid,
in_specs: PipelineBlockSpecs,
out_specs: PipelineBlockSpecs,
should_accumulate_out: Union[Sequence[bool], Any] = False,
@ -1046,7 +1046,7 @@ def emit_pipeline_with_allocations(
def emit_pipeline(
body: PipelineBody,
*,
grid: core.Grid,
grid: core.StaticGrid,
in_specs: PipelineBlockSpecs,
out_specs: PipelineBlockSpecs,
should_accumulate_out: Union[Sequence[bool], Any] = False,

View File

@ -95,11 +95,13 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if interpret:
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr. This should reproduce exactly what compiling to Triton
# will do.
grid = grid_mapping.grid
grid = grid_mapping.static_grid
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
@ -195,6 +197,8 @@ pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if grid_mapping.num_index_operands:
raise NotImplementedError
if input_output_aliases:
@ -282,6 +286,8 @@ def _pallas_call_batching_rule(args, dims, *,
interpret: bool,
which_linear: tuple[bool, ...],
**compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if grid_mapping.num_index_operands:
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
@ -453,8 +459,11 @@ def pallas_call(
**compiler_params: Any,
):
name = _extract_function_name(f, name)
if grid is not None and grid_spec is not None:
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
if isinstance(out_shape, list):
out_shape = tuple(out_shape)
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
@ -472,7 +481,8 @@ def pallas_call(
out_tree)
which_linear = (False,) * len(flat_args)
out_flat = pallas_call_p.bind(
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
*dynamic_grid_bounds, *consts, *flat_args,
jaxpr=jaxpr, name=name, which_linear=which_linear,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
out_shapes=tuple(flat_out_shapes), debug=debug,

View File

@ -1568,6 +1568,8 @@ def pallas_call_lowering(
triton_params: dict[str, Any] | None = None,
**compiler_params: Any,
):
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("dynamic grid bounds not supported in the Triton backend")
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,

View File

@ -1091,6 +1091,28 @@ class PallasCallTest(PallasTPUTest):
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18))
)(x)
def test_dynamic_grid(self):
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
def kernel(y_ref):
@pl.when(pl.program_id(0) == 0)
def _init():
y_ref[...] = jnp.zeros_like(y_ref)
y_ref[...] += 1
@jax.jit
def dynamic_kernel(steps):
return pl.pallas_call(
kernel,
grid=(steps * 2,),
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
out_shape=result_ty,
)()
np.testing.assert_array_equal(
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
)
class PallasUXTest(PallasTPUTest):