mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
This commit is contained in:
parent
c8ea86c9c9
commit
4063373b22
@ -34,7 +34,7 @@ This generalizes to any tuple of integers (a length `d` grid will correspond
|
||||
to `d` nested loops).
|
||||
The kernel is executed as many times
|
||||
as `prod(grid)`.
|
||||
The default grid value `None` stands for `()`, and results in one
|
||||
The default grid value `()` results in one
|
||||
kernel invocation.
|
||||
Each of these invocations is referred to as a "program".
|
||||
To access which program (i.e. which element of the grid) the kernel is currently
|
||||
|
@ -27,6 +27,7 @@ import warnings
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util as lu
|
||||
@ -40,7 +41,8 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class DynamicGridDim:
|
||||
pass
|
||||
def __repr__(self):
|
||||
return "DynamicGridDim"
|
||||
dynamic_grid_dim = DynamicGridDim()
|
||||
|
||||
|
||||
@ -173,18 +175,27 @@ def current_grid_env() -> GridEnv | None:
|
||||
|
||||
|
||||
class Mapped:
|
||||
pass
|
||||
"""Used as a block shape dimension to denote a mapped dimension.
|
||||
A mapped dimension behaves like `1` except it is squeezed from the block.
|
||||
See :ref:`pallas_blockspec` for more details.
|
||||
"""
|
||||
def __repr__(self):
|
||||
return "Mapped"
|
||||
mapped = Mapped()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Unblocked:
|
||||
padding: tuple[tuple[int, int], ...] | None = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"Unblocked(padding={self.padding})"
|
||||
unblocked = Unblocked()
|
||||
|
||||
|
||||
class Blocked:
|
||||
pass
|
||||
def __repr__(self):
|
||||
return "Blocked"
|
||||
blocked = Blocked()
|
||||
|
||||
|
||||
@ -196,6 +207,8 @@ class BlockSpec:
|
||||
"""Specifies how an array should be sliced for each iteration of a kernel.
|
||||
|
||||
See :ref:`pallas_blockspec` for more details.
|
||||
This object contains the parameters passed through the API.
|
||||
An internal canonicalized version is in BlockMapping.
|
||||
"""
|
||||
block_shape: tuple[int | None, ...] | None = None
|
||||
index_map: Callable[..., Any] | None = None
|
||||
@ -247,9 +260,38 @@ BlockSpecTree = Any
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BlockMapping:
|
||||
"""An internal canonicalized version of BlockSpec.
|
||||
|
||||
See the `check_invariants` method for precise specification.
|
||||
"""
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
block_aval: AbstractMemoryRef # The block ref aval
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
indexing_mode: IndexingMode
|
||||
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
|
||||
origin: str # The origin, e.g. input[2]["field"]
|
||||
|
||||
def check_invariants(self) -> None:
|
||||
if not config.enable_checks.value: return
|
||||
|
||||
unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
|
||||
assert unmapped_block_shape == self.block_aval.shape, (
|
||||
self.block_shape, self.block_aval)
|
||||
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
|
||||
self.block_shape, self.array_shape_dtype
|
||||
)
|
||||
|
||||
assert not self.index_map_jaxpr.consts
|
||||
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals)
|
||||
assert all(ov.shape == () and
|
||||
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
|
||||
for ov in self.index_map_jaxpr.out_avals), (
|
||||
self.index_map_jaxpr.out_avals)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
new_self = dataclasses.replace(self, **kwargs)
|
||||
new_self.check_invariants()
|
||||
return new_self
|
||||
|
||||
def compute_start_indices_interpret(self, loop_idx, *args):
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
|
||||
@ -269,8 +311,6 @@ class BlockMapping:
|
||||
else:
|
||||
raise RuntimeError(f"Unknown indexing mode: {self.indexing_mode}")
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
|
||||
@ -285,16 +325,86 @@ def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GridMapping:
|
||||
"""An internal canonicalized version of GridSpec.
|
||||
|
||||
Encodes the calling conventions of the pallas_call primitive, the kernel,
|
||||
and the index maps.
|
||||
|
||||
The pallas_call is invoked with: ``*dynamic_grid_sizes, *index, *consts, *inputs``.
|
||||
The ``index`` operands are for the scalar prefetch.
|
||||
The ``consts`` are constants captured by the kernel function.
|
||||
|
||||
The kernel function is invoked with:
|
||||
``*index, *consts, *inputs, *scratch``.
|
||||
|
||||
The index map functions are invoked with:
|
||||
``*program_ids, *index``.
|
||||
|
||||
See the `check_invariants` method for a more precise specification.
|
||||
"""
|
||||
grid: GridMappingGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
block_mappings: tuple[BlockMapping | None, ...]
|
||||
mapped_dims: tuple[int, ...] = ()
|
||||
num_index_operands: int = 0
|
||||
num_scratch_operands: int = 0
|
||||
# Number of constants hoisted to operands by ``_hoist_consts_to_refs``.
|
||||
num_constant_operands: int = 0
|
||||
|
||||
replace = dataclasses.replace
|
||||
# Block mappings for: *consts, *inputs, *outputs
|
||||
block_mappings: tuple[BlockMapping, ...]
|
||||
# The inputs for tracing the index map: the tree and the flat avals
|
||||
index_map_tree: tree_util.PyTreeDef
|
||||
index_map_avals: tuple[jax_core.AbstractValue]
|
||||
# Which dimensions in `grid` are vmapped.
|
||||
vmapped_dims: tuple[int, ...]
|
||||
|
||||
num_index_operands: int
|
||||
# Number of captured constants hoisted to operands.
|
||||
num_constant_operands: int
|
||||
num_inputs: int
|
||||
num_outputs: int
|
||||
num_scratch_operands: int
|
||||
|
||||
def check_invariants(self) -> None:
|
||||
if not config.enable_checks.value: return
|
||||
assert (len(self.block_mappings) ==
|
||||
self.num_constant_operands + self.num_inputs + self.num_outputs), (
|
||||
self.num_constant_operands, self.num_inputs, self.num_outputs,
|
||||
self.block_mappings
|
||||
)
|
||||
# index_map_avals = int32[] * len(self.grid) + index_operands
|
||||
assert len(self.index_map_avals) == len(self.grid) + self.num_index_operands, (
|
||||
self.index_map_avals,
|
||||
self.grid,
|
||||
self.num_index_operands,
|
||||
)
|
||||
# Check that we can put together the avals and the tree.
|
||||
index_map_args, index_map_kwargs = self.index_map_tree.unflatten(
|
||||
self.index_map_avals)
|
||||
assert not index_map_kwargs
|
||||
assert len(index_map_args) >= len(self.grid)
|
||||
for i in range(len(self.grid)):
|
||||
index_map_arg = index_map_args[i]
|
||||
assert index_map_arg.shape == ()
|
||||
assert index_map_arg.dtype == jnp.int32
|
||||
|
||||
assert len(self.vmapped_dims) <= len(self.grid)
|
||||
for i in self.vmapped_dims:
|
||||
assert 0 <= i < len(self.grid)
|
||||
|
||||
if self.grid_names is not None:
|
||||
assert len(self.grid) == len(self.grid_names), (self.grid, self.grid_names)
|
||||
for bm in self.block_mappings:
|
||||
bm.check_invariants()
|
||||
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
|
||||
self.index_map_avals,
|
||||
bm.index_map_jaxpr.in_avals,
|
||||
)
|
||||
|
||||
def replace(self, **kwargs) -> GridMapping:
|
||||
new_self = dataclasses.replace(self, **kwargs)
|
||||
new_self.check_invariants()
|
||||
return new_self
|
||||
|
||||
@property
|
||||
# TODO(necula): deprecate and then remove this property.
|
||||
def mapped_dims(self) -> tuple[int, ...]:
|
||||
return self.vmapped_dims
|
||||
|
||||
@property
|
||||
def num_dynamic_grid_bounds(self):
|
||||
@ -314,74 +424,123 @@ class GridMapping:
|
||||
axis_env_ctx = jax_core.extend_axis_env_nd(
|
||||
zip(self.grid_names, self.grid)
|
||||
)
|
||||
with tracing_grid_env(self.grid, self.mapped_dims), axis_env_ctx:
|
||||
with tracing_grid_env(self.grid, self.vmapped_dims), axis_env_ctx:
|
||||
yield
|
||||
|
||||
@property
|
||||
def slice_index_ops(self):
|
||||
"""Returns a slice object to select the index operands to a kernel."""
|
||||
return slice(0, self.num_index_operands)
|
||||
|
||||
@property
|
||||
def slice_block_ops(self):
|
||||
"""Returns a slice to select all but the index operands to a kernel."""
|
||||
return slice(self.num_index_operands, None)
|
||||
|
||||
@property
|
||||
def slice_scratch_ops(self):
|
||||
"""Returns a slice object to select the scratch operands to a kernel."""
|
||||
if self.num_scratch_operands:
|
||||
return slice(-self.num_scratch_operands, None)
|
||||
else:
|
||||
return slice(0, 0)
|
||||
|
||||
# TODO(necula): this is used to recover the old `in_shapes`, but it probably
|
||||
# is not needed anymore, with some cleanup.
|
||||
@property
|
||||
def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
|
||||
"""The shapes of *index, *consts, *inputs."""
|
||||
index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape,
|
||||
ia.inner_aval.dtype)
|
||||
for ia in self.index_map_avals[len(self.grid):]]
|
||||
consts_inputs_shapes = [
|
||||
bm.array_shape_dtype
|
||||
for bm in self.block_mappings[
|
||||
:self.num_constant_operands + self.num_inputs]]
|
||||
return tuple(index_shapes + consts_inputs_shapes)
|
||||
|
||||
# TODO(necula): this is used to recover the old `out_shapes`, but it probably
|
||||
# is not needed anymore, with some cleanup.
|
||||
@property
|
||||
def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
|
||||
return tuple(
|
||||
bm.array_shape_dtype
|
||||
for bm in self.block_mappings[
|
||||
self.num_constant_operands + self.num_inputs:
|
||||
self.num_constant_operands + self.num_inputs + self.num_outputs])
|
||||
|
||||
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
|
||||
if isinstance(dim, jax.Array):
|
||||
return True
|
||||
return jax_core.is_dim(dim)
|
||||
|
||||
def _preprocess_grid(grid: Grid | int | None) -> tuple[TupleGrid, GridNames]:
|
||||
if grid is None:
|
||||
return (), None
|
||||
if isinstance(grid, int):
|
||||
return (grid,), None
|
||||
# Handle empty grid
|
||||
if not grid:
|
||||
return grid, None # type: ignore
|
||||
# Check if we have a named grid
|
||||
if isinstance(grid[0], tuple):
|
||||
grid_names, grid = util.unzip2(grid) # type: ignore
|
||||
else:
|
||||
grid_names = None
|
||||
# TODO(b/353730556): allow NumPy scalars in grids
|
||||
if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore
|
||||
raise ValueError(
|
||||
f"Grid must be a tuple of integers or jax.Array, got {grid}"
|
||||
)
|
||||
return grid, grid_names # type: ignore
|
||||
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
in_avals: Sequence[jax_core.ShapedArray],
|
||||
block_spec: BlockSpec,
|
||||
path: tree_util.KeyPath,
|
||||
aval: jax_core.ShapedArray,
|
||||
in_tree: Any,
|
||||
array_aval: jax_core.ShapedArray,
|
||||
*,
|
||||
# Inputs for the index_map
|
||||
index_map_avals: Sequence[jax_core.AbstractValue],
|
||||
index_map_tree: tree_util.PyTreeDef,
|
||||
grid: GridMappingGrid,
|
||||
mapped_dims: tuple[int, ...],
|
||||
what: str, # Used to localize error messages, e.g., {what}{path}
|
||||
) -> BlockMapping | None:
|
||||
) -> BlockMapping:
|
||||
origin = f"{what}{tree_util.keystr(path)}"
|
||||
if block_spec is no_block_spec:
|
||||
return None
|
||||
block_spec = BlockSpec(None, None)
|
||||
if block_spec.index_map is None:
|
||||
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
|
||||
compute_index = lambda *args: (0,) * len(array_aval.shape)
|
||||
else:
|
||||
compute_index = block_spec.compute_index
|
||||
if block_spec.block_shape is None:
|
||||
block_shape = aval.shape
|
||||
block_shape = array_aval.shape
|
||||
else:
|
||||
block_shape = block_spec.block_shape
|
||||
block_shape = tuple(
|
||||
mapped if s is None else s for s in block_shape)
|
||||
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
if len(out_avals) != len(block_shape):
|
||||
if len(array_aval.shape) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Index map for {what}{tree_util.keystr(path)} must return "
|
||||
f"{len(aval.shape)} values to match {block_shape=}. "
|
||||
f"Currently returning {len(out_avals)} values."
|
||||
f"Block shape for {origin} (= {block_shape}) "
|
||||
f"must have the same number of dimensions as the array shape {array_aval.shape}"
|
||||
)
|
||||
unmapped_block_shape = tuple(s for s in block_shape if s is not None)
|
||||
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
|
||||
block_spec.memory_space)
|
||||
|
||||
if not jax_core.is_constant_shape(block_aval.shape):
|
||||
raise ValueError(
|
||||
"shape polymorphism for Pallas does not support "
|
||||
"dynamically-shaped blocks. "
|
||||
f"{origin} has block_shape: {block_aval.shape}")
|
||||
|
||||
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index),
|
||||
index_map_tree)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
|
||||
index_map_avals)
|
||||
mapped_block_shape = tuple(
|
||||
mapped if s is None else s for s in block_shape)
|
||||
if len(out_avals) != len(mapped_block_shape):
|
||||
raise ValueError(
|
||||
# TODO(necula): show the name and location of the index map function
|
||||
f"Index map for {origin} must return "
|
||||
f"{len(block_aval.shape)} values to match block shape {mapped_block_shape}. "
|
||||
f"Currently returning {len(out_avals)} values."
|
||||
)
|
||||
if consts:
|
||||
raise NotImplementedError(
|
||||
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
|
||||
# TODO(necula): show the name and location of the index map function
|
||||
f"Index map for {origin} captures constants: "
|
||||
f"{consts}")
|
||||
return BlockMapping(
|
||||
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
|
||||
mapping = BlockMapping(
|
||||
block_shape=mapped_block_shape,
|
||||
block_aval=block_aval,
|
||||
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
|
||||
indexing_mode=block_spec.indexing_mode,
|
||||
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
mapping.check_invariants()
|
||||
return mapping
|
||||
|
||||
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
) -> state.AbstractRef:
|
||||
@ -390,65 +549,20 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
shape = tuple(s for s in block_shape if s is not None)
|
||||
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
||||
|
||||
|
||||
def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
|
||||
in_specs: Sequence[BlockSpec],
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_avals: Sequence[jax_core.ShapedArray],
|
||||
out_specs: Sequence[BlockSpec],
|
||||
out_paths: Sequence[tree_util.KeyPath]):
|
||||
def make_ref_aval(aval: jax_core.ShapedArray,
|
||||
spec: BlockSpec,
|
||||
path: tree_util.KeyPath,
|
||||
what: str) -> state.AbstractRef:
|
||||
if spec is no_block_spec:
|
||||
memory_space = None
|
||||
block_shape = None
|
||||
else:
|
||||
memory_space = spec.memory_space
|
||||
block_shape = spec.block_shape
|
||||
|
||||
ref_aval = AbstractMemoryRef(aval, memory_space)
|
||||
if block_shape is not None:
|
||||
if len(ref_aval.shape) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
|
||||
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
|
||||
)
|
||||
block_shape_unmapped = tuple(s for s in block_shape if s is not None)
|
||||
ref_aval = ref_aval.update(
|
||||
inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped))
|
||||
|
||||
if not jax_core.is_constant_shape(ref_aval.shape):
|
||||
raise ValueError(
|
||||
"shape polymorphism for Pallas does not support "
|
||||
"dynamically-shaped blocks. "
|
||||
f"{what}{tree_util.keystr(path)} has block_shape: {ref_aval.shape}")
|
||||
return ref_aval
|
||||
|
||||
in_ref_avals = [
|
||||
make_ref_aval(aval, in_spec, in_path, "input")
|
||||
for aval, in_spec, in_path in zip(in_avals, in_specs, in_paths)
|
||||
]
|
||||
out_ref_avals = [
|
||||
make_ref_aval(aval, out_spec, out_path, "output")
|
||||
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
|
||||
]
|
||||
return in_ref_avals, out_ref_avals
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class GridSpec:
|
||||
"""Encodes the parameters of the grid, as given through the API.
|
||||
|
||||
An internal sanitized version is in GridMapping.
|
||||
"""
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
in_specs_tree: Any
|
||||
out_specs_tree: Any
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid: Grid | None = None,
|
||||
grid: Grid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
):
|
||||
@ -460,89 +574,153 @@ class GridSpec:
|
||||
if isinstance(out_specs, list):
|
||||
out_specs = tuple(out_specs)
|
||||
|
||||
self.grid, self.grid_names = _preprocess_grid(grid)
|
||||
if in_specs is not no_block_spec:
|
||||
flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs)
|
||||
self.in_specs = tuple(flat_in_specs)
|
||||
else:
|
||||
self.in_specs = in_specs
|
||||
self.in_specs_tree = None
|
||||
if out_specs is not no_block_spec:
|
||||
flat_out_specs, self.out_specs_tree = tree_util.tree_flatten(out_specs)
|
||||
self.out_specs = tuple(flat_out_specs)
|
||||
else:
|
||||
self.out_specs = out_specs
|
||||
self.out_specs_tree = None
|
||||
self.in_specs = in_specs
|
||||
self.out_specs = out_specs
|
||||
|
||||
def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree):
|
||||
if self.in_specs is no_block_spec:
|
||||
flat_in_specs = [no_block_spec] * len(in_avals)
|
||||
else:
|
||||
flat_in_specs = self.in_specs
|
||||
if self.in_specs_tree != in_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`in_specs`", self.in_specs_tree,
|
||||
"inputs", in_tree))
|
||||
if self.out_specs is no_block_spec:
|
||||
flat_out_specs = [no_block_spec] * len(out_avals)
|
||||
else:
|
||||
flat_out_specs = self.out_specs
|
||||
if self.out_specs_tree != out_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`out_specs`", self.out_specs_tree,
|
||||
"`out_shape`", out_tree))
|
||||
return flat_in_specs, flat_out_specs
|
||||
grid_names = None
|
||||
if isinstance(grid, int):
|
||||
grid = (grid,)
|
||||
elif grid and isinstance(grid[0], tuple): # Check if we have a named grid
|
||||
grid_names, grid = util.unzip2(grid) # type: ignore
|
||||
|
||||
# TODO(b/353730556): allow NumPy scalars in grids
|
||||
if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore
|
||||
raise ValueError(
|
||||
f"Grid must be a tuple of integers or jax.Array, got {grid}"
|
||||
)
|
||||
self.grid = grid # type: ignore
|
||||
self.grid_names = grid_names
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
self,
|
||||
in_avals: Sequence[jax_core.AbstractValue],
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_avals: Sequence[jax_core.AbstractValue],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
out_paths: Sequence[tree_util.KeyPath],
|
||||
num_scalar_prefetch: int = 0,
|
||||
scratch_shapes: Sequence[Any] = (),
|
||||
) -> tuple[tuple[AbstractMemoryRef, ...],
|
||||
GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in self.grid)
|
||||
grid_mapping_grid = tuple(
|
||||
dynamic_grid_dim if d is None else d for d in self.grid
|
||||
)
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_tree, out_avals, out_tree)
|
||||
in_ref_avals, out_ref_avals = _get_ref_avals(
|
||||
in_avals, flat_in_specs, in_paths,
|
||||
out_avals, flat_out_specs, out_paths)
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
# Create args, kwargs pytree def
|
||||
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
|
||||
# The inputs for the index maps
|
||||
index_map_avals = (
|
||||
(jax_core.ShapedArray((), jnp.dtype("int32")),) * len(self.grid))
|
||||
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
|
||||
|
||||
if num_scalar_prefetch:
|
||||
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
|
||||
scalar_avals, unflat_in_avals = split_list(
|
||||
all_avals, [num_scalar_prefetch])
|
||||
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
|
||||
num_flat_scalar_prefetch = len(flat_scalar_avals)
|
||||
scalar_ref_avals = [
|
||||
self._make_scalar_ref_aval(aval)
|
||||
for aval in flat_scalar_avals]
|
||||
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
|
||||
scalar_tree, scalar_ref_avals)
|
||||
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||
index_map_tree = tree_util.tree_structure(((*index_map_avals,
|
||||
*scalar_avals), {}))
|
||||
index_map_avals = (*index_map_avals, *scalar_ref_avals)
|
||||
del scalar_ref_avals, flat_scalar_avals, scalar_tree
|
||||
del scalar_avals, unflat_in_avals, all_avals
|
||||
else:
|
||||
num_flat_scalar_prefetch = 0
|
||||
jaxpr_scalar_ref_avals = ()
|
||||
|
||||
if scratch_shapes:
|
||||
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
|
||||
scratch_shapes)
|
||||
flat_scratch_avals = map(self._make_scratch_aval, flat_scratch_shapes)
|
||||
num_flat_scratch_operands = len(flat_scratch_avals)
|
||||
jaxpr_scratch_avals = tree_util.tree_unflatten(
|
||||
scratch_tree, flat_scratch_avals)
|
||||
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
|
||||
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
|
||||
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
|
||||
else:
|
||||
num_flat_scratch_operands = 0
|
||||
jaxpr_scratch_avals = ()
|
||||
|
||||
if self.in_specs is not no_block_spec:
|
||||
flat_in_specs, in_specs_tree = tree_util.tree_flatten(self.in_specs)
|
||||
if in_specs_tree != in_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
|
||||
"inputs", in_tree))
|
||||
else:
|
||||
flat_in_specs = [no_block_spec] * len(in_avals)
|
||||
|
||||
in_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
grid_avals,
|
||||
in_tree=grid_tree,
|
||||
index_map_avals=index_map_avals,
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="input",
|
||||
what="inputs",
|
||||
),
|
||||
flat_in_specs,
|
||||
in_paths,
|
||||
in_ref_avals,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_avals,
|
||||
)
|
||||
|
||||
if self.out_specs is not no_block_spec:
|
||||
flat_out_specs, out_specs_tree = tree_util.tree_flatten(self.out_specs)
|
||||
if out_specs_tree != out_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
|
||||
"`out_shape`", out_tree))
|
||||
else:
|
||||
flat_out_specs = [no_block_spec] * len(out_avals)
|
||||
|
||||
out_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
grid_avals,
|
||||
in_tree=grid_tree,
|
||||
index_map_avals=index_map_avals,
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="output",
|
||||
what="outputs",
|
||||
),
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_ref_avals,
|
||||
out_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
grid_mapping_grid, self.grid_names, # type: ignore
|
||||
(*in_block_mappings, *out_block_mappings)
|
||||
grid=grid_mapping_grid, # type: ignore[arg-type]
|
||||
grid_names=self.grid_names,
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
index_map_avals=index_map_avals, # type: ignore[arg-type]
|
||||
index_map_tree=index_map_tree,
|
||||
vmapped_dims=(),
|
||||
num_index_operands=num_flat_scalar_prefetch,
|
||||
num_constant_operands=0, # Fixed up later
|
||||
num_inputs=len(flat_in_specs),
|
||||
num_outputs=len(flat_out_specs),
|
||||
num_scratch_operands=num_flat_scratch_operands,
|
||||
)
|
||||
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
grid_mapping.check_invariants()
|
||||
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
|
||||
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
|
||||
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals,
|
||||
*jaxpr_scratch_avals), grid_mapping
|
||||
|
||||
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
|
||||
assert False # Not needed in GridSpec
|
||||
|
||||
def _make_scalar_ref_aval(self, aval):
|
||||
assert False # Not needed in GridSpec
|
||||
|
||||
def unzip_dynamic_grid_bounds(
|
||||
self,
|
||||
@ -557,6 +735,7 @@ class GridSpec:
|
||||
static_self.grid = static_grid # type: ignore
|
||||
return static_self, dynamic_bounds
|
||||
|
||||
|
||||
def pytreedef_mismatch_err_msg(
|
||||
what1: str, tree1: tree_util.PyTreeDef,
|
||||
what2: str, tree2: tree_util.PyTreeDef) -> str:
|
||||
|
@ -24,7 +24,6 @@ from typing import Any, Hashable
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
import jax.numpy as jnp
|
||||
from jax._src.pallas import core as pallas_core
|
||||
@ -141,30 +140,19 @@ class MemoryRef:
|
||||
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
|
||||
|
||||
|
||||
def _make_aval(obj: object) -> jax_core.AbstractValue:
|
||||
if isinstance(obj, MemoryRef):
|
||||
return obj.get_aval()
|
||||
if isinstance(obj, SemaphoreType):
|
||||
return obj.get_aval()
|
||||
raise ValueError(f"No registered conversion for {type(obj)}. "
|
||||
"Only VMEM and SemaphoreType are supported.")
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
num_scalar_prefetch: int
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
in_specs_tree: Any
|
||||
out_specs_tree: Any
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
scratch_shapes: tuple[Any, ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_scalar_prefetch: int,
|
||||
grid: Grid | None = None,
|
||||
grid: Grid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
scratch_shapes: Any | Sequence[Any] = ()
|
||||
@ -173,84 +161,25 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
self.num_scalar_prefetch = num_scalar_prefetch
|
||||
self.scratch_shapes = tuple(scratch_shapes)
|
||||
|
||||
def get_grid_mapping(
|
||||
def _make_scalar_ref_aval(self, aval):
|
||||
return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
TPUMemorySpace.SMEM)
|
||||
|
||||
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
|
||||
if isinstance(obj, MemoryRef):
|
||||
return obj.get_aval()
|
||||
if isinstance(obj, SemaphoreType):
|
||||
return obj.get_aval()
|
||||
raise ValueError(f"No registered conversion for {type(obj)}. "
|
||||
"Only VMEM and SemaphoreType are supported.")
|
||||
|
||||
def get_grid_mapping( # type: ignore[override]
|
||||
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in self.grid)
|
||||
grid_mapping_grid = tuple(
|
||||
pallas_core.dynamic_grid_dim if d is None else d for d in self.grid
|
||||
)
|
||||
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
|
||||
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
|
||||
self.scratch_shapes)
|
||||
flat_scratch_avals = map(_make_aval, flat_scratch_shapes)
|
||||
scalar_avals, unflat_in_avals = split_list(
|
||||
all_avals, [self.num_scalar_prefetch])
|
||||
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
|
||||
num_flat_scalar_prefetch = len(flat_scalar_avals)
|
||||
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_avals_tree, out_avals, out_tree)
|
||||
in_ref_avals, out_ref_avals = (
|
||||
pallas_core._get_ref_avals(
|
||||
in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
|
||||
out_avals, flat_out_specs, out_paths))
|
||||
scalar_ref_avals = [
|
||||
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
TPUMemorySpace.SMEM)
|
||||
for aval in flat_scalar_avals]
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
# Create args, kwargs pytree def
|
||||
index_map_in_tree = tree_util.tree_structure(
|
||||
((*grid_avals, *scalar_avals), {})
|
||||
)
|
||||
in_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals),
|
||||
in_tree=index_map_in_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="input",
|
||||
),
|
||||
flat_in_specs,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_ref_avals,
|
||||
)
|
||||
out_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals),
|
||||
in_tree=index_map_in_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="output",
|
||||
),
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_ref_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
grid=grid_mapping_grid, grid_names=self.grid_names, # type: ignore
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
mapped_dims=(),
|
||||
num_index_operands=num_flat_scalar_prefetch,
|
||||
num_scratch_operands=len(flat_scratch_avals)
|
||||
)
|
||||
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
|
||||
scalar_tree, scalar_ref_avals)
|
||||
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals)
|
||||
jaxpr_scratch_avals = tree_util.tree_unflatten(
|
||||
scratch_tree, flat_scratch_avals)
|
||||
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
|
||||
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
|
||||
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
|
||||
*jaxpr_in_ref_avals)
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals,
|
||||
*jaxpr_scratch_avals), grid_mapping
|
||||
return super().get_grid_mapping(in_avals, in_tree, in_paths,
|
||||
out_avals, out_tree, out_paths,
|
||||
num_scalar_prefetch=self.num_scalar_prefetch,
|
||||
scratch_shapes=self.scratch_shapes)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -294,7 +294,7 @@ class MosaicGridMapping:
|
||||
self.grid_names = grid_mapping.grid_names
|
||||
self.jaxpr = jaxpr
|
||||
self.block_mappings = grid_mapping.block_mappings
|
||||
self.mapped_dims = grid_mapping.mapped_dims
|
||||
self.mapped_dims = grid_mapping.vmapped_dims
|
||||
num_scalar_prefetch = grid_mapping.num_index_operands
|
||||
num_scratch = grid_mapping.num_scratch_operands
|
||||
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
|
||||
@ -425,13 +425,15 @@ class MeshInfo:
|
||||
def lower_jaxpr_to_module(
|
||||
ctx: ir.Context,
|
||||
grid_mapping: pl_core.GridMapping,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
dimension_semantics: tuple[str | None, ...] | None,
|
||||
mesh: mesh_lib.Mesh | None = None,
|
||||
for_verification: bool = False,
|
||||
) -> tuple[Module, tuple[Any, ...]]:
|
||||
# TODO(necula): cleanup
|
||||
in_shapes = grid_mapping.in_shapes
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
|
||||
mosaic_grid_mapping = MosaicGridMapping(
|
||||
jaxpr, grid_mapping, dimension_semantics, mesh)
|
||||
mosaic_grid_mapping.maybe_compress_grid()
|
||||
@ -454,10 +456,8 @@ def lower_jaxpr_to_module(
|
||||
invars = invars[grid_mapping.num_index_operands:]
|
||||
# invars now = *consts, *ins, *outs
|
||||
avals = tuple(v.aval for v in invars)
|
||||
# TODO(necula): we should not need block_operand_shapes anymore
|
||||
block_operand_shapes = (
|
||||
*[jax.ShapeDtypeStruct(v.aval.shape,
|
||||
v.aval.dtype)
|
||||
for v in invars[:grid_mapping.num_constant_operands]],
|
||||
*in_shapes[grid_mapping.num_index_operands:],
|
||||
*out_shapes,
|
||||
)
|
||||
@ -466,10 +466,6 @@ def lower_jaxpr_to_module(
|
||||
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
|
||||
):
|
||||
func_name = f"transform_{i}"
|
||||
if bm is None:
|
||||
raise NotImplementedError(
|
||||
"BlockSpecs are required on TPU when grid is specified"
|
||||
)
|
||||
# ANY operands don't support windowing and require empty window_params.
|
||||
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
|
||||
# We may not require windowing if our block_shape matches the original
|
||||
|
@ -32,7 +32,6 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core
|
||||
from jax._src.pallas.mosaic import lowering
|
||||
from jax._src.pallas.mosaic import verification
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
|
||||
@ -69,21 +68,13 @@ def pallas_call_tpu_lowering_rule(
|
||||
name: str,
|
||||
grid_mapping: core.GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: dict[str, Any]):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if interpret:
|
||||
# TODO(necula): is this branch still needed?
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
interpret=interpret, debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping,
|
||||
compiler_params=compiler_params)
|
||||
del interpret
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
if "mosaic_params" in compiler_params:
|
||||
@ -114,7 +105,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
with mlir_ctx, ir.Location.unknown(mlir_ctx):
|
||||
dimension_semantics = mosaic_params.get("dimension_semantics", None)
|
||||
return lowering.lower_jaxpr_to_module(
|
||||
mlir_ctx, grid_mapping, in_shapes, out_shapes, jaxpr,
|
||||
mlir_ctx, grid_mapping, jaxpr,
|
||||
dimension_semantics=dimension_semantics, mesh=mesh,
|
||||
for_verification=for_verification)
|
||||
mosaic_module, extra_args = lower_module(for_verification=False)
|
||||
|
@ -155,14 +155,14 @@ class LoweringError(Exception):
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
grid_mapping: pl_core.GridMapping,
|
||||
in_structs: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_structs: tuple[jax.ShapeDtypeStruct, ...],
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
compiler_params: dict[str, Any],
|
||||
) -> LoweringResult:
|
||||
in_structs = grid_mapping.in_shapes
|
||||
out_structs = grid_mapping.out_shapes
|
||||
assert len(jaxpr.outvars) == 0
|
||||
assert not grid_mapping.mapped_dims
|
||||
assert not grid_mapping.vmapped_dims
|
||||
grid = grid_mapping.grid
|
||||
if len(grid) < 3:
|
||||
grid += (1,) * (3 - len(grid))
|
||||
|
@ -19,12 +19,10 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import lowering
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax.experimental.mosaic import gpu as mosaic_gpu
|
||||
|
||||
|
||||
@ -33,30 +31,13 @@ def pallas_call_lowering(
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
if interpret:
|
||||
# TODO(necula): is this still needed?
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
|
||||
del interpret
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Mosaic GPU backend"
|
||||
@ -75,8 +56,6 @@ def pallas_call_lowering(
|
||||
)
|
||||
lowering_result = lowering.lower_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
in_shapes,
|
||||
out_shapes,
|
||||
jaxpr,
|
||||
name,
|
||||
compiler_params,
|
||||
|
@ -61,6 +61,7 @@ BlockSpecTree = pallas_core.BlockSpecTree
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
|
||||
# See the docstring for GridMapping for the calling convention
|
||||
pallas_call_p = jax_core.Primitive('pallas_call')
|
||||
pallas_call_p.multiple_results = True
|
||||
|
||||
@ -103,8 +104,6 @@ def _pad_values_to_block_dimension(value,
|
||||
Returns:
|
||||
A padded array.
|
||||
"""
|
||||
if block_shape is None:
|
||||
return value
|
||||
padded_shape = tuple(
|
||||
((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
|
||||
)
|
||||
@ -166,12 +165,14 @@ def _pallas_call_impl(*args, **kwargs):
|
||||
def _pallas_call_impl_interpret(
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str, in_shapes, out_shapes,
|
||||
name: str,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any):
|
||||
del compiler_params, name, in_shapes
|
||||
del compiler_params, name
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
@ -192,21 +193,13 @@ def _pallas_call_impl_interpret(
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
|
||||
# args now contains: *consts, *inputs, *outputs
|
||||
num_invars = len(jaxpr.invars)
|
||||
num_inputs_outputs = (
|
||||
num_invars
|
||||
- grid_mapping.num_index_operands
|
||||
- grid_mapping.num_scratch_operands
|
||||
)
|
||||
_, _, scratch_invars = split_list(
|
||||
jaxpr.invars, [grid_mapping.num_index_operands, num_inputs_outputs]
|
||||
)
|
||||
scratch_invars = jaxpr.invars[grid_mapping.slice_scratch_ops]
|
||||
scratch_avals = [v.aval for v in scratch_invars]
|
||||
scratch_values = _initialize_scratch_vals(scratch_avals)
|
||||
|
||||
carry = []
|
||||
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
|
||||
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
@ -215,18 +208,14 @@ def _pallas_call_impl_interpret(
|
||||
x = lax.pad(x, pad_value, [(*p, 0) for p in padding])
|
||||
carry.append(x)
|
||||
|
||||
block_shapes_without_mapped_dims = [
|
||||
None if block_mapping is None else block_mapping.block_shape
|
||||
for block_mapping in grid_mapping.block_mappings
|
||||
]
|
||||
is_indexing_dim = [
|
||||
None if bm is None else tuple(b is pallas_core.mapped for b in bm)
|
||||
for bm in block_shapes_without_mapped_dims
|
||||
tuple(b is pallas_core.mapped for b in bm.block_shape)
|
||||
for bm in grid_mapping.block_mappings
|
||||
]
|
||||
block_shapes = [
|
||||
None if (bm is None or iid is None)
|
||||
else tuple(1 if i else b for i, b in zip(iid, bm))
|
||||
for iid, bm in zip(is_indexing_dim, block_shapes_without_mapped_dims)
|
||||
None if iid is None
|
||||
else tuple(1 if i else b for i, b in zip(iid, bm.block_shape))
|
||||
for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings)
|
||||
]
|
||||
|
||||
# Pad values to evenly divide into block dimensions. This matches the
|
||||
@ -254,7 +243,7 @@ def _pallas_call_impl_interpret(
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.mapped_dims
|
||||
if dim not in grid_mapping.vmapped_dims
|
||||
)
|
||||
carry, scratch = split_list(carry, [num_inout])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
@ -288,7 +277,7 @@ def _pallas_call_impl_interpret(
|
||||
out_block_mappings = grid_mapping.block_mappings[len(args):]
|
||||
out_nopad = []
|
||||
for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings):
|
||||
if bm is not None and isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
if input_output_aliases:
|
||||
@ -303,13 +292,16 @@ def _pallas_call_impl_interpret(
|
||||
|
||||
pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
||||
def _pallas_call_abstract_eval(*avals, grid_mapping, **_):
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
|
||||
grid_mapping, debug, interpret, compiler_params: Any):
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if grid_mapping.num_index_operands:
|
||||
@ -345,14 +337,17 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
print(jvp_jaxpr)
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
||||
jvp_grid_mapping = grid_mapping.replace(
|
||||
block_mappings=jvp_bms,
|
||||
num_inputs=grid_mapping.num_inputs * 2,
|
||||
num_outputs=grid_mapping.num_outputs * 2,
|
||||
)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*primals,
|
||||
*tangents,
|
||||
jaxpr=jvp_jaxpr,
|
||||
name=f"{name}_jvp",
|
||||
in_shapes=(*in_shapes, *in_shapes),
|
||||
out_shapes=(*out_shapes, *out_shapes),
|
||||
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
|
||||
grid_mapping=jvp_grid_mapping,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
@ -362,40 +357,38 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
return out_primals, out_tangents
|
||||
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
|
||||
def _batch_block_mapping(grid_mapping: GridMapping, aval: jax_core.ShapedArray,
|
||||
def _batch_block_mapping(grid_mapping: GridMapping,
|
||||
axis_size: int,
|
||||
aval: jax_core.ShapedArray,
|
||||
dim: int | batching.NotMapped,
|
||||
block_mapping: BlockMapping | None) -> BlockMapping:
|
||||
block_mapping: BlockMapping) -> BlockMapping:
|
||||
def _block_map_function(new_idx, *args):
|
||||
if block_mapping is None:
|
||||
indices = [0] * len(aval.shape)
|
||||
else:
|
||||
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
|
||||
block_mapping.index_map_jaxpr.consts,
|
||||
*args)
|
||||
indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr,
|
||||
block_mapping.index_map_jaxpr.consts,
|
||||
*args)
|
||||
if dim is not batching.not_mapped:
|
||||
indices.insert(dim, new_idx)
|
||||
return tuple(indices)
|
||||
i32_aval = jax_core.ShapedArray((), jnp.int32)
|
||||
if block_mapping is None:
|
||||
idx_avals = [i32_aval] * (len(grid_mapping.grid) + 1)
|
||||
else:
|
||||
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
with grid_mapping.trace_env():
|
||||
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_block_map_function), idx_avals)
|
||||
shape = aval.shape if block_mapping is None else block_mapping.block_shape
|
||||
shape = block_mapping.block_shape
|
||||
if dim is batching.not_mapped:
|
||||
new_block_shape = shape
|
||||
new_array_shape_dtype = block_mapping.array_shape_dtype
|
||||
else:
|
||||
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
||||
new_array_shape_dtype = jax.ShapeDtypeStruct(
|
||||
tuple_insert(block_mapping.array_shape_dtype.shape,
|
||||
dim,
|
||||
axis_size),
|
||||
block_mapping.array_shape_dtype.dtype)
|
||||
|
||||
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
||||
if block_mapping is None:
|
||||
return BlockMapping(
|
||||
block_shape=new_block_shape,
|
||||
index_map_jaxpr=jaxpr,
|
||||
indexing_mode=pallas_core.blocked,
|
||||
)
|
||||
return block_mapping.replace(block_shape=new_block_shape,
|
||||
array_shape_dtype=new_array_shape_dtype,
|
||||
index_map_jaxpr=jaxpr)
|
||||
|
||||
|
||||
@ -434,8 +427,6 @@ def _batch_with_explicit_loop(
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
@ -453,7 +444,8 @@ def _batch_with_explicit_loop(
|
||||
to the current iteration index and dynamic_updates an (initially empty) output
|
||||
allocation.
|
||||
"""
|
||||
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if not dims:
|
||||
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
||||
|
||||
@ -499,13 +491,10 @@ def _batch_with_explicit_loop(
|
||||
axis=dim,
|
||||
)
|
||||
)
|
||||
|
||||
batch_out = pallas_call_p.bind(
|
||||
*batch_args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -533,15 +522,14 @@ def _pallas_call_batching_rule(
|
||||
*,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
grid_mapping: GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
):
|
||||
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
) -> jax.Array:
|
||||
@ -558,8 +546,6 @@ def _pallas_call_batching_rule(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -591,8 +577,6 @@ def _pallas_call_batching_rule(
|
||||
dims=dynamic_grid_dims + dims,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -625,8 +609,6 @@ def _pallas_call_batching_rule(
|
||||
dims=scalar_bdims + bdims,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -652,7 +634,6 @@ def _pallas_call_batching_rule(
|
||||
all_dims = list(dims) + [0] * len(out_shapes)
|
||||
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
num_constant_operands = grid_mapping.num_constant_operands
|
||||
num_scratch_operands = grid_mapping.num_scratch_operands
|
||||
|
||||
# Only add a batch dimension for the avals that actually have a grid mapping.
|
||||
@ -660,37 +641,27 @@ def _pallas_call_batching_rule(
|
||||
# operands (the last in the list).
|
||||
avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)]
|
||||
batched_block_mappings = map(
|
||||
partial(_batch_block_mapping, grid_mapping),
|
||||
partial(_batch_block_mapping, grid_mapping, axis_size),
|
||||
avals_to_batch,
|
||||
all_dims[num_index_operands:],
|
||||
block_mappings,
|
||||
)
|
||||
|
||||
# TODO(necula): should fix in_shapes to include the consts
|
||||
dims_no_consts = (
|
||||
dims[:num_index_operands] +
|
||||
dims[num_index_operands + num_constant_operands:]
|
||||
)
|
||||
batched_in_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
|
||||
tuple_insert(x.shape, dim, axis_size),
|
||||
x.dtype)
|
||||
for x, dim in zip(in_shapes, dims_no_consts))
|
||||
batched_out_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
|
||||
for x in out_shapes)
|
||||
|
||||
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(grid_mapping.index_map_avals)
|
||||
assert not index_map_tree_kwargs
|
||||
batched_index_map_args = (jax_core.ShapedArray((), jnp.int32),) + index_map_tree_args
|
||||
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten((batched_index_map_args, {}))
|
||||
batched_grid_mapping = grid_mapping.replace(
|
||||
grid=(axis_size, *grid_mapping.grid),
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
mapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.mapped_dims))
|
||||
index_map_avals=batched_index_map_avals,
|
||||
index_map_tree=batched_index_map_tree,
|
||||
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims))
|
||||
out = pallas_call_p.bind(
|
||||
*dynamic_grid_args,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name=f"batched_{name}",
|
||||
in_shapes=batched_in_shapes,
|
||||
out_shapes=batched_out_shapes,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -725,7 +696,6 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
interpret: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
out_shapes,
|
||||
**kwargs):
|
||||
# We implement the checkify rule in 4 steps:
|
||||
# 1) First, trace the kernel body to get the expected error shapes.
|
||||
@ -831,37 +801,25 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
|
||||
# Prepare pallas_call inputs. We need to create new block specs
|
||||
# for the new error inputs and outputs.
|
||||
scalar_avals = map(checkify.get_shaped_aval, scalars)
|
||||
error_block_specs = [pallas_core.BlockSpec(
|
||||
index_map=lambda *args: (0,) * len(error.shape),
|
||||
block_shape=error.shape)
|
||||
for error in shaped_err_avals]
|
||||
error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals)
|
||||
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
|
||||
grid_avals = [
|
||||
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
||||
scalar_ref_avals = [
|
||||
pallas_core.AbstractMemoryRef(
|
||||
jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||
pallas_core.MemorySpace.INDEX)
|
||||
for aval in scalar_avals]
|
||||
grid_tree = tree_util.tree_structure(((*grid_avals, *scalar_avals), {}))
|
||||
error_block_mappings = map(
|
||||
partial(
|
||||
pallas_core._convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals),
|
||||
in_tree=grid_tree,
|
||||
index_map_avals=grid_mapping.index_map_avals,
|
||||
index_map_tree=grid_mapping.index_map_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=grid_mapping.mapped_dims,
|
||||
mapped_dims=grid_mapping.vmapped_dims,
|
||||
what="error"),
|
||||
error_block_specs, error_paths, error_memref_aval)
|
||||
error_block_specs, error_paths, shaped_err_avals)
|
||||
input_block_mappings, output_block_mappings = split_list(
|
||||
grid_mapping.block_mappings, [num_kernel_inputs,])
|
||||
grid_mapping_with_error = grid_mapping.replace(
|
||||
block_mappings=(*error_block_mappings, *input_block_mappings,
|
||||
*error_block_mappings, *output_block_mappings)
|
||||
*error_block_mappings, *output_block_mappings),
|
||||
num_inputs=grid_mapping.num_inputs + len(error_block_mappings),
|
||||
num_outputs=grid_mapping.num_outputs + len(error_block_mappings)
|
||||
)
|
||||
error_out_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(e.shape, e.dtype) for e in shaped_err_avals)
|
||||
# Bump all input_output_aliases by num_err_vals to make room for error
|
||||
# TODO(justinfu): Don't bump scalars here.
|
||||
input_output_aliases = tuple(
|
||||
@ -870,17 +828,11 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
(i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases
|
||||
|
||||
new_vals_in = [*scalars, *err_vals, *args]
|
||||
new_input_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype) for x in [
|
||||
*scalars, *shaped_err_avals, *args])
|
||||
del kwargs['in_shapes']
|
||||
result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in,
|
||||
jaxpr=final_jaxpr,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping_with_error,
|
||||
input_output_aliases=input_output_aliases_with_error,
|
||||
in_shapes=new_input_shapes,
|
||||
out_shapes=error_out_shapes + out_shapes,
|
||||
**kwargs)
|
||||
errors, results = split_list(result, [num_err_vals])
|
||||
# TODO(b/350593266): Remove line below once we support ()-shaped scalars.
|
||||
@ -890,25 +842,20 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
flat_in_avals: Sequence[jax_core.AbstractValue],
|
||||
flat_out_avals: Sequence[jax_core.AbstractValue],
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
out_paths: Sequence[tree_util.KeyPath],
|
||||
interpret: bool):
|
||||
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, in_paths,
|
||||
flat_out_avals, out_tree, out_paths)
|
||||
def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
grid_mapping: GridMapping,
|
||||
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
|
||||
kernel_in_tree: tree_util.PyTreeDef,
|
||||
interpret: bool):
|
||||
if interpret:
|
||||
avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals)
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), jaxpr_in_tree)
|
||||
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
|
||||
kernel_avals))
|
||||
wrapped_kernel_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), kernel_in_tree)
|
||||
debug = pe.debug_info(fun, kernel_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
|
||||
jaxpr_flat_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
if consts:
|
||||
# Pad ``block_mappings`` to account for the hoisted constants.
|
||||
# The constants will be right after the index operands and just before
|
||||
@ -918,20 +865,14 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
index=grid_mapping.num_index_operands,
|
||||
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
|
||||
num_constant_operands = len(consts)
|
||||
# TODO(necula): refactor grid_mapping to remove this code duplication
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
||||
if grid_mapping.num_index_operands:
|
||||
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
|
||||
# Create args, kwargs pytree def
|
||||
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
|
||||
const_block_mappings = []
|
||||
for c_idx, c in enumerate(consts):
|
||||
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
|
||||
grid_avals,
|
||||
pallas_core.BlockSpec(None, None),
|
||||
path=(tree_util.SequenceKey(c_idx),),
|
||||
aval=jax_core.ShapedArray(c.shape, c.dtype),
|
||||
in_tree=grid_tree,
|
||||
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
|
||||
index_map_avals=grid_mapping.index_map_avals,
|
||||
index_map_tree=grid_mapping.index_map_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=(),
|
||||
what="consts",
|
||||
@ -942,7 +883,12 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
|
||||
num_constant_operands=num_constant_operands,
|
||||
)
|
||||
return grid_mapping, jaxpr, consts, out_tree_thunk()
|
||||
kernel_out_tree = out_tree_thunk()
|
||||
if kernel_out_tree != tree_util.tree_structure(None):
|
||||
raise ValueError(
|
||||
"The kernel function in a pallas_call should return None. "
|
||||
f"Found a PyTree: {kernel_out_tree}")
|
||||
return grid_mapping, jaxpr, consts
|
||||
|
||||
def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
if name is None:
|
||||
@ -1044,7 +990,7 @@ def pallas_call(
|
||||
*,
|
||||
grid_spec: GridSpec | None = None,
|
||||
debug: bool = False,
|
||||
grid: Grid | None = None,
|
||||
grid: Grid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
@ -1066,8 +1012,7 @@ def pallas_call(
|
||||
debug: if True, Pallas prints various intermediate forms of the kernel
|
||||
as it is being processed.
|
||||
grid: the iteration space, as a tuple of integers. The kernel is executed
|
||||
as many times as ``prod(grid)``. The default value ``None`` is equivalent
|
||||
to ``()``.
|
||||
as many times as ``prod(grid)``.
|
||||
See details at :ref:`pallas_grid`.
|
||||
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||
a structure matching that of the positional arguments.
|
||||
@ -1097,10 +1042,23 @@ def pallas_call(
|
||||
name = _extract_function_name(f, name)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if grid is not None and grid_spec is not None:
|
||||
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
|
||||
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
else:
|
||||
if grid:
|
||||
raise ValueError(
|
||||
"If `grid_spec` is specified, then `grid` must "
|
||||
f"be `()`. It is {grid}")
|
||||
if in_specs is not no_block_spec:
|
||||
raise ValueError(
|
||||
"If `grid_spec` is specified, then `in_specs` must "
|
||||
f"be `no_block_spec`. It is {in_specs}")
|
||||
if out_specs is not no_block_spec:
|
||||
raise ValueError(
|
||||
"If `grid_spec` is specified, then `out_specs` must "
|
||||
f"be `no_block_spec`. It is {out_specs}")
|
||||
del grid, in_specs, out_specs
|
||||
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
||||
# TODO(necula): this canonicalization may be convenient for some usage
|
||||
# but it is lossy, because it prevents expressing functions that return
|
||||
@ -1119,14 +1077,14 @@ def pallas_call(
|
||||
for a in flat_args)
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
# TODO(necula): check that input_output_aliases is well-formed: shapes match, no duplicates, etc.
|
||||
grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr(
|
||||
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
|
||||
out_tree, out_paths, interpret=interpret)
|
||||
if f_out_tree != tree_util.tree_flatten(None)[1]:
|
||||
raise ValueError(
|
||||
"The kernel function in a pallas_call should return None. "
|
||||
f"Found a PyTree: {f_out_tree}")
|
||||
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
|
||||
kernel_avals, grid_mapping = grid_spec.get_grid_mapping(
|
||||
flat_in_avals, in_tree, in_paths,
|
||||
flat_out_avals, out_tree, out_paths)
|
||||
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
|
||||
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
|
||||
f, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
|
||||
interpret=interpret)
|
||||
for i_idx, o_idx in input_output_aliases.items():
|
||||
if i_idx not in range(len(flat_in_avals)):
|
||||
raise ValueError(
|
||||
@ -1152,9 +1110,7 @@ def pallas_call(
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
|
||||
jaxpr=jaxpr, name=name,
|
||||
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
for a in flat_args),
|
||||
out_shapes=tuple(flat_out_shapes), debug=debug,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
|
@ -112,10 +112,8 @@ class LoweringError(Exception):
|
||||
|
||||
|
||||
def _eval_index_map(
|
||||
ctx: ModuleContext, idx, block_mapping: BlockMapping | None
|
||||
ctx: ModuleContext, idx, block_mapping: BlockMapping
|
||||
):
|
||||
if block_mapping is None:
|
||||
return None
|
||||
block_indices = lower_jaxpr_to_triton_ir(
|
||||
ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx
|
||||
)
|
||||
@ -187,13 +185,13 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping):
|
||||
|
||||
# Preserve grid order provided to pallas_call
|
||||
for i, s in enumerate(grid_mapping.grid):
|
||||
if i not in grid_mapping.mapped_dims:
|
||||
if i not in grid_mapping.vmapped_dims:
|
||||
launch_grid.append(s)
|
||||
launch_grid_to_pallas_grid.append(i)
|
||||
|
||||
# For mapped dims, iterate from inner to outer. This follows the pallas_call
|
||||
# batching rule that prepends the vmapped dimension.
|
||||
for dim in reversed(grid_mapping.mapped_dims):
|
||||
for dim in reversed(grid_mapping.vmapped_dims):
|
||||
s = grid_mapping.grid[dim]
|
||||
launch_grid.append(s)
|
||||
launch_grid_to_pallas_grid.append(dim)
|
||||
@ -287,7 +285,7 @@ def lower_jaxpr_to_triton_module(
|
||||
local_program_ids = [
|
||||
pid
|
||||
for i, pid in enumerate(program_ids)
|
||||
if i not in grid_mapping.mapped_dims
|
||||
if i not in grid_mapping.vmapped_dims
|
||||
]
|
||||
ctx = ModuleContext(
|
||||
name, grid_mapping, local_program_ids, mlir.TracebackCaches(), platform
|
||||
@ -297,7 +295,7 @@ def lower_jaxpr_to_triton_module(
|
||||
"Scalar prefetch not supported in Triton lowering."
|
||||
)
|
||||
for bm in grid_mapping.block_mappings:
|
||||
if bm is not None and not isinstance(bm.indexing_mode, Blocked):
|
||||
if not isinstance(bm.indexing_mode, Blocked):
|
||||
raise NotImplementedError(
|
||||
"Only Blocked indexing mode is supported in Triton lowering."
|
||||
)
|
||||
@ -305,20 +303,14 @@ def lower_jaxpr_to_triton_module(
|
||||
functools.partial(_eval_index_map, ctx, program_ids),
|
||||
grid_mapping.block_mappings,
|
||||
)
|
||||
consts_shapes = [
|
||||
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
|
||||
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
|
||||
]
|
||||
block_infos = [
|
||||
BlockInfo(
|
||||
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
|
||||
start_idx,
|
||||
block_mapping.block_shape,
|
||||
)
|
||||
if block_mapping is not None
|
||||
else None
|
||||
for shape_dtype, block_mapping, start_idx in zip(
|
||||
(*consts_shapes, *in_out_shapes),
|
||||
in_out_shapes,
|
||||
grid_mapping.block_mappings,
|
||||
start_indices,
|
||||
)
|
||||
|
@ -19,12 +19,10 @@ from __future__ import annotations
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas.triton import lowering
|
||||
|
||||
|
||||
@ -45,30 +43,16 @@ def pallas_call_lowering(
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
if interpret:
|
||||
# TODO(necula): is this branch still needed?
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx,
|
||||
*in_nodes,
|
||||
jaxpr=jaxpr,
|
||||
name=name,
|
||||
out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping,
|
||||
compiler_params=compiler_params,
|
||||
)
|
||||
|
||||
del interpret
|
||||
# TODO(necula): cleanup
|
||||
in_shapes = grid_mapping.in_shapes
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Triton backend"
|
||||
|
@ -479,7 +479,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
np.asarray(out), atol=1e-4, rtol=1e-4))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
one_containing="f32",
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128")])
|
||||
|
@ -26,7 +26,7 @@ from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas.ops.gpu import attention
|
||||
from jax.experimental.pallas.ops.gpu import layer_norm
|
||||
@ -129,7 +129,7 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
|
||||
super().setUp()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
_trace_kernel_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
@ -27,7 +27,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
from jax import export
|
||||
@ -93,7 +93,7 @@ class ShapePolyTest(jtu.JaxTestCase,
|
||||
if sys.platform == "win32":
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
super().setUp()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
_trace_kernel_to_jaxpr.cache_clear()
|
||||
|
||||
def test_copy(self):
|
||||
# The blocks are static, but the input and the grid are of polymorphic
|
||||
|
@ -32,7 +32,7 @@ from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -145,14 +145,13 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
|
||||
super().setUp()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
_trace_kernel_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
|
||||
class PallasCallTest(PallasBaseTest):
|
||||
|
||||
def test_add_one(self):
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
@ -655,7 +654,6 @@ class PallasCallInterpreterTest(PallasCallTest):
|
||||
|
||||
|
||||
class ApiErrorTest(PallasBaseTest):
|
||||
|
||||
def test_pallas_kernel_args_mismatch(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
|
||||
@ -719,7 +717,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((4,), lambda: (0, 0))])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
|
||||
"Index map for inputs\\[0\\] must return 1 values to match .*Currently returning 2 values."):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_index_map_captures_consts(self):
|
||||
@ -730,7 +728,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Index map for input\\[0\\] captures constants"):
|
||||
"Index map for inputs\\[0\\] captures constants"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_out_specs_mismatch_shape(self):
|
||||
@ -752,7 +750,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Block shape for input\\[0\\] .* must have the same number of dimensions as the "
|
||||
"Block shape for inputs\\[0\\] .* must have the same number of dimensions as the "
|
||||
"array shape"):
|
||||
|
||||
f(a)
|
||||
@ -762,7 +760,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
out_specs=[pl.BlockSpec((1, 1), lambda: 0)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Block shape for output\\[0\\] .* must have the same number of dimensions as the "
|
||||
"Block shape for outputs\\[0\\] .* must have the same number of dimensions as the "
|
||||
"array shape"):
|
||||
f(a)
|
||||
|
||||
@ -1893,7 +1891,6 @@ class PallasCheckifyInterpreterTest(PallasBaseTest):
|
||||
|
||||
|
||||
class PallasCallNamedGridTest(PallasBaseTest):
|
||||
|
||||
def test_named_grid(self):
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
|
@ -24,7 +24,7 @@ from jax import random
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -52,7 +52,7 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
self.skipTest("Only works on non-Windows platforms")
|
||||
|
||||
super().setUp()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
_trace_kernel_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
@ -29,7 +29,7 @@ from jax._src import state
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental import pallas as pl
|
||||
@ -65,7 +65,7 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
if not jtu.test_device_matches(['tpu']) and not self.INTERPRET:
|
||||
self.skipTest('Test requires TPUs, or interpret mode')
|
||||
super().setUp()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
_trace_kernel_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
@ -117,12 +117,17 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(scratch=scratch, vmap=vmap)
|
||||
dict(scratch=scratch, vmap=vmap, dyn_grid=dyn_grid)
|
||||
for scratch in [True, False]
|
||||
for vmap in [True, False]
|
||||
for vmap in [False, True]
|
||||
for dyn_grid in [False, True]
|
||||
]
|
||||
)
|
||||
def test_scalar_prefetch_hoisted_const(self, *, scratch: bool, vmap: bool):
|
||||
def test_scalar_prefetch_calling_convention(
|
||||
self, *,
|
||||
scratch: bool, vmap: bool, dyn_grid: bool):
|
||||
# Tests what we process correctly all the various inputs and outputs:
|
||||
# dynamic_grid_dims, index, inputs, outputs, scratch.
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
|
||||
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
|
||||
@ -133,30 +138,39 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
x_shape = (16, 128)
|
||||
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
|
||||
def f(x):
|
||||
def f(x, grid_size):
|
||||
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
grid=(2,),
|
||||
num_scalar_prefetch=1, # 1 pytree
|
||||
grid=(grid_size,),
|
||||
in_specs=[pl.BlockSpec((8, 128),
|
||||
lambda i, s_ref: (pl.load(s_ref, (i,)), 0))],
|
||||
lambda i, s_ref: (pl.load(s_ref[0], (i,)), 0))],
|
||||
out_specs=pl.BlockSpec((32, 128),
|
||||
lambda i, s_ref: (pl.load(s_ref, i), 0)),
|
||||
lambda i, s_ref: (pl.load(s_ref[0], i), 0)),
|
||||
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
|
||||
else []),
|
||||
),
|
||||
)
|
||||
def kernel(s_ref, src, dst, *scratch_refs):
|
||||
def kernel(s_refs, src, dst, *scratch_refs):
|
||||
s_ref, s2, s3 = s_refs
|
||||
assert s_ref.shape == (2,)
|
||||
assert s2.shape == (3,)
|
||||
assert s3 is None
|
||||
store_idx = s_ref[pl.program_id(0)]
|
||||
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)
|
||||
return kernel(s, x)
|
||||
# Pass a pytree of scalar
|
||||
return kernel((s, np.arange(3, dtype=np.int32), None), x)
|
||||
|
||||
if dyn_grid:
|
||||
f = jax.jit(f)
|
||||
if vmap:
|
||||
f = jax.vmap(f)
|
||||
res = f(x)
|
||||
res = jax.vmap(lambda x: f(x, 2))(x)
|
||||
else:
|
||||
res = f(x, 2)
|
||||
|
||||
if vmap:
|
||||
for i in range(x.shape[0]):
|
||||
self.assertAllClose(res[i, 0:1], to_store)
|
||||
@ -165,6 +179,35 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
self.assertAllClose(res[0:1], to_store)
|
||||
self.assertAllClose(res[33:34], to_store)
|
||||
|
||||
def test_with_unhashable_grid_spec(self):
|
||||
# Make sure that we don't crash when the GridSpec has non-hashable parts
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=[[jax.ShapeDtypeStruct((8, 128), np.int32)]],
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1, # 1 pytree
|
||||
grid=(1,),
|
||||
in_specs=[[pl.BlockSpec((8, 128),
|
||||
lambda i, s_ref: (0, 0))]],
|
||||
out_specs=[[pl.BlockSpec((8, 128),
|
||||
lambda i, s_ref: (0, 0))]],
|
||||
scratch_shapes=[[pltpu.SemaphoreType.REGULAR((3,))]],
|
||||
),
|
||||
)
|
||||
def kernel(s_ref, x_ref, o_ref, scratch_ref):
|
||||
assert isinstance(s_ref, list)
|
||||
assert isinstance(x_ref, list)
|
||||
assert isinstance(o_ref, list)
|
||||
assert isinstance(scratch_ref, list)
|
||||
o_ref[0][...] = x_ref[0][...]
|
||||
|
||||
x_shape = (8, 128)
|
||||
s = np.array([0, 1], np.int32)
|
||||
x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape)
|
||||
res = kernel([s, s], [x])
|
||||
self.assertIsInstance(res, tuple) # Even though we asked for a list!
|
||||
self.assertAllClose(res[0][0], x)
|
||||
|
||||
def test_block_spec_with_wrong_block_shape_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
@ -210,7 +253,7 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Index map for input\[0\] must return 2 values to match block_shape=\(8, 128\).'
|
||||
r'Index map for inputs\[0\] must return 2 values to match block shape \(8, 128\).'
|
||||
' Currently returning 1 values.'):
|
||||
_ = self.pallas_call(
|
||||
body,
|
||||
@ -2356,7 +2399,6 @@ class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest):
|
||||
|
||||
|
||||
class PallasCallTPUCheckifyTest(PallasBaseTest):
|
||||
|
||||
@parameterized.parameters((2,), (5,), (6,), (7,))
|
||||
def test_checkify_with_scalar_prefetch(self, threshold):
|
||||
def body(scalar_ref, x_ref, o_ref):
|
||||
|
@ -318,7 +318,6 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
class SplashAttentionTest(PallasBaseTest):
|
||||
|
||||
@parameterized.product(
|
||||
is_mqa=(False, True),
|
||||
is_segmented=(False, True),
|
||||
@ -544,7 +543,7 @@ class SplashAttentionTest(PallasBaseTest):
|
||||
data.draw(mha_strategy())
|
||||
)
|
||||
|
||||
# Avoid segment ids for rectangular matrices, as its hard to enforce
|
||||
# Avoid segment ids for rectangular matrices, as it's hard to enforce
|
||||
# valid masks (non-0 rows).
|
||||
hp.assume(q_seq_len == kv_seq_len or not is_segmented)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user