Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9

PiperOrigin-RevId: 655871052
This commit is contained in:
George Necula 2024-07-25 01:49:59 -07:00 committed by jax authors
parent c8ea86c9c9
commit 4063373b22
17 changed files with 564 additions and 521 deletions

View File

@ -34,7 +34,7 @@ This generalizes to any tuple of integers (a length `d` grid will correspond
to `d` nested loops).
The kernel is executed as many times
as `prod(grid)`.
The default grid value `None` stands for `()`, and results in one
The default grid value `()` results in one
kernel invocation.
Each of these invocations is referred to as a "program".
To access which program (i.e. which element of the grid) the kernel is currently

View File

@ -27,6 +27,7 @@ import warnings
import jax
from jax._src import api_util
from jax._src import config
from jax._src import core as jax_core
from jax._src import deprecations
from jax._src import linear_util as lu
@ -40,7 +41,8 @@ import jax.numpy as jnp
class DynamicGridDim:
pass
def __repr__(self):
return "DynamicGridDim"
dynamic_grid_dim = DynamicGridDim()
@ -173,18 +175,27 @@ def current_grid_env() -> GridEnv | None:
class Mapped:
pass
"""Used as a block shape dimension to denote a mapped dimension.
A mapped dimension behaves like `1` except it is squeezed from the block.
See :ref:`pallas_blockspec` for more details.
"""
def __repr__(self):
return "Mapped"
mapped = Mapped()
@dataclasses.dataclass(frozen=True)
class Unblocked:
padding: tuple[tuple[int, int], ...] | None = None
def __repr__(self):
return f"Unblocked(padding={self.padding})"
unblocked = Unblocked()
class Blocked:
pass
def __repr__(self):
return "Blocked"
blocked = Blocked()
@ -196,6 +207,8 @@ class BlockSpec:
"""Specifies how an array should be sliced for each iteration of a kernel.
See :ref:`pallas_blockspec` for more details.
This object contains the parameters passed through the API.
An internal canonicalized version is in BlockMapping.
"""
block_shape: tuple[int | None, ...] | None = None
index_map: Callable[..., Any] | None = None
@ -247,9 +260,38 @@ BlockSpecTree = Any
@dataclasses.dataclass(frozen=True)
class BlockMapping:
"""An internal canonicalized version of BlockSpec.
See the `check_invariants` method for precise specification.
"""
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: str # The origin, e.g. input[2]["field"]
def check_invariants(self) -> None:
if not config.enable_checks.value: return
unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
assert unmapped_block_shape == self.block_aval.shape, (
self.block_shape, self.block_aval)
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
self.block_shape, self.array_shape_dtype
)
assert not self.index_map_jaxpr.consts
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals)
assert all(ov.shape == () and
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
for ov in self.index_map_jaxpr.out_avals), (
self.index_map_jaxpr.out_avals)
def replace(self, **kwargs):
new_self = dataclasses.replace(self, **kwargs)
new_self.check_invariants()
return new_self
def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
@ -269,8 +311,6 @@ class BlockMapping:
else:
raise RuntimeError(f"Unknown indexing mode: {self.indexing_mode}")
replace = dataclasses.replace
@contextlib.contextmanager
def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
@ -285,16 +325,86 @@ def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
@dataclasses.dataclass(frozen=True)
class GridMapping:
"""An internal canonicalized version of GridSpec.
Encodes the calling conventions of the pallas_call primitive, the kernel,
and the index maps.
The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *consts, *inputs``.
The ``index`` operands are for the scalar prefetch.
The ``consts`` are constants captured by the kernel function.
The kernel function is invoked with:
``*index, *consts, *inputs, *scratch``.
The index map functions are invoked with:
``*program_ids, *index``.
See the `check_invariants` method for a more precise specification.
"""
grid: GridMappingGrid
grid_names: tuple[Hashable, ...] | None
block_mappings: tuple[BlockMapping | None, ...]
mapped_dims: tuple[int, ...] = ()
num_index_operands: int = 0
num_scratch_operands: int = 0
# Number of constants hoisted to operands by ``_hoist_consts_to_refs``.
num_constant_operands: int = 0
replace = dataclasses.replace
# Block mappings for: *consts, *inputs, *outputs
block_mappings: tuple[BlockMapping, ...]
# The inputs for tracing the index map: the tree and the flat avals
index_map_tree: tree_util.PyTreeDef
index_map_avals: tuple[jax_core.AbstractValue]
# Which dimensions in `grid` are vmapped.
vmapped_dims: tuple[int, ...]
num_index_operands: int
# Number of captured constants hoisted to operands.
num_constant_operands: int
num_inputs: int
num_outputs: int
num_scratch_operands: int
def check_invariants(self) -> None:
if not config.enable_checks.value: return
assert (len(self.block_mappings) ==
self.num_constant_operands + self.num_inputs + self.num_outputs), (
self.num_constant_operands, self.num_inputs, self.num_outputs,
self.block_mappings
)
# index_map_avals = int32[] * len(self.grid) + index_operands
assert len(self.index_map_avals) == len(self.grid) + self.num_index_operands, (
self.index_map_avals,
self.grid,
self.num_index_operands,
)
# Check that we can put together the avals and the tree.
index_map_args, index_map_kwargs = self.index_map_tree.unflatten(
self.index_map_avals)
assert not index_map_kwargs
assert len(index_map_args) >= len(self.grid)
for i in range(len(self.grid)):
index_map_arg = index_map_args[i]
assert index_map_arg.shape == ()
assert index_map_arg.dtype == jnp.int32
assert len(self.vmapped_dims) <= len(self.grid)
for i in self.vmapped_dims:
assert 0 <= i < len(self.grid)
if self.grid_names is not None:
assert len(self.grid) == len(self.grid_names), (self.grid, self.grid_names)
for bm in self.block_mappings:
bm.check_invariants()
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
self.index_map_avals,
bm.index_map_jaxpr.in_avals,
)
def replace(self, **kwargs) -> GridMapping:
new_self = dataclasses.replace(self, **kwargs)
new_self.check_invariants()
return new_self
@property
# TODO(necula): deprecate and then remove this property.
def mapped_dims(self) -> tuple[int, ...]:
return self.vmapped_dims
@property
def num_dynamic_grid_bounds(self):
@ -314,74 +424,123 @@ class GridMapping:
axis_env_ctx = jax_core.extend_axis_env_nd(
zip(self.grid_names, self.grid)
)
with tracing_grid_env(self.grid, self.mapped_dims), axis_env_ctx:
with tracing_grid_env(self.grid, self.vmapped_dims), axis_env_ctx:
yield
@property
def slice_index_ops(self):
"""Returns a slice object to select the index operands to a kernel."""
return slice(0, self.num_index_operands)
@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel."""
return slice(self.num_index_operands, None)
@property
def slice_scratch_ops(self):
"""Returns a slice object to select the scratch operands to a kernel."""
if self.num_scratch_operands:
return slice(-self.num_scratch_operands, None)
else:
return slice(0, 0)
# TODO(necula): this is used to recover the old `in_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
"""The shapes of *index, *consts, *inputs."""
index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):]]
consts_inputs_shapes = [
bm.array_shape_dtype
for bm in self.block_mappings[
:self.num_constant_operands + self.num_inputs]]
return tuple(index_shapes + consts_inputs_shapes)
# TODO(necula): this is used to recover the old `out_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
return tuple(
bm.array_shape_dtype
for bm in self.block_mappings[
self.num_constant_operands + self.num_inputs:
self.num_constant_operands + self.num_inputs + self.num_outputs])
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
if isinstance(dim, jax.Array):
return True
return jax_core.is_dim(dim)
def _preprocess_grid(grid: Grid | int | None) -> tuple[TupleGrid, GridNames]:
if grid is None:
return (), None
if isinstance(grid, int):
return (grid,), None
# Handle empty grid
if not grid:
return grid, None # type: ignore
# Check if we have a named grid
if isinstance(grid[0], tuple):
grid_names, grid = util.unzip2(grid) # type: ignore
else:
grid_names = None
# TODO(b/353730556): allow NumPy scalars in grids
if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore
raise ValueError(
f"Grid must be a tuple of integers or jax.Array, got {grid}"
)
return grid, grid_names # type: ignore
def _convert_block_spec_to_block_mapping(
in_avals: Sequence[jax_core.ShapedArray],
block_spec: BlockSpec,
path: tree_util.KeyPath,
aval: jax_core.ShapedArray,
in_tree: Any,
array_aval: jax_core.ShapedArray,
*,
# Inputs for the index_map
index_map_avals: Sequence[jax_core.AbstractValue],
index_map_tree: tree_util.PyTreeDef,
grid: GridMappingGrid,
mapped_dims: tuple[int, ...],
what: str, # Used to localize error messages, e.g., {what}{path}
) -> BlockMapping | None:
) -> BlockMapping:
origin = f"{what}{tree_util.keystr(path)}"
if block_spec is no_block_spec:
return None
block_spec = BlockSpec(None, None)
if block_spec.index_map is None:
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
compute_index = lambda *args: (0,) * len(array_aval.shape)
else:
compute_index = block_spec.compute_index
if block_spec.block_shape is None:
block_shape = aval.shape
block_shape = array_aval.shape
else:
block_shape = block_spec.block_shape
block_shape = tuple(
mapped if s is None else s for s in block_shape)
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
if len(out_avals) != len(block_shape):
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Index map for {what}{tree_util.keystr(path)} must return "
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
f"Block shape for {origin} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {array_aval.shape}"
)
unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
block_spec.memory_space)
if not jax_core.is_constant_shape(block_aval.shape):
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"{origin} has block_shape: {block_aval.shape}")
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index),
index_map_tree)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
index_map_avals)
mapped_block_shape = tuple(
mapped if s is None else s for s in block_shape)
if len(out_avals) != len(mapped_block_shape):
raise ValueError(
# TODO(necula): show the name and location of the index map function
f"Index map for {origin} must return "
f"{len(block_aval.shape)} values to match block shape {mapped_block_shape}. "
f"Currently returning {len(out_avals)} values."
)
if consts:
raise NotImplementedError(
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
# TODO(necula): show the name and location of the index map function
f"Index map for {origin} captures constants: "
f"{consts}")
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
indexing_mode=block_spec.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
origin=origin,
)
mapping.check_invariants()
return mapping
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
) -> state.AbstractRef:
@ -390,65 +549,20 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
shape = tuple(s for s in block_shape if s is not None)
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
in_specs: Sequence[BlockSpec],
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.ShapedArray],
out_specs: Sequence[BlockSpec],
out_paths: Sequence[tree_util.KeyPath]):
def make_ref_aval(aval: jax_core.ShapedArray,
spec: BlockSpec,
path: tree_util.KeyPath,
what: str) -> state.AbstractRef:
if spec is no_block_spec:
memory_space = None
block_shape = None
else:
memory_space = spec.memory_space
block_shape = spec.block_shape
ref_aval = AbstractMemoryRef(aval, memory_space)
if block_shape is not None:
if len(ref_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
)
block_shape_unmapped = tuple(s for s in block_shape if s is not None)
ref_aval = ref_aval.update(
inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped))
if not jax_core.is_constant_shape(ref_aval.shape):
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"{what}{tree_util.keystr(path)} has block_shape: {ref_aval.shape}")
return ref_aval
in_ref_avals = [
make_ref_aval(aval, in_spec, in_path, "input")
for aval, in_spec, in_path in zip(in_avals, in_specs, in_paths)
]
out_ref_avals = [
make_ref_aval(aval, out_spec, out_path, "output")
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
]
return in_ref_avals, out_ref_avals
@dataclasses.dataclass(init=False, unsafe_hash=True)
class GridSpec:
"""Encodes the parameters of the grid, as given through the API.
An internal sanitized version is in GridMapping.
"""
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
in_specs_tree: Any
out_specs_tree: Any
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
def __init__(
self,
grid: Grid | None = None,
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
):
@ -460,89 +574,153 @@ class GridSpec:
if isinstance(out_specs, list):
out_specs = tuple(out_specs)
self.grid, self.grid_names = _preprocess_grid(grid)
if in_specs is not no_block_spec:
flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs)
self.in_specs = tuple(flat_in_specs)
else:
self.in_specs = in_specs
self.in_specs_tree = None
if out_specs is not no_block_spec:
flat_out_specs, self.out_specs_tree = tree_util.tree_flatten(out_specs)
self.out_specs = tuple(flat_out_specs)
else:
self.out_specs = out_specs
self.out_specs_tree = None
self.in_specs = in_specs
self.out_specs = out_specs
def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree):
if self.in_specs is no_block_spec:
flat_in_specs = [no_block_spec] * len(in_avals)
else:
flat_in_specs = self.in_specs
if self.in_specs_tree != in_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`in_specs`", self.in_specs_tree,
"inputs", in_tree))
if self.out_specs is no_block_spec:
flat_out_specs = [no_block_spec] * len(out_avals)
else:
flat_out_specs = self.out_specs
if self.out_specs_tree != out_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`out_specs`", self.out_specs_tree,
"`out_shape`", out_tree))
return flat_in_specs, flat_out_specs
grid_names = None
if isinstance(grid, int):
grid = (grid,)
elif grid and isinstance(grid[0], tuple): # Check if we have a named grid
grid_names, grid = util.unzip2(grid) # type: ignore
# TODO(b/353730556): allow NumPy scalars in grids
if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore
raise ValueError(
f"Grid must be a tuple of integers or jax.Array, got {grid}"
)
self.grid = grid # type: ignore
self.grid_names = grid_names
def get_grid_mapping(
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
self,
in_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.AbstractValue],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
num_scalar_prefetch: int = 0,
scratch_shapes: Sequence[Any] = (),
) -> tuple[tuple[AbstractMemoryRef, ...],
GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
dynamic_grid_dim if d is None else d for d in self.grid
)
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree)
in_ref_avals, out_ref_avals = _get_ref_avals(
in_avals, flat_in_specs, in_paths,
out_avals, flat_out_specs, out_paths)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
# The inputs for the index maps
index_map_avals = (
(jax_core.ShapedArray((), jnp.dtype("int32")),) * len(self.grid))
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
if num_scalar_prefetch:
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
scalar_avals, unflat_in_avals = split_list(
all_avals, [num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
num_flat_scalar_prefetch = len(flat_scalar_avals)
scalar_ref_avals = [
self._make_scalar_ref_aval(aval)
for aval in flat_scalar_avals]
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
index_map_tree = tree_util.tree_structure(((*index_map_avals,
*scalar_avals), {}))
index_map_avals = (*index_map_avals, *scalar_ref_avals)
del scalar_ref_avals, flat_scalar_avals, scalar_tree
del scalar_avals, unflat_in_avals, all_avals
else:
num_flat_scalar_prefetch = 0
jaxpr_scalar_ref_avals = ()
if scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
scratch_shapes)
flat_scratch_avals = map(self._make_scratch_aval, flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
else:
num_flat_scratch_operands = 0
jaxpr_scratch_avals = ()
if self.in_specs is not no_block_spec:
flat_in_specs, in_specs_tree = tree_util.tree_flatten(self.in_specs)
if in_specs_tree != in_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
"inputs", in_tree))
else:
flat_in_specs = [no_block_spec] * len(in_avals)
in_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
grid_avals,
in_tree=grid_tree,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="input",
what="inputs",
),
flat_in_specs,
in_paths,
in_ref_avals,
in_paths[num_flat_scalar_prefetch:],
in_avals,
)
if self.out_specs is not no_block_spec:
flat_out_specs, out_specs_tree = tree_util.tree_flatten(self.out_specs)
if out_specs_tree != out_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
"`out_shape`", out_tree))
else:
flat_out_specs = [no_block_spec] * len(out_avals)
out_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
grid_avals,
in_tree=grid_tree,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="output",
what="outputs",
),
flat_out_specs,
out_paths,
out_ref_avals,
out_avals,
)
grid_mapping = GridMapping(
grid_mapping_grid, self.grid_names, # type: ignore
(*in_block_mappings, *out_block_mappings)
grid=grid_mapping_grid, # type: ignore[arg-type]
grid_names=self.grid_names,
block_mappings=(*in_block_mappings, *out_block_mappings),
index_map_avals=index_map_avals, # type: ignore[arg-type]
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_constant_operands=0, # Fixed up later
num_inputs=len(flat_in_specs),
num_outputs=len(flat_out_specs),
num_scratch_operands=num_flat_scratch_operands,
)
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
grid_mapping.check_invariants()
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
return (*jaxpr_in_avals, *jaxpr_out_avals,
*jaxpr_scratch_avals), grid_mapping
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
assert False # Not needed in GridSpec
def _make_scalar_ref_aval(self, aval):
assert False # Not needed in GridSpec
def unzip_dynamic_grid_bounds(
self,
@ -557,6 +735,7 @@ class GridSpec:
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds
def pytreedef_mismatch_err_msg(
what1: str, tree1: tree_util.PyTreeDef,
what2: str, tree2: tree_util.PyTreeDef) -> str:

View File

@ -24,7 +24,6 @@ from typing import Any, Hashable
import jax
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import tree_util
from jax._src import util
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
@ -141,30 +140,19 @@ class MemoryRef:
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
def _make_aval(obj: object) -> jax_core.AbstractValue:
if isinstance(obj, MemoryRef):
return obj.get_aval()
if isinstance(obj, SemaphoreType):
return obj.get_aval()
raise ValueError(f"No registered conversion for {type(obj)}. "
"Only VMEM and SemaphoreType are supported.")
@dataclasses.dataclass(init=False, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
num_scalar_prefetch: int
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
in_specs_tree: Any
out_specs_tree: Any
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
scratch_shapes: tuple[Any, ...]
def __init__(
self,
num_scalar_prefetch: int,
grid: Grid | None = None,
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
scratch_shapes: Any | Sequence[Any] = ()
@ -173,84 +161,25 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
self.num_scalar_prefetch = num_scalar_prefetch
self.scratch_shapes = tuple(scratch_shapes)
def get_grid_mapping(
def _make_scalar_ref_aval(self, aval):
return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
TPUMemorySpace.SMEM)
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
if isinstance(obj, MemoryRef):
return obj.get_aval()
if isinstance(obj, SemaphoreType):
return obj.get_aval()
raise ValueError(f"No registered conversion for {type(obj)}. "
"Only VMEM and SemaphoreType are supported.")
def get_grid_mapping( # type: ignore[override]
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
pallas_core.dynamic_grid_dim if d is None else d for d in self.grid
)
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
self.scratch_shapes)
flat_scratch_avals = map(_make_aval, flat_scratch_shapes)
scalar_avals, unflat_in_avals = split_list(
all_avals, [self.num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
num_flat_scalar_prefetch = len(flat_scalar_avals)
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_avals_tree, out_avals, out_tree)
in_ref_avals, out_ref_avals = (
pallas_core._get_ref_avals(
in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
out_avals, flat_out_specs, out_paths))
scalar_ref_avals = [
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
TPUMemorySpace.SMEM)
for aval in flat_scalar_avals]
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
index_map_in_tree = tree_util.tree_structure(
((*grid_avals, *scalar_avals), {})
)
in_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="input",
),
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_ref_avals,
)
out_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=index_map_in_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="output",
),
flat_out_specs,
out_paths,
out_ref_avals,
)
grid_mapping = GridMapping(
grid=grid_mapping_grid, grid_names=self.grid_names, # type: ignore
block_mappings=(*in_block_mappings, *out_block_mappings),
mapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_scratch_operands=len(flat_scratch_avals)
)
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
*jaxpr_in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals,
*jaxpr_scratch_avals), grid_mapping
return super().get_grid_mapping(in_avals, in_tree, in_paths,
out_avals, out_tree, out_paths,
num_scalar_prefetch=self.num_scalar_prefetch,
scratch_shapes=self.scratch_shapes)
@dataclasses.dataclass(frozen=True)

View File

@ -294,7 +294,7 @@ class MosaicGridMapping:
self.grid_names = grid_mapping.grid_names
self.jaxpr = jaxpr
self.block_mappings = grid_mapping.block_mappings
self.mapped_dims = grid_mapping.mapped_dims
self.mapped_dims = grid_mapping.vmapped_dims
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
@ -425,13 +425,15 @@ class MeshInfo:
def lower_jaxpr_to_module(
ctx: ir.Context,
grid_mapping: pl_core.GridMapping,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
jaxpr: jax_core.Jaxpr,
dimension_semantics: tuple[str | None, ...] | None,
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
# TODO(necula): cleanup
in_shapes = grid_mapping.in_shapes
out_shapes = grid_mapping.out_shapes
mosaic_grid_mapping = MosaicGridMapping(
jaxpr, grid_mapping, dimension_semantics, mesh)
mosaic_grid_mapping.maybe_compress_grid()
@ -454,10 +456,8 @@ def lower_jaxpr_to_module(
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
# TODO(necula): we should not need block_operand_shapes anymore
block_operand_shapes = (
*[jax.ShapeDtypeStruct(v.aval.shape,
v.aval.dtype)
for v in invars[:grid_mapping.num_constant_operands]],
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
@ -466,10 +466,6 @@ def lower_jaxpr_to_module(
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
):
func_name = f"transform_{i}"
if bm is None:
raise NotImplementedError(
"BlockSpecs are required on TPU when grid is specified"
)
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
# We may not require windowing if our block_shape matches the original

View File

@ -32,7 +32,6 @@ from jax._src.lib.mlir import ir
from jax._src.pallas import core
from jax._src.pallas.mosaic import lowering
from jax._src.pallas.mosaic import verification
from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental import mosaic
from jax.experimental.mosaic.dialects import tpu
@ -69,21 +68,13 @@ def pallas_call_tpu_lowering_rule(
name: str,
grid_mapping: core.GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
debug: bool,
interpret: bool,
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,
compiler_params=compiler_params)
del interpret
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if debug:
print(jaxpr)
if "mosaic_params" in compiler_params:
@ -114,7 +105,7 @@ def pallas_call_tpu_lowering_rule(
with mlir_ctx, ir.Location.unknown(mlir_ctx):
dimension_semantics = mosaic_params.get("dimension_semantics", None)
return lowering.lower_jaxpr_to_module(
mlir_ctx, grid_mapping, in_shapes, out_shapes, jaxpr,
mlir_ctx, grid_mapping, jaxpr,
dimension_semantics=dimension_semantics, mesh=mesh,
for_verification=for_verification)
mosaic_module, extra_args = lower_module(for_verification=False)

View File

@ -155,14 +155,14 @@ class LoweringError(Exception):
def lower_jaxpr_to_module(
grid_mapping: pl_core.GridMapping,
in_structs: tuple[jax.ShapeDtypeStruct, ...],
out_structs: tuple[jax.ShapeDtypeStruct, ...],
jaxpr: jax_core.Jaxpr,
name: str,
compiler_params: dict[str, Any],
) -> LoweringResult:
in_structs = grid_mapping.in_shapes
out_structs = grid_mapping.out_shapes
assert len(jaxpr.outvars) == 0
assert not grid_mapping.mapped_dims
assert not grid_mapping.vmapped_dims
grid = grid_mapping.grid
if len(grid) < 3:
grid += (1,) * (3 - len(grid))

View File

@ -19,12 +19,10 @@ from __future__ import annotations
from typing import Any
import jax
from jax import core as jax_core
from jax._src.interpreters import mlir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import lowering
from jax._src.pallas.pallas_call import pallas_call_p
from jax.experimental.mosaic import gpu as mosaic_gpu
@ -33,30 +31,13 @@ def pallas_call_lowering(
*args,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*args,
jaxpr=jaxpr,
name=name,
out_shapes=out_shapes,
in_shapes=in_shapes,
interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,
compiler_params=compiler_params,
)
del interpret
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Mosaic GPU backend"
@ -75,8 +56,6 @@ def pallas_call_lowering(
)
lowering_result = lowering.lower_jaxpr_to_module(
grid_mapping,
in_shapes,
out_shapes,
jaxpr,
name,
compiler_params,

View File

@ -61,6 +61,7 @@ BlockSpecTree = pallas_core.BlockSpecTree
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
# See the docstring for GridMapping for the calling convention
pallas_call_p = jax_core.Primitive('pallas_call')
pallas_call_p.multiple_results = True
@ -103,8 +104,6 @@ def _pad_values_to_block_dimension(value,
Returns:
A padded array.
"""
if block_shape is None:
return value
padded_shape = tuple(
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
)
@ -166,12 +165,14 @@ def _pallas_call_impl(*args, **kwargs):
def _pallas_call_impl_interpret(
*args,
jaxpr: jax_core.Jaxpr,
name: str, in_shapes, out_shapes,
name: str,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
compiler_params: Any):
del compiler_params, name, in_shapes
del compiler_params, name
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list( # type: ignore
@ -192,21 +193,13 @@ def _pallas_call_impl_interpret(
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
- grid_mapping.num_index_operands
- grid_mapping.num_scratch_operands
)
_, _, scratch_invars = split_list(
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
)
scratch_invars = jaxpr.invars[grid_mapping.slice_scratch_ops]
scratch_avals = [v.aval for v in scratch_invars]
scratch_values = _initialize_scratch_vals(scratch_avals)
carry = []
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
@ -215,18 +208,14 @@ def _pallas_call_impl_interpret(
x = lax.pad(x, pad_value, [(*p, 0) for p in padding])
carry.append(x)
block_shapes_without_mapped_dims = [
None if block_mapping is None else block_mapping.block_shape
for block_mapping in grid_mapping.block_mappings
]
is_indexing_dim = [
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
for bm in block_shapes_without_mapped_dims
tuple(b is pallas_core.mapped for b in bm.block_shape)
for bm in grid_mapping.block_mappings
]
block_shapes = [
None if (bm is None or iid is None)
else tuple(1 if i else b for i, b in zip(iid, bm))
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
None if iid is None
else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
]
# Pad values to evenly divide into block dimensions. This matches the
@ -254,7 +243,7 @@ def _pallas_call_impl_interpret(
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.mapped_dims
if dim not in grid_mapping.vmapped_dims
)
carry, scratch = split_list(carry, [num_inout])
with pallas_core.grid_env(local_grid_env):
@ -288,7 +277,7 @@ def _pallas_call_impl_interpret(
out_block_mappings = grid_mapping.block_mappings[len(args):]
out_nopad = []
for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings):
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
if input_output_aliases:
@ -303,13 +292,16 @@ def _pallas_call_impl_interpret(
pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
def _pallas_call_abstract_eval(*avals, grid_mapping, **_):
out_shapes = grid_mapping.out_shapes
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
grid_mapping, debug, interpret, compiler_params: Any):
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if grid_mapping.num_index_operands:
@ -345,14 +337,17 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
print(jvp_jaxpr)
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
jvp_grid_mapping = grid_mapping.replace(
block_mappings=jvp_bms,
num_inputs=grid_mapping.num_inputs * 2,
num_outputs=grid_mapping.num_outputs * 2,
)
out_flat = pallas_call_p.bind(
*primals,
*tangents,
jaxpr=jvp_jaxpr,
name=f"{name}_jvp",
in_shapes=(*in_shapes, *in_shapes),
out_shapes=(*out_shapes, *out_shapes),
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
grid_mapping=jvp_grid_mapping,
interpret=interpret,
debug=debug,
input_output_aliases=(),
@ -362,40 +357,38 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
return out_primals, out_tangents
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
def _batch_block_mapping(grid_mapping: GridMapping, aval: jax_core.ShapedArray,
def _batch_block_mapping(grid_mapping: GridMapping,
axis_size: int,
aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping | None) -> BlockMapping:
block_mapping: BlockMapping) -> BlockMapping:
def _block_map_function(new_idx, *args):
if block_mapping is None:
indices = [0] * len(aval.shape)
else:
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
block_mapping.index_map_jaxpr.consts,
*args)
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
block_mapping.index_map_jaxpr.consts,
*args)
if dim is not batching.not_mapped:
indices.insert(dim, new_idx)
return tuple(indices)
i32_aval = jax_core.ShapedArray((), jnp.int32)
if block_mapping is None:
idx_avals = [i32_aval] * (len(grid_mapping.grid) + 1)
else:
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
shape = aval.shape if block_mapping is None else block_mapping.block_shape
shape = block_mapping.block_shape
if dim is batching.not_mapped:
new_block_shape = shape
new_array_shape_dtype = block_mapping.array_shape_dtype
else:
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
new_array_shape_dtype = jax.ShapeDtypeStruct(
tuple_insert(block_mapping.array_shape_dtype.shape,
dim,
axis_size),
block_mapping.array_shape_dtype.dtype)
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
if block_mapping is None:
return BlockMapping(
block_shape=new_block_shape,
index_map_jaxpr=jaxpr,
indexing_mode=pallas_core.blocked,
)
return block_mapping.replace(block_shape=new_block_shape,
array_shape_dtype=new_array_shape_dtype,
index_map_jaxpr=jaxpr)
@ -434,8 +427,6 @@ def _batch_with_explicit_loop(
*,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
@ -453,7 +444,8 @@ def _batch_with_explicit_loop(
to the current iteration index and dynamic_updates an (initially empty) output
allocation.
"""
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if not dims:
raise NotImplementedError("vmapping pallas_call with no arguments.")
@ -499,13 +491,10 @@ def _batch_with_explicit_loop(
axis=dim,
)
)
batch_out = pallas_call_p.bind(
*batch_args,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -533,15 +522,14 @@ def _pallas_call_batching_rule(
*,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
grid_mapping: GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
compiler_params: Any,
):
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
) -> jax.Array:
@ -558,8 +546,6 @@ def _pallas_call_batching_rule(
*args,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -591,8 +577,6 @@ def _pallas_call_batching_rule(
dims=dynamic_grid_dims + dims,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -625,8 +609,6 @@ def _pallas_call_batching_rule(
dims=scalar_bdims + bdims,
jaxpr=jaxpr,
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -652,7 +634,6 @@ def _pallas_call_batching_rule(
all_dims = list(dims) + [0] * len(out_shapes)
num_index_operands = grid_mapping.num_index_operands
num_constant_operands = grid_mapping.num_constant_operands
num_scratch_operands = grid_mapping.num_scratch_operands
# Only add a batch dimension for the avals that actually have a grid mapping.
@ -660,37 +641,27 @@ def _pallas_call_batching_rule(
# operands (the last in the list).
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
batched_block_mappings = map(
partial(_batch_block_mapping, grid_mapping),
partial(_batch_block_mapping, grid_mapping, axis_size),
avals_to_batch,
all_dims[num_index_operands:],
block_mappings,
)
# TODO(necula): should fix in_shapes to include the consts
dims_no_consts = (
dims[:num_index_operands] +
dims[num_index_operands + num_constant_operands:]
)
batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
x.dtype)
for x, dim in zip(in_shapes, dims_no_consts))
batched_out_shapes = tuple(
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
for x in out_shapes)
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(grid_mapping.index_map_avals)
assert not index_map_tree_kwargs
batched_index_map_args = (jax_core.ShapedArray((), jnp.int32),) + index_map_tree_args
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten((batched_index_map_args, {}))
batched_grid_mapping = grid_mapping.replace(
grid=(axis_size, *grid_mapping.grid),
block_mappings=tuple(batched_block_mappings),
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
index_map_avals=batched_index_map_avals,
index_map_tree=batched_index_map_tree,
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims))
out = pallas_call_p.bind(
*dynamic_grid_args,
*args,
jaxpr=jaxpr,
name=f"batched_{name}",
in_shapes=batched_in_shapes,
out_shapes=batched_out_shapes,
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -725,7 +696,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
interpret: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
out_shapes,
**kwargs):
# We implement the checkify rule in 4 steps:
# 1) First, trace the kernel body to get the expected error shapes.
@ -831,37 +801,25 @@ def pallas_call_checkify_rule(error: checkify.Error,
# Prepare pallas_call inputs. We need to create new block specs
# for the new error inputs and outputs.
scalar_avals = map(checkify.get_shaped_aval, scalars)
error_block_specs = [pallas_core.BlockSpec(
index_map=lambda *args: (0,) * len(error.shape),
block_shape=error.shape)
for error in shaped_err_avals]
error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals)
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
grid_avals = [
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
scalar_ref_avals = [
pallas_core.AbstractMemoryRef(
jax_core.ShapedArray(aval.shape, aval.dtype),
pallas_core.MemorySpace.INDEX)
for aval in scalar_avals]
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
error_block_mappings = map(
partial(
pallas_core._convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals),
in_tree=grid_tree,
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.mapped_dims,
mapped_dims=grid_mapping.vmapped_dims,
what="error"),
error_block_specs, error_paths, error_memref_aval)
error_block_specs, error_paths, shaped_err_avals)
input_block_mappings, output_block_mappings = split_list(
grid_mapping.block_mappings, [num_kernel_inputs,])
grid_mapping_with_error = grid_mapping.replace(
block_mappings=(*error_block_mappings, *input_block_mappings,
*error_block_mappings, *output_block_mappings)
*error_block_mappings, *output_block_mappings),
num_inputs=grid_mapping.num_inputs + len(error_block_mappings),
num_outputs=grid_mapping.num_outputs + len(error_block_mappings)
)
error_out_shapes = tuple(
jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals)
# Bump all input_output_aliases by num_err_vals to make room for error
# TODO(justinfu): Don't bump scalars here.
input_output_aliases = tuple(
@ -870,17 +828,11 @@ def pallas_call_checkify_rule(error: checkify.Error,
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
new_vals_in = [*scalars, *err_vals, *args]
new_input_shapes = tuple(
jax.ShapeDtypeStruct(x.shape, x.dtype) for x in [
*scalars, *shaped_err_avals, *args])
del kwargs['in_shapes']
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
jaxpr=final_jaxpr,
interpret=interpret,
grid_mapping=grid_mapping_with_error,
input_output_aliases=input_output_aliases_with_error,
in_shapes=new_input_shapes,
out_shapes=error_out_shapes + out_shapes,
**kwargs)
errors, results = split_list(result, [num_err_vals])
# TODO(b/350593266): Remove line below once we support ()-shaped scalars.
@ -890,25 +842,20 @@ def pallas_call_checkify_rule(error: checkify.Error,
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
@weakref_lru_cache
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
flat_in_avals: Sequence[jax_core.AbstractValue],
flat_out_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
interpret: bool):
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
def _trace_kernel_to_jaxpr(fun: Callable,
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
interpret: bool):
if interpret:
avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals)
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), jaxpr_in_tree)
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
kernel_avals))
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), kernel_in_tree)
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
if consts:
# Pad ``block_mappings`` to account for the hoisted constants.
# The constants will be right after the index operands and just before
@ -918,20 +865,14 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
index=grid_mapping.num_index_operands,
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
num_constant_operands = len(consts)
# TODO(necula): refactor grid_mapping to remove this code duplication
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
if grid_mapping.num_index_operands:
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
const_block_mappings = []
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
grid_avals,
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
aval=jax_core.ShapedArray(c.shape, c.dtype),
in_tree=grid_tree,
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
@ -942,7 +883,12 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
num_constant_operands=num_constant_operands,
)
return grid_mapping, jaxpr, consts, out_tree_thunk()
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {kernel_out_tree}")
return grid_mapping, jaxpr, consts
def _extract_function_name(f: Callable, name: str | None) -> str:
if name is None:
@ -1044,7 +990,7 @@ def pallas_call(
*,
grid_spec: GridSpec | None = None,
debug: bool = False,
grid: Grid | None = None,
grid: Grid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
input_output_aliases: dict[int, int] = {},
@ -1066,8 +1012,7 @@ def pallas_call(
debug: if True, Pallas prints various intermediate forms of the kernel
as it is being processed.
grid: the iteration space, as a tuple of integers. The kernel is executed
as many times as ``prod(grid)``. The default value ``None`` is equivalent
to ``()``.
as many times as ``prod(grid)``.
See details at :ref:`pallas_grid`.
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
a structure matching that of the positional arguments.
@ -1097,10 +1042,23 @@ def pallas_call(
name = _extract_function_name(f, name)
if compiler_params is None:
compiler_params = {}
if grid is not None and grid_spec is not None:
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)
else:
if grid:
raise ValueError(
"If `grid_spec` is specified, then `grid` must "
f"be `()`. It is {grid}")
if in_specs is not no_block_spec:
raise ValueError(
"If `grid_spec` is specified, then `in_specs` must "
f"be `no_block_spec`. It is {in_specs}")
if out_specs is not no_block_spec:
raise ValueError(
"If `grid_spec` is specified, then `out_specs` must "
f"be `no_block_spec`. It is {out_specs}")
del grid, in_specs, out_specs
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
# TODO(necula): this canonicalization may be convenient for some usage
# but it is lossy, because it prevents expressing functions that return
@ -1119,14 +1077,14 @@ def pallas_call(
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
# TODO(necula): check that input_output_aliases is well-formed: shapes match, no duplicates, etc.
grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
out_tree, out_paths, interpret=interpret)
if f_out_tree != tree_util.tree_flatten(None)[1]:
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {f_out_tree}")
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
kernel_avals, grid_mapping = grid_spec.get_grid_mapping(
flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
f, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
interpret=interpret)
for i_idx, o_idx in input_output_aliases.items():
if i_idx not in range(len(flat_in_avals)):
raise ValueError(
@ -1152,9 +1110,7 @@ def pallas_call(
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
out_shapes=tuple(flat_out_shapes), debug=debug,
debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,
input_output_aliases=tuple(input_output_aliases.items()),

View File

@ -112,10 +112,8 @@ class LoweringError(Exception):
def _eval_index_map(
ctx: ModuleContext, idx, block_mapping: BlockMapping | None
ctx: ModuleContext, idx, block_mapping: BlockMapping
):
if block_mapping is None:
return None
block_indices = lower_jaxpr_to_triton_ir(
ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx
)
@ -187,13 +185,13 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping):
# Preserve grid order provided to pallas_call
for i, s in enumerate(grid_mapping.grid):
if i not in grid_mapping.mapped_dims:
if i not in grid_mapping.vmapped_dims:
launch_grid.append(s)
launch_grid_to_pallas_grid.append(i)
# For mapped dims, iterate from inner to outer. This follows the pallas_call
# batching rule that prepends the vmapped dimension.
for dim in reversed(grid_mapping.mapped_dims):
for dim in reversed(grid_mapping.vmapped_dims):
s = grid_mapping.grid[dim]
launch_grid.append(s)
launch_grid_to_pallas_grid.append(dim)
@ -287,7 +285,7 @@ def lower_jaxpr_to_triton_module(
local_program_ids = [
pid
for i, pid in enumerate(program_ids)
if i not in grid_mapping.mapped_dims
if i not in grid_mapping.vmapped_dims
]
ctx = ModuleContext(
name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
@ -297,7 +295,7 @@ def lower_jaxpr_to_triton_module(
"Scalar prefetch not supported in Triton lowering."
)
for bm in grid_mapping.block_mappings:
if bm is not None and not isinstance(bm.indexing_mode, Blocked):
if not isinstance(bm.indexing_mode, Blocked):
raise NotImplementedError(
"Only Blocked indexing mode is supported in Triton lowering."
)
@ -305,20 +303,14 @@ def lower_jaxpr_to_triton_module(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
consts_shapes = [
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
]
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
start_idx,
block_mapping.block_shape,
)
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
(*consts_shapes, *in_out_shapes),
in_out_shapes,
grid_mapping.block_mappings,
start_indices,
)

View File

@ -19,12 +19,10 @@ from __future__ import annotations
import io
from typing import Any
import jax
from jax import core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.triton import lowering
@ -45,30 +43,16 @@ def pallas_call_lowering(
*in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
compiler_params: dict[str, Any],
):
if interpret:
# TODO(necula): is this branch still needed?
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
jaxpr=jaxpr,
name=name,
out_shapes=out_shapes,
in_shapes=in_shapes,
interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,
compiler_params=compiler_params,
)
del interpret
# TODO(necula): cleanup
in_shapes = grid_mapping.in_shapes
out_shapes = grid_mapping.out_shapes
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"

View File

@ -479,7 +479,6 @@ class CompatTest(bctu.CompatTestBase):
np.asarray(out), atol=1e-4, rtol=1e-4))
@jtu.parameterized_filterable(
one_containing="f32",
kwargs=[
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128")])

View File

@ -26,7 +26,7 @@ from jax import random
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
from jax.experimental.pallas.ops.gpu import attention
from jax.experimental.pallas.ops.gpu import layer_norm
@ -129,7 +129,7 @@ class PallasBaseTest(jtu.JaxTestCase):
self.skipTest("Only works on non-Windows platforms")
super().setUp()
_trace_to_jaxpr.cache_clear()
_trace_kernel_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)

View File

@ -27,7 +27,7 @@ from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax import export
@ -93,7 +93,7 @@ class ShapePolyTest(jtu.JaxTestCase,
if sys.platform == "win32":
self.skipTest("Only works on non-Windows platforms")
super().setUp()
_trace_to_jaxpr.cache_clear()
_trace_kernel_to_jaxpr.cache_clear()
def test_copy(self):
# The blocks are static, but the input and the grid are of polymorphic

View File

@ -32,7 +32,7 @@ from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
@ -145,14 +145,13 @@ class PallasBaseTest(jtu.JaxTestCase):
self.skipTest("Only works on non-Windows platforms")
super().setUp()
_trace_to_jaxpr.cache_clear()
_trace_kernel_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
class PallasCallTest(PallasBaseTest):
def test_add_one(self):
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
@ -655,7 +654,6 @@ class PallasCallInterpreterTest(PallasCallTest):
class ApiErrorTest(PallasBaseTest):
def test_pallas_kernel_args_mismatch(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
@ -719,7 +717,7 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((4,), lambda: (0, 0))])
with self.assertRaisesRegex(
ValueError,
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
"Index map for inputs\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)
def test_pallas_call_index_map_captures_consts(self):
@ -730,7 +728,7 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for input\\[0\\] captures constants"):
"Index map for inputs\\[0\\] captures constants"):
f(a)
def test_pallas_call_out_specs_mismatch_shape(self):
@ -752,7 +750,7 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))])
with self.assertRaisesRegex(
ValueError,
"Block shape for input\\[0\\] .* must have the same number of dimensions as the "
"Block shape for inputs\\[0\\] .* must have the same number of dimensions as the "
"array shape"):
f(a)
@ -762,7 +760,7 @@ class ApiErrorTest(PallasBaseTest):
out_specs=[pl.BlockSpec((1, 1), lambda: 0)])
with self.assertRaisesRegex(
ValueError,
"Block shape for output\\[0\\] .* must have the same number of dimensions as the "
"Block shape for outputs\\[0\\] .* must have the same number of dimensions as the "
"array shape"):
f(a)
@ -1893,7 +1891,6 @@ class PallasCheckifyInterpreterTest(PallasBaseTest):
class PallasCallNamedGridTest(PallasBaseTest):
def test_named_grid(self):
def kernel(x_ref, y_ref):

View File

@ -24,7 +24,7 @@ from jax import random
from jax._src.lib import xla_extension
from jax._src import config
from jax._src import test_util as jtu
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
@ -52,7 +52,7 @@ class PallasBaseTest(jtu.JaxTestCase):
self.skipTest("Only works on non-Windows platforms")
super().setUp()
_trace_to_jaxpr.cache_clear()
_trace_kernel_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)

View File

@ -29,7 +29,7 @@ from jax._src import state
from jax._src import test_util as jtu
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_extension
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import mesh_utils
from jax.experimental import mosaic
from jax.experimental import pallas as pl
@ -65,7 +65,7 @@ class PallasBaseTest(jtu.JaxTestCase):
if not jtu.test_device_matches(['tpu']) and not self.INTERPRET:
self.skipTest('Test requires TPUs, or interpret mode')
super().setUp()
_trace_to_jaxpr.cache_clear()
_trace_kernel_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
@ -117,12 +117,17 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
@jtu.parameterized_filterable(
kwargs=[
dict(scratch=scratch, vmap=vmap)
dict(scratch=scratch, vmap=vmap, dyn_grid=dyn_grid)
for scratch in [True, False]
for vmap in [True, False]
for vmap in [False, True]
for dyn_grid in [False, True]
]
)
def test_scalar_prefetch_hoisted_const(self, *, scratch: bool, vmap: bool):
def test_scalar_prefetch_calling_convention(
self, *,
scratch: bool, vmap: bool, dyn_grid: bool):
# Tests what we process correctly all the various inputs and outputs:
# dynamic_grid_dims, index, inputs, outputs, scratch.
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
@ -133,30 +138,39 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
x_shape = (16, 128)
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)
def f(x):
def f(x, grid_size):
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(2,),
num_scalar_prefetch=1, # 1 pytree
grid=(grid_size,),
in_specs=[pl.BlockSpec((8, 128),
lambda i, s_ref: (pl.load(s_ref, (i,)), 0))],
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0))],
out_specs=pl.BlockSpec((32, 128),
lambda i, s_ref: (pl.load(s_ref, i), 0)),
lambda i, s_ref: (pl.load(s_ref[0], i), 0)),
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
else []),
),
)
def kernel(s_ref, src, dst, *scratch_refs):
def kernel(s_refs, src, dst, *scratch_refs):
s_ref, s2, s3 = s_refs
assert s_ref.shape == (2,)
assert s2.shape == (3,)
assert s3 is None
store_idx = s_ref[pl.program_id(0)]
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)
return kernel(s, x)
# Pass a pytree of scalar
return kernel((s, np.arange(3, dtype=np.int32), None), x)
if dyn_grid:
f = jax.jit(f)
if vmap:
f = jax.vmap(f)
res = f(x)
res = jax.vmap(lambda x: f(x, 2))(x)
else:
res = f(x, 2)
if vmap:
for i in range(x.shape[0]):
self.assertAllClose(res[i, 0:1], to_store)
@ -165,6 +179,35 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
self.assertAllClose(res[0:1], to_store)
self.assertAllClose(res[33:34], to_store)
def test_with_unhashable_grid_spec(self):
# Make sure that we don't crash when the GridSpec has non-hashable parts
@functools.partial(
self.pallas_call,
out_shape=[[jax.ShapeDtypeStruct((8, 128), np.int32)]],
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1, # 1 pytree
grid=(1,),
in_specs=[[pl.BlockSpec((8, 128),
lambda i, s_ref: (0, 0))]],
out_specs=[[pl.BlockSpec((8, 128),
lambda i, s_ref: (0, 0))]],
scratch_shapes=[[pltpu.SemaphoreType.REGULAR((3,))]],
),
)
def kernel(s_ref, x_ref, o_ref, scratch_ref):
assert isinstance(s_ref, list)
assert isinstance(x_ref, list)
assert isinstance(o_ref, list)
assert isinstance(scratch_ref, list)
o_ref[0][...] = x_ref[0][...]
x_shape = (8, 128)
s = np.array([0, 1], np.int32)
x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape)
res = kernel([s, s], [x])
self.assertIsInstance(res, tuple) # Even though we asked for a list!
self.assertAllClose(res[0][0], x)
def test_block_spec_with_wrong_block_shape_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
@ -210,7 +253,7 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
ValueError,
r'Index map for input\[0\] must return 2 values to match block_shape=\(8, 128\).'
r'Index map for inputs\[0\] must return 2 values to match block shape \(8, 128\).'
' Currently returning 1 values.'):
_ = self.pallas_call(
body,
@ -2356,7 +2399,6 @@ class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest):
class PallasCallTPUCheckifyTest(PallasBaseTest):
@parameterized.parameters((2,), (5,), (6,), (7,))
def test_checkify_with_scalar_prefetch(self, threshold):
def body(scalar_ref, x_ref, o_ref):

View File

@ -318,7 +318,6 @@ class PallasBaseTest(jtu.JaxTestCase):
class SplashAttentionTest(PallasBaseTest):
@parameterized.product(
is_mqa=(False, True),
is_segmented=(False, True),
@ -544,7 +543,7 @@ class SplashAttentionTest(PallasBaseTest):
data.draw(mha_strategy())
)
# Avoid segment ids for rectangular matrices, as its hard to enforce
# Avoid segment ids for rectangular matrices, as it's hard to enforce
# valid masks (non-0 rows).
hp.assume(q_seq_len == kv_seq_len or not is_segmented)