mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add support for dynamically computed grid bounds in Pallas kernels.
PiperOrigin-RevId: 603389883
This commit is contained in:
parent
69e4dd41b5
commit
21070b24d7
@ -15,11 +15,12 @@
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Union
|
||||
from collections.abc import Iterator
|
||||
|
||||
from jax._src import api_util
|
||||
@ -36,7 +37,8 @@ import jax.numpy as jnp
|
||||
# mypy: ignore-errors
|
||||
|
||||
partial = functools.partial
|
||||
Grid = tuple[int, ...]
|
||||
Grid = tuple[Union[int, None], ...] # None indicates that the bound is dynamic.
|
||||
StaticGrid = tuple[int, ...] # None indicates that the bound is dynamic.
|
||||
split_list = util.split_list
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -160,7 +162,7 @@ class BlockMapping:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GridMapping:
|
||||
grid: tuple[int, ...]
|
||||
grid: Grid
|
||||
block_mappings: tuple[BlockMapping | None, ...]
|
||||
mapped_dims: tuple[int, ...]
|
||||
num_index_operands: int
|
||||
@ -168,6 +170,16 @@ class GridMapping:
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
@property
|
||||
def num_dynamic_grid_bounds(self):
|
||||
return sum(b is None for b in self.grid)
|
||||
|
||||
@property
|
||||
def static_grid(self) -> StaticGrid:
|
||||
if self.num_dynamic_grid_bounds:
|
||||
raise ValueError("Expected a grid with fully static bounds")
|
||||
return self.grid # typing: ignore
|
||||
|
||||
|
||||
def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
if grid is None:
|
||||
@ -317,3 +329,12 @@ class GridSpec:
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
|
||||
|
||||
def unzip_dynamic_grid_bounds(self) -> tuple[GridSpec, tuple[Any, ...]]:
|
||||
static_grid = tuple(d if isinstance(d, int) else None for d in self.grid)
|
||||
dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int))
|
||||
# We can't use dataclasses.replace, because our fields are incompatible
|
||||
# with __init__'s signature.
|
||||
static_self = copy.copy(self)
|
||||
static_self.grid = static_grid
|
||||
return static_self, dynamic_bounds
|
||||
|
@ -68,6 +68,9 @@ TPUMemorySpace = tpu_core.TPUMemorySpace
|
||||
VMEM = tpu_core.TPUMemorySpace.VMEM
|
||||
SMEM = tpu_core.TPUMemorySpace.SMEM
|
||||
|
||||
# The value interpreter as a dynamic dimension by MLIR.
|
||||
MLIR_DYNAMIC = -9223372036854775808
|
||||
|
||||
partial = functools.partial
|
||||
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
|
||||
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
|
||||
@ -381,6 +384,10 @@ def lower_jaxpr_to_module(
|
||||
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
|
||||
):
|
||||
func_name = f"transform_{i}"
|
||||
if bm is None:
|
||||
raise NotImplementedError(
|
||||
"BlockSpecs are required on TPU when grid is specified"
|
||||
)
|
||||
if bm.index_map_jaxpr.consts:
|
||||
raise NotImplementedError("Index map jaxpr with consts not supported.")
|
||||
# ANY operands don't support windowing and require empty window_params.
|
||||
@ -421,7 +428,8 @@ def lower_jaxpr_to_module(
|
||||
m.body.append(mlir_func)
|
||||
sym_tab.insert(mlir_func)
|
||||
func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
|
||||
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(grid)
|
||||
static_grid = [MLIR_DYNAMIC if b is None else b for b in grid]
|
||||
func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)
|
||||
|
||||
func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
|
||||
|
@ -45,6 +45,10 @@ def pallas_call_tpu_lowering_rule(
|
||||
**compiler_params: Any):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if interpret:
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"Dynamic grid bounds not supported in interpret mode."
|
||||
)
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
@ -78,8 +82,14 @@ def pallas_call_tpu_lowering_rule(
|
||||
raise NotImplementedError(
|
||||
"Cannot use both input_output_aliases and extra_args."
|
||||
)
|
||||
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
|
||||
input_output_aliases = tuple(
|
||||
(a[0] + num_dyn_bounds, a[1]) for a in input_output_aliases
|
||||
)
|
||||
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
|
||||
def _lower_fun(*args):
|
||||
# Dynamic grid bounds have to go at the front.
|
||||
dynamic_grid_args, args = args[:num_dyn_bounds], args[num_dyn_bounds:],
|
||||
return mosaic.as_tpu_kernel(
|
||||
mosaic_module,
|
||||
out_avals,
|
||||
@ -91,6 +101,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
flags=mosaic_params.get("flags", None),
|
||||
input_output_aliases=input_output_aliases,
|
||||
)(
|
||||
*dynamic_grid_args,
|
||||
*extra_args,
|
||||
*args,
|
||||
collective_id=mosaic_params.get("collective_id", None),
|
||||
|
@ -120,7 +120,7 @@ def _tree_map_with_kwargs(f, *args, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
def _get_next_indices(grid: core.Grid, indices: GridIndices) -> GridIndices:
|
||||
def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndices:
|
||||
"""Takes a grid and current indices and returns the next indices.
|
||||
|
||||
grid: (3, 4, 5)
|
||||
@ -412,7 +412,7 @@ class Pipeline(Protocol):
|
||||
def emit_pipeline_with_allocations(
|
||||
body: PipelineBody,
|
||||
*,
|
||||
grid: core.Grid,
|
||||
grid: core.StaticGrid,
|
||||
in_specs: PipelineBlockSpecs,
|
||||
out_specs: PipelineBlockSpecs,
|
||||
should_accumulate_out: Union[Sequence[bool], Any] = False,
|
||||
@ -1046,7 +1046,7 @@ def emit_pipeline_with_allocations(
|
||||
def emit_pipeline(
|
||||
body: PipelineBody,
|
||||
*,
|
||||
grid: core.Grid,
|
||||
grid: core.StaticGrid,
|
||||
in_specs: PipelineBlockSpecs,
|
||||
out_specs: PipelineBlockSpecs,
|
||||
should_accumulate_out: Union[Sequence[bool], Any] = False,
|
||||
|
@ -95,11 +95,13 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if interpret:
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr. This should reproduce exactly what compiling to Triton
|
||||
# will do.
|
||||
grid = grid_mapping.grid
|
||||
grid = grid_mapping.static_grid
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
@ -195,6 +197,8 @@ pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, **compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError
|
||||
if input_output_aliases:
|
||||
@ -282,6 +286,8 @@ def _pallas_call_batching_rule(args, dims, *,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if grid_mapping.num_index_operands:
|
||||
scalar_args, args = split_list(args, [grid_mapping.num_index_operands])
|
||||
scalar_bdims, bdims = split_list(dims, [grid_mapping.num_index_operands])
|
||||
@ -453,8 +459,11 @@ def pallas_call(
|
||||
**compiler_params: Any,
|
||||
):
|
||||
name = _extract_function_name(f, name)
|
||||
if grid is not None and grid_spec is not None:
|
||||
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
||||
if isinstance(out_shape, list):
|
||||
out_shape = tuple(out_shape)
|
||||
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
|
||||
@ -472,7 +481,8 @@ def pallas_call(
|
||||
out_tree)
|
||||
which_linear = (False,) * len(flat_args)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
|
||||
*dynamic_grid_bounds, *consts, *flat_args,
|
||||
jaxpr=jaxpr, name=name, which_linear=which_linear,
|
||||
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
for a in flat_args),
|
||||
out_shapes=tuple(flat_out_shapes), debug=debug,
|
||||
|
@ -1568,6 +1568,8 @@ def pallas_call_lowering(
|
||||
triton_params: dict[str, Any] | None = None,
|
||||
**compiler_params: Any,
|
||||
):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("dynamic grid bounds not supported in the Triton backend")
|
||||
if interpret:
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx,
|
||||
|
@ -1091,6 +1091,28 @@ class PallasCallTest(PallasTPUTest):
|
||||
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18))
|
||||
)(x)
|
||||
|
||||
def test_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(steps):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
np.testing.assert_array_equal(
|
||||
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
|
||||
)
|
||||
|
||||
|
||||
class PallasUXTest(PallasTPUTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user