mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more parts of the calling convention. Here we use that new functionality and clean up some code. I have removed the internal methods from `BlockSpec` and `GridSpec` because these classes are part of the API. I added entries to pallas/CHANGELOG.
This commit is contained in:
parent
68972de021
commit
70a11acbb1
@ -10,6 +10,7 @@ Classes
|
||||
:toctree: _autosummary
|
||||
|
||||
BlockSpec
|
||||
GridSpec
|
||||
Slice
|
||||
|
||||
Functions
|
||||
@ -34,4 +35,4 @@ Functions
|
||||
atomic_or
|
||||
atomic_xchg
|
||||
|
||||
debug_print
|
||||
debug_print
|
||||
|
@ -17,6 +17,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
|
||||
be passed *before* `index_map`. The old argument order is deprecated and
|
||||
will be removed in a future release.
|
||||
* {class}`jax.experimental.pallas.GridSpec` does not have anymore the `in_specs_tree`,
|
||||
and the `out_specs_tree` fields, and the `in_specs` and `out_specs` tree now
|
||||
store the values as pytrees of BlockSpec. Previously, `in_specs` and
|
||||
`out_specs` were flattened ({jax-issue}`#22552`).
|
||||
* The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has
|
||||
been removed because it is private. Similarly, the `get_grid_mapping` and
|
||||
`unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`).
|
||||
* Fixed the interpreter mode to work with BlockSpec that involve padding
|
||||
({jax-issue}`#22275`).
|
||||
Padding in interpreter mode will be with NaN, to help debug out-of-bounds
|
||||
|
@ -1705,7 +1705,7 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
|
||||
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
|
||||
|
||||
Args:
|
||||
f: a Python value that represents a dimension.
|
||||
d: a Python value that represents a dimension.
|
||||
|
||||
Returns:
|
||||
A canonical dimension value.
|
||||
|
@ -15,12 +15,13 @@
|
||||
"""Module for pallas-core functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import threading
|
||||
from typing import Any, Hashable, Union
|
||||
import warnings
|
||||
@ -55,6 +56,10 @@ TupleGrid = tuple[GridElement, ...]
|
||||
Grid = Union[NamedGrid, TupleGrid]
|
||||
StaticGrid = tuple[int, ...]
|
||||
GridMappingGrid = tuple[int | DynamicGridDim, ...]
|
||||
|
||||
# Pytrees of jax.ShapeDtypeStruct
|
||||
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
|
||||
|
||||
split_list = util.split_list
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -202,14 +207,13 @@ blocked = Blocked()
|
||||
IndexingMode = Union[Blocked, Unblocked]
|
||||
|
||||
|
||||
@dataclasses.dataclass(unsafe_hash=True)
|
||||
@dataclasses.dataclass
|
||||
class BlockSpec:
|
||||
"""Specifies how an array should be sliced for each iteration of a kernel.
|
||||
"""Specifies how an array should be sliced for each invocation of a kernel.
|
||||
|
||||
See :ref:`pallas_blockspec` for more details.
|
||||
This object contains the parameters passed through the API.
|
||||
An internal canonicalized version is in BlockMapping.
|
||||
"""
|
||||
# An internal canonicalized version is in BlockMapping.
|
||||
block_shape: tuple[int | None, ...] | None = None
|
||||
index_map: Callable[..., Any] | None = None
|
||||
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
|
||||
@ -242,22 +246,25 @@ class BlockSpec:
|
||||
self.memory_space = memory_space
|
||||
self.indexing_mode = indexing_mode
|
||||
|
||||
def compute_index(self, *args):
|
||||
assert self.index_map is not None
|
||||
out = self.index_map(*args)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out,)
|
||||
return out
|
||||
|
||||
def compute_index(bs: BlockSpec, *args):
|
||||
assert bs.index_map is not None
|
||||
out = bs.index_map(*args)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out,)
|
||||
return out
|
||||
|
||||
|
||||
class NoBlockSpec:
|
||||
pass
|
||||
def __repr__(self):
|
||||
return "NoBlockSpec"
|
||||
no_block_spec = NoBlockSpec()
|
||||
|
||||
|
||||
# A PyTree of BlockSpec | NoBlockSpec.
|
||||
# BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
BlockSpecTree = Any
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BlockMapping:
|
||||
"""An internal canonicalized version of BlockSpec.
|
||||
@ -399,6 +406,7 @@ class GridMapping:
|
||||
|
||||
if self.grid_names is not None:
|
||||
assert len(self.grid) == len(self.grid_names), (self.grid, self.grid_names)
|
||||
|
||||
for bm in self.block_mappings:
|
||||
bm.check_invariants()
|
||||
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
|
||||
@ -439,45 +447,52 @@ class GridMapping:
|
||||
|
||||
@property
|
||||
def slice_index_ops(self):
|
||||
"""Returns a slice object to select the index operands to a kernel."""
|
||||
"""Returns a slice object to select the index operands to a kernel.
|
||||
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
|
||||
"""
|
||||
return slice(0, self.num_index_operands)
|
||||
|
||||
@property
|
||||
def slice_block_ops(self):
|
||||
"""Returns a slice to select all but the index operands to a kernel."""
|
||||
"""Returns a slice to select all but the index operands to a kernel.
|
||||
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
|
||||
"""
|
||||
return slice(self.num_index_operands, None)
|
||||
|
||||
@property
|
||||
def slice_scratch_ops(self):
|
||||
"""Returns a slice object to select the scratch operands to a kernel."""
|
||||
"""Returns a slice object to select the scratch operands to a kernel.
|
||||
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
|
||||
"""
|
||||
if self.num_scratch_operands:
|
||||
return slice(-self.num_scratch_operands, None)
|
||||
else:
|
||||
return slice(0, 0)
|
||||
|
||||
# TODO(necula): this is used to recover the old `in_shapes`, but it probably
|
||||
# is not needed anymore, with some cleanup.
|
||||
@property
|
||||
def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
|
||||
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
|
||||
"""The shapes of *index, *consts, *inputs."""
|
||||
index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape,
|
||||
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
|
||||
ia.inner_aval.dtype)
|
||||
for ia in self.index_map_avals[len(self.grid):]]
|
||||
consts_inputs_shapes = [
|
||||
for ia in self.index_map_avals[len(self.grid):])
|
||||
consts_inputs_shapes = (
|
||||
bm.array_shape_dtype
|
||||
for bm in self.block_mappings[
|
||||
:self.num_constant_operands + self.num_inputs]]
|
||||
return tuple(index_shapes + consts_inputs_shapes)
|
||||
:self.num_constant_operands + self.num_inputs])
|
||||
return itertools.chain(index_shapes, consts_inputs_shapes)
|
||||
|
||||
# TODO(necula): this is used to recover the old `out_shapes`, but it probably
|
||||
# is not needed anymore, with some cleanup.
|
||||
@property
|
||||
def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
|
||||
def block_mappings_output(self) -> Iterable[BlockMapping]:
|
||||
return itertools.islice(
|
||||
self.block_mappings,
|
||||
self.num_constant_operands + self.num_inputs,
|
||||
self.num_constant_operands + self.num_inputs + self.num_outputs)
|
||||
|
||||
@property
|
||||
def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
|
||||
return tuple(
|
||||
bm.array_shape_dtype
|
||||
for bm in self.block_mappings[
|
||||
self.num_constant_operands + self.num_inputs:
|
||||
self.num_constant_operands + self.num_inputs + self.num_outputs])
|
||||
bm.array_shape_dtype for bm in self.block_mappings_output)
|
||||
|
||||
|
||||
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
|
||||
if isinstance(dim, jax.Array):
|
||||
@ -500,9 +515,9 @@ def _convert_block_spec_to_block_mapping(
|
||||
if block_spec is no_block_spec:
|
||||
block_spec = BlockSpec(None, None)
|
||||
if block_spec.index_map is None:
|
||||
compute_index = lambda *args: (0,) * len(array_aval.shape)
|
||||
index_map_func = lambda *args: (0,) * len(array_aval.shape)
|
||||
else:
|
||||
compute_index = block_spec.compute_index
|
||||
index_map_func = functools.partial(compute_index, block_spec)
|
||||
if block_spec.block_shape is None:
|
||||
block_shape = array_aval.shape
|
||||
else:
|
||||
@ -522,7 +537,7 @@ def _convert_block_spec_to_block_mapping(
|
||||
"dynamically-shaped blocks. "
|
||||
f"{origin} has block_shape: {block_aval.shape}")
|
||||
|
||||
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index),
|
||||
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func),
|
||||
index_map_tree)
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
|
||||
@ -559,16 +574,21 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
shape = tuple(s for s in block_shape if s is not None)
|
||||
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class GridSpec:
|
||||
"""Encodes the parameters of the grid, as given through the API.
|
||||
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
|
||||
|
||||
An internal sanitized version is in GridMapping.
|
||||
@dataclasses.dataclass(init=False)
|
||||
class GridSpec:
|
||||
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
|
||||
|
||||
See the documentation for :func:`jax.experimental.pallas.pallas_call`,
|
||||
and also :ref:`pallas_grids_and_blockspecs` for a more detailed
|
||||
description of the parameters.
|
||||
"""
|
||||
# A canonicalized internal version is in GridMapping.
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
in_specs: BlockSpecTree
|
||||
out_specs: BlockSpecTree
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -601,149 +621,151 @@ class GridSpec:
|
||||
self.grid = grid # type: ignore
|
||||
self.grid_names = grid_names
|
||||
|
||||
def get_grid_mapping(
|
||||
self,
|
||||
in_avals: Sequence[jax_core.AbstractValue],
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_avals: Sequence[jax_core.AbstractValue],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
out_paths: Sequence[tree_util.KeyPath],
|
||||
num_scalar_prefetch: int = 0,
|
||||
scratch_shapes: Sequence[Any] = (),
|
||||
) -> tuple[tuple[AbstractMemoryRef, ...],
|
||||
GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in self.grid)
|
||||
grid_mapping_grid = tuple(
|
||||
dynamic_grid_dim if d is None else d for d in self.grid
|
||||
)
|
||||
# The inputs for the index maps
|
||||
index_map_avals = (
|
||||
(jax_core.ShapedArray((), jnp.dtype("int32")),) * len(self.grid))
|
||||
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
|
||||
|
||||
if num_scalar_prefetch:
|
||||
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
|
||||
scalar_avals, unflat_in_avals = split_list(
|
||||
all_avals, [num_scalar_prefetch])
|
||||
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
|
||||
num_flat_scalar_prefetch = len(flat_scalar_avals)
|
||||
scalar_ref_avals = [
|
||||
self._make_scalar_ref_aval(aval)
|
||||
for aval in flat_scalar_avals]
|
||||
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
|
||||
scalar_tree, scalar_ref_avals)
|
||||
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||
index_map_tree = tree_util.tree_structure(((*index_map_avals,
|
||||
*scalar_avals), {}))
|
||||
index_map_avals = (*index_map_avals, *scalar_ref_avals)
|
||||
del scalar_ref_avals, flat_scalar_avals, scalar_tree
|
||||
del scalar_avals, unflat_in_avals, all_avals
|
||||
else:
|
||||
num_flat_scalar_prefetch = 0
|
||||
jaxpr_scalar_ref_avals = ()
|
||||
|
||||
if scratch_shapes:
|
||||
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
|
||||
scratch_shapes)
|
||||
flat_scratch_avals = map(self._make_scratch_aval, flat_scratch_shapes)
|
||||
num_flat_scratch_operands = len(flat_scratch_avals)
|
||||
jaxpr_scratch_avals = tree_util.tree_unflatten(
|
||||
scratch_tree, flat_scratch_avals)
|
||||
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
|
||||
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
|
||||
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
|
||||
else:
|
||||
num_flat_scratch_operands = 0
|
||||
jaxpr_scratch_avals = ()
|
||||
|
||||
if self.in_specs is not no_block_spec:
|
||||
flat_in_specs, in_specs_tree = tree_util.tree_flatten(self.in_specs)
|
||||
if in_specs_tree != in_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
|
||||
"inputs", in_tree))
|
||||
else:
|
||||
flat_in_specs = [no_block_spec] * len(in_avals)
|
||||
|
||||
in_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
index_map_avals=index_map_avals,
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="inputs",
|
||||
),
|
||||
flat_in_specs,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_avals,
|
||||
)
|
||||
|
||||
if self.out_specs is not no_block_spec:
|
||||
flat_out_specs, out_specs_tree = tree_util.tree_flatten(self.out_specs)
|
||||
if out_specs_tree != out_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
|
||||
"`out_shape`", out_tree))
|
||||
else:
|
||||
flat_out_specs = [no_block_spec] * len(out_avals)
|
||||
|
||||
out_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
index_map_avals=index_map_avals,
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="outputs",
|
||||
),
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
grid=grid_mapping_grid, # type: ignore[arg-type]
|
||||
grid_names=self.grid_names,
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
index_map_avals=index_map_avals, # type: ignore[arg-type]
|
||||
index_map_tree=index_map_tree,
|
||||
vmapped_dims=(),
|
||||
num_index_operands=num_flat_scalar_prefetch,
|
||||
num_constant_operands=0, # Fixed up later
|
||||
num_inputs=len(flat_in_specs),
|
||||
num_outputs=len(flat_out_specs),
|
||||
num_scratch_operands=num_flat_scratch_operands,
|
||||
)
|
||||
grid_mapping.check_invariants()
|
||||
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
|
||||
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
|
||||
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals,
|
||||
*jaxpr_scratch_avals), grid_mapping
|
||||
|
||||
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
|
||||
assert False # Not needed in GridSpec
|
||||
|
||||
def _make_scalar_ref_aval(self, aval):
|
||||
assert False # Not needed in GridSpec
|
||||
|
||||
def unzip_dynamic_grid_bounds(
|
||||
self,
|
||||
) -> tuple[GridSpec, tuple[Any, ...]]:
|
||||
static_grid = tuple(
|
||||
d if isinstance(d, int) else None for d in self.grid
|
||||
)
|
||||
dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int))
|
||||
# We can't use dataclasses.replace, because our fields are incompatible
|
||||
# with __init__'s signature.
|
||||
static_self = copy.copy(self)
|
||||
static_self.grid = static_grid # type: ignore
|
||||
return static_self, dynamic_bounds
|
||||
|
||||
def get_grid_mapping(
|
||||
grid_spec: GridSpec,
|
||||
in_avals: Sequence[jax_core.AbstractValue],
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
out_avals: Sequence[jax_core.AbstractValue],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
out_paths: Sequence[tree_util.KeyPath],
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...],
|
||||
GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in grid_spec.grid)
|
||||
grid_mapping_grid = tuple(
|
||||
dynamic_grid_dim if d is None else d for d in grid_spec.grid
|
||||
)
|
||||
# The inputs for the index maps
|
||||
index_map_avals = (
|
||||
(index_map_grid_aval,) * len(grid_spec.grid))
|
||||
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
|
||||
|
||||
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)
|
||||
if num_scalar_prefetch:
|
||||
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
|
||||
scalar_avals, unflat_in_avals = split_list(
|
||||
all_avals, [num_scalar_prefetch])
|
||||
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
|
||||
num_flat_scalar_prefetch = len(flat_scalar_avals)
|
||||
scalar_ref_avals = [
|
||||
grid_spec._make_scalar_ref_aval(aval)
|
||||
for aval in flat_scalar_avals]
|
||||
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
|
||||
scalar_tree, scalar_ref_avals)
|
||||
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||
index_map_tree = tree_util.tree_structure(((*index_map_avals,
|
||||
*scalar_avals), {}))
|
||||
index_map_avals = (*index_map_avals, *scalar_ref_avals)
|
||||
del scalar_ref_avals, flat_scalar_avals, scalar_tree
|
||||
del scalar_avals, unflat_in_avals, all_avals
|
||||
else:
|
||||
num_flat_scalar_prefetch = 0
|
||||
jaxpr_scalar_ref_avals = ()
|
||||
|
||||
scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ())
|
||||
if scratch_shapes:
|
||||
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
|
||||
scratch_shapes)
|
||||
flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes)
|
||||
num_flat_scratch_operands = len(flat_scratch_avals)
|
||||
jaxpr_scratch_avals = tree_util.tree_unflatten(
|
||||
scratch_tree, flat_scratch_avals)
|
||||
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
|
||||
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
|
||||
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
|
||||
else:
|
||||
num_flat_scratch_operands = 0
|
||||
jaxpr_scratch_avals = ()
|
||||
|
||||
if grid_spec.in_specs is not no_block_spec:
|
||||
flat_in_specs, in_specs_tree = tree_util.tree_flatten(grid_spec.in_specs)
|
||||
if in_specs_tree != in_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
|
||||
"inputs", in_tree))
|
||||
else:
|
||||
flat_in_specs = [no_block_spec] * len(in_avals)
|
||||
|
||||
in_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
index_map_avals=index_map_avals,
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="inputs",
|
||||
),
|
||||
flat_in_specs,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_avals,
|
||||
)
|
||||
|
||||
if grid_spec.out_specs is not no_block_spec:
|
||||
flat_out_specs, out_specs_tree = tree_util.tree_flatten(grid_spec.out_specs)
|
||||
if out_specs_tree != out_tree:
|
||||
raise ValueError(
|
||||
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
|
||||
"`out_shape`", out_tree))
|
||||
else:
|
||||
flat_out_specs = [no_block_spec] * len(out_avals)
|
||||
|
||||
out_block_mappings = map(
|
||||
partial(
|
||||
_convert_block_spec_to_block_mapping,
|
||||
index_map_avals=index_map_avals,
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="outputs",
|
||||
),
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
grid=grid_mapping_grid, # type: ignore[arg-type]
|
||||
grid_names=grid_spec.grid_names,
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
index_map_avals=index_map_avals, # type: ignore[arg-type]
|
||||
index_map_tree=index_map_tree,
|
||||
vmapped_dims=(),
|
||||
num_index_operands=num_flat_scalar_prefetch,
|
||||
num_constant_operands=0, # Fixed up later
|
||||
num_inputs=len(flat_in_specs),
|
||||
num_outputs=len(flat_out_specs),
|
||||
num_scratch_operands=num_flat_scratch_operands,
|
||||
)
|
||||
grid_mapping.check_invariants()
|
||||
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
|
||||
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
|
||||
*jaxpr_in_ref_avals)
|
||||
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals,
|
||||
*jaxpr_scratch_avals), grid_mapping
|
||||
|
||||
|
||||
def unzip_dynamic_grid_bounds(
|
||||
grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]:
|
||||
static_grid = tuple(
|
||||
d if isinstance(d, int) else None for d in grid_spec.grid
|
||||
)
|
||||
dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int))
|
||||
# We can't use dataclasses.replace, because our fields are incompatible
|
||||
# with __init__'s signature.
|
||||
static_self = copy.copy(grid_spec)
|
||||
static_self.grid = static_grid # type: ignore
|
||||
return static_self, dynamic_bounds
|
||||
|
||||
|
||||
def pytreedef_mismatch_err_msg(
|
||||
|
@ -67,7 +67,7 @@ class barrier_semaphore(semaphore_dtype): pass
|
||||
class AbstractSemaphoreTyRules:
|
||||
@staticmethod
|
||||
def pallas_interpret_element_aval(_) -> jax_core.ShapedArray:
|
||||
return jax_core.ShapedArray((), jnp.dtype('int32'))
|
||||
return pallas_core.index_map_grid_aval
|
||||
|
||||
class AbstractSemaphoreTy(dtypes.ExtendedDType):
|
||||
name: str
|
||||
@ -145,8 +145,8 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
num_scalar_prefetch: int
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
|
||||
in_specs: pallas_core.BlockSpecTree
|
||||
out_specs: pallas_core.BlockSpecTree
|
||||
scratch_shapes: tuple[Any, ...]
|
||||
|
||||
def __init__(
|
||||
@ -173,14 +173,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
raise ValueError(f"No registered conversion for {type(obj)}. "
|
||||
"Only VMEM and SemaphoreType are supported.")
|
||||
|
||||
def get_grid_mapping( # type: ignore[override]
|
||||
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
return super().get_grid_mapping(in_avals, in_tree, in_paths,
|
||||
out_avals, out_tree, out_paths,
|
||||
num_scalar_prefetch=self.num_scalar_prefetch,
|
||||
scratch_shapes=self.scratch_shapes)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
|
@ -262,6 +262,7 @@ def _get_arg_type(
|
||||
memory_space = TPUMemorySpace.VMEM
|
||||
if isinstance(aval, tpu_core.AbstractSemaphore):
|
||||
return aval_to_ir_type(aval), None
|
||||
# TODO(necula): clean this None block_mapping
|
||||
if block_mapping is None:
|
||||
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
|
||||
shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape)
|
||||
@ -296,6 +297,7 @@ class MosaicGridMapping:
|
||||
self.jaxpr = jaxpr
|
||||
self.block_mappings = grid_mapping.block_mappings
|
||||
self.mapped_dims = grid_mapping.vmapped_dims
|
||||
# TODO(necula): clean this using new grid_mapping helpers
|
||||
num_scalar_prefetch = grid_mapping.num_index_operands
|
||||
num_scratch = grid_mapping.num_scratch_operands
|
||||
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
|
||||
@ -348,7 +350,7 @@ class MosaicGridMapping:
|
||||
for aval in scratch_avals
|
||||
)
|
||||
self.grid_types, _ = unzip2([
|
||||
_get_arg_type(jax_core.ShapedArray((), jnp.int32), None)
|
||||
_get_arg_type(pl_core.index_map_grid_aval, None)
|
||||
for _ in range(len(self.grid))
|
||||
])
|
||||
self._prepare_mesh_info(mesh)
|
||||
@ -432,9 +434,6 @@ def lower_jaxpr_to_module(
|
||||
mesh: mesh_lib.Mesh | None = None,
|
||||
for_verification: bool = False,
|
||||
) -> tuple[Module, tuple[Any, ...]]:
|
||||
# TODO(necula): cleanup
|
||||
in_shapes = grid_mapping.in_shapes
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
for bm in grid_mapping.block_mappings:
|
||||
def err_details():
|
||||
return (f"Block spec for {bm.origin} has block shape "
|
||||
@ -510,33 +509,17 @@ def lower_jaxpr_to_module(
|
||||
window_params = []
|
||||
grid = mosaic_grid_mapping.grid
|
||||
if grid:
|
||||
invars = jaxpr.invars
|
||||
if grid_mapping.num_scratch_operands > 0:
|
||||
invars = invars[
|
||||
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
|
||||
else:
|
||||
invars = invars[grid_mapping.num_index_operands:]
|
||||
# invars now = *consts, *ins, *outs
|
||||
avals = tuple(v.aval for v in invars)
|
||||
# TODO(necula): we should not need block_operand_shapes anymore
|
||||
block_operand_shapes = (
|
||||
*in_shapes[grid_mapping.num_index_operands:],
|
||||
*out_shapes,
|
||||
)
|
||||
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
|
||||
for i, (full_ty, bm, aval) in enumerate(
|
||||
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
|
||||
):
|
||||
for i, bm in enumerate(grid_mapping.block_mappings):
|
||||
func_name = f"transform_{i}"
|
||||
# ANY operands don't support windowing and require empty window_params.
|
||||
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
|
||||
if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY:
|
||||
# We checked above that the block does not require windowing.
|
||||
window_params.append(ir.DictAttr.get())
|
||||
continue
|
||||
mlir_func = lower_jaxpr_to_transform_func(
|
||||
ctx,
|
||||
bm.index_map_jaxpr.jaxpr,
|
||||
aval,
|
||||
bm.block_aval,
|
||||
name=func_name,
|
||||
mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
for_verification=for_verification,
|
||||
@ -547,7 +530,7 @@ def lower_jaxpr_to_module(
|
||||
]
|
||||
# If we have an extended dtype, we need to add the block shape for the
|
||||
# remaining physical dtype.
|
||||
block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval))
|
||||
block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval))
|
||||
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
|
||||
block_params = dict(
|
||||
window_bounds=window_shape,
|
||||
@ -941,7 +924,7 @@ def _make_index(s):
|
||||
def _maybe_cast_to_index(cast_to_index, x):
|
||||
if cast_to_index:
|
||||
return _make_index(x)
|
||||
return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32))
|
||||
return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval)
|
||||
|
||||
|
||||
def _index_to_start_size_stride(
|
||||
@ -2156,9 +2139,8 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
||||
if unroll != 1:
|
||||
raise NotImplementedError(
|
||||
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
|
||||
i32 = jax_core.ShapedArray((), jnp.int32)
|
||||
lbd = _ensure_mlir_value(start, i32)
|
||||
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, i32))
|
||||
lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval)
|
||||
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval))
|
||||
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
|
||||
for_op = scf.ForOp(lbd, ubd, step, args)
|
||||
with ir.InsertionPoint(for_op.body):
|
||||
@ -2626,8 +2608,8 @@ def _device_id_to_logical(
|
||||
return sum(a * b for a, b in zip(indices, mesh_strides))
|
||||
lower_ctx = LoweringRuleContext(
|
||||
lowering_context=ctx.lowering_context,
|
||||
avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids),
|
||||
avals_out=[jax_core.ShapedArray((), jnp.int32)],
|
||||
avals_in=[pl_core.index_map_grid_aval] * len(device_ids),
|
||||
avals_out=[pl_core.index_map_grid_aval],
|
||||
block_shapes=(None,) * len(device_ids),
|
||||
)
|
||||
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
|
||||
|
@ -74,8 +74,6 @@ def pallas_call_tpu_lowering_rule(
|
||||
compiler_params: dict[str, Any]):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
del interpret
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
if "mosaic_params" in compiler_params:
|
||||
@ -118,7 +116,9 @@ def pallas_call_tpu_lowering_rule(
|
||||
(a[0] + num_dyn_bounds + num_extra_args, a[1])
|
||||
for a in input_output_aliases
|
||||
)
|
||||
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
|
||||
out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape,
|
||||
bm.array_shape_dtype.dtype)
|
||||
for bm in grid_mapping.block_mappings_output]
|
||||
|
||||
if promela_dump_path := _DUMP_PROMELA_TO.value:
|
||||
num_devices = 1 if mesh is None else mesh.devices.size
|
||||
|
@ -276,7 +276,7 @@ class BufferedRef:
|
||||
|
||||
@property
|
||||
def compute_index(self):
|
||||
return self.spec.compute_index
|
||||
return lambda *args: pallas_core.compute_index(self.spec, *args)
|
||||
|
||||
@property
|
||||
def memory_space(self):
|
||||
|
@ -147,7 +147,7 @@ def lower_jaxpr_to_module(
|
||||
name: str,
|
||||
compiler_params: dict[str, Any],
|
||||
) -> LoweringResult:
|
||||
in_structs = grid_mapping.in_shapes
|
||||
in_structs = tuple(grid_mapping.in_shapes)
|
||||
out_structs = grid_mapping.out_shapes
|
||||
assert len(jaxpr.outvars) == 0
|
||||
assert not grid_mapping.vmapped_dims
|
||||
|
@ -15,7 +15,7 @@
|
||||
"""Module for calling pallas functions from JAX."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import partial, reduce
|
||||
import itertools
|
||||
from typing import Any
|
||||
@ -53,6 +53,7 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Grid = pallas_core.Grid
|
||||
TupleGrid = pallas_core.TupleGrid
|
||||
GridSpec = pallas_core.GridSpec
|
||||
BlockMapping = pallas_core.BlockMapping
|
||||
GridMapping = pallas_core.GridMapping
|
||||
@ -118,14 +119,16 @@ def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
|
||||
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
|
||||
|
||||
def _initialize_output_vals(
|
||||
out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]:
|
||||
block_mappings_output: Iterable[BlockMapping],
|
||||
input_args, input_output_aliases) -> Sequence[jax.Array]:
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
output_vals = []
|
||||
for i, out_shape in enumerate(out_shapes):
|
||||
for i, bm in enumerate(block_mappings_output):
|
||||
if i in oi_map:
|
||||
output_vals.append(input_args[oi_map[i]])
|
||||
else:
|
||||
output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype))
|
||||
output_vals.append(uninitialized_value(bm.array_shape_dtype.shape,
|
||||
bm.array_shape_dtype.dtype))
|
||||
return output_vals
|
||||
|
||||
def _logical_to_interpret_mode_dtype(dtype):
|
||||
@ -171,8 +174,6 @@ def _pallas_call_impl_interpret(
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any):
|
||||
del compiler_params, name
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
@ -189,16 +190,18 @@ def _pallas_call_impl_interpret(
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
out = _initialize_output_vals(grid_mapping.block_mappings_output,
|
||||
args, input_output_aliases)
|
||||
scalars = args[grid_mapping.slice_index_ops]
|
||||
block_args = args[len(scalars):]
|
||||
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
|
||||
# args now contains: *consts, *inputs, *outputs
|
||||
# block_args now contains: *consts, *inputs, *outputs
|
||||
scratch_invars = jaxpr.invars[grid_mapping.slice_scratch_ops]
|
||||
scratch_avals = [v.aval for v in scratch_invars]
|
||||
scratch_values = _initialize_scratch_vals(scratch_avals)
|
||||
|
||||
carry = []
|
||||
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
|
||||
for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings):
|
||||
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
@ -224,7 +227,7 @@ def _pallas_call_impl_interpret(
|
||||
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
|
||||
carry.extend(scratch_values)
|
||||
|
||||
num_inout = len(args) + len(out)
|
||||
num_inout_blocks = len(block_args) + len(out)
|
||||
grid_start_indices = (jnp.int32(0),) * len(grid)
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid)
|
||||
@ -239,19 +242,19 @@ def _pallas_call_impl_interpret(
|
||||
i, *_ = carry
|
||||
return i < num_iterations
|
||||
def body(carry):
|
||||
i, loop_idx, *carry = carry
|
||||
i, loop_idx, *carry_blocks = carry
|
||||
local_grid_env = tuple(
|
||||
pallas_core.GridAxis(idx, b)
|
||||
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
|
||||
if dim not in grid_mapping.vmapped_dims
|
||||
)
|
||||
carry, scratch = split_list(carry, [num_inout])
|
||||
carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
||||
is_indexing_dim)
|
||||
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes,
|
||||
carry_consts_ins, is_indexing_dim)
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
|
||||
scratch_values
|
||||
@ -263,20 +266,21 @@ def _pallas_call_impl_interpret(
|
||||
)
|
||||
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
|
||||
*blocks, *scratch)
|
||||
blocks = blocks[grid_mapping.num_index_operands:]
|
||||
blocks, out_scratch = split_list(blocks, [num_inout])
|
||||
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry, blocks, is_indexing_dim)
|
||||
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
|
||||
|
||||
_, out_inout, out_scratch = split_list(
|
||||
blocks, [grid_mapping.num_index_operands, num_inout_blocks])
|
||||
out_carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry_consts_ins, out_inout, is_indexing_dim)
|
||||
return (i + 1, _get_next_indices(grid, loop_idx),
|
||||
*out_carry, *out_scratch)
|
||||
|
||||
(_, _, *carry) = lax.while_loop(
|
||||
cond, body, (jnp.int32(0), grid_start_indices, *carry)
|
||||
)
|
||||
_, out, _ = split_list(carry, [len(args), len(out)])
|
||||
assert len(grid_mapping.block_mappings) == len(args) + len(out)
|
||||
out_block_mappings = grid_mapping.block_mappings[len(args):]
|
||||
|
||||
out_out = carry[len(block_args):len(block_args) + len(out)]
|
||||
out_nopad = []
|
||||
for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings):
|
||||
for o, bm in zip(out_out, grid_mapping.block_mappings_output):
|
||||
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
|
||||
padding = bm.indexing_mode.padding
|
||||
if padding is not None and any(p != (0, 0) for p in padding):
|
||||
@ -285,23 +289,22 @@ def _pallas_call_impl_interpret(
|
||||
pad_low, pad_high = zip(*padding)
|
||||
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
|
||||
o = lax.slice(o, pad_low, limit_indices)
|
||||
if o.shape != expected_o_shape.shape:
|
||||
o = lax.slice(o, (0,) * o.ndim, expected_o_shape.shape)
|
||||
if o.shape != bm.array_shape_dtype.shape:
|
||||
o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape)
|
||||
out_nopad.append(o)
|
||||
return out_nopad
|
||||
|
||||
pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
def _pallas_call_abstract_eval(*avals, grid_mapping, **_):
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
|
||||
def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
|
||||
return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape,
|
||||
bm.array_shape_dtype.dtype)
|
||||
for bm in grid_mapping.block_mappings_output)
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping, debug, interpret, compiler_params: Any):
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if grid_mapping.num_index_operands:
|
||||
@ -310,7 +313,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
raise NotImplementedError("JVP with aliasing not supported.")
|
||||
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
|
||||
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
|
||||
nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes)
|
||||
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
|
||||
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
|
||||
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
|
||||
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
|
||||
@ -322,7 +325,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
|
||||
# the jaxpr's invars.
|
||||
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
|
||||
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
|
||||
jvp_jaxpr.invars, [len(primals), grid_mapping.num_outputs, len(tangents)]
|
||||
)
|
||||
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
|
||||
effs = []
|
||||
@ -335,6 +338,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
|
||||
if debug:
|
||||
print(jvp_jaxpr)
|
||||
# TODO(necula): does this work with consts?
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
|
||||
jvp_grid_mapping = grid_mapping.replace(
|
||||
@ -369,8 +373,7 @@ def _batch_block_mapping(grid_mapping: GridMapping,
|
||||
if dim is not batching.not_mapped:
|
||||
indices.insert(dim, new_idx)
|
||||
return tuple(indices)
|
||||
i32_aval = jax_core.ShapedArray((), jnp.int32)
|
||||
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
with grid_mapping.trace_env():
|
||||
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_block_map_function), idx_avals)
|
||||
@ -444,8 +447,6 @@ def _batch_with_explicit_loop(
|
||||
to the current iteration index and dynamic_updates an (initially empty) output
|
||||
allocation.
|
||||
"""
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if not dims:
|
||||
raise NotImplementedError("vmapping pallas_call with no arguments.")
|
||||
|
||||
@ -465,10 +466,9 @@ def _batch_with_explicit_loop(
|
||||
# The output arrays are completelly overwritten, so we can just initialize
|
||||
# empty arrays.
|
||||
initial_state = [
|
||||
jnp.empty(
|
||||
tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype
|
||||
)
|
||||
for out_shape in out_shapes
|
||||
jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size),
|
||||
dtype=bm.array_shape_dtype.dtype)
|
||||
for bm in grid_mapping.block_mappings_output
|
||||
]
|
||||
|
||||
def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
|
||||
@ -528,8 +528,6 @@ def _pallas_call_batching_rule(
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
):
|
||||
# TODO(necula): cleanup
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
) -> jax.Array:
|
||||
@ -631,7 +629,7 @@ def _pallas_call_batching_rule(
|
||||
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
|
||||
)
|
||||
|
||||
all_dims = list(dims) + [0] * len(out_shapes)
|
||||
all_dims = list(dims) + [0] * grid_mapping.num_outputs
|
||||
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
num_scratch_operands = grid_mapping.num_scratch_operands
|
||||
@ -647,10 +645,12 @@ def _pallas_call_batching_rule(
|
||||
block_mappings,
|
||||
)
|
||||
|
||||
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(grid_mapping.index_map_avals)
|
||||
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(
|
||||
grid_mapping.index_map_avals)
|
||||
assert not index_map_tree_kwargs
|
||||
batched_index_map_args = (jax_core.ShapedArray((), jnp.int32),) + index_map_tree_args
|
||||
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten((batched_index_map_args, {}))
|
||||
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
|
||||
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
|
||||
(batched_index_map_args, {}))
|
||||
batched_grid_mapping = grid_mapping.replace(
|
||||
grid=(axis_size, *grid_mapping.grid),
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
@ -706,18 +706,12 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
# 4) Create block specs for the error state and call pallas_call with
|
||||
# the new kernel.
|
||||
dynamic_grid_bounds, scalars, args = split_list( # type: ignore
|
||||
args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands]
|
||||
args, [grid_mapping.num_dynamic_grid_bounds,
|
||||
grid_mapping.num_index_operands]
|
||||
)
|
||||
num_scalars = len(scalars)
|
||||
num_invars = len(jaxpr.invars)
|
||||
num_inputs_outputs = (
|
||||
num_invars
|
||||
- grid_mapping.num_index_operands
|
||||
- grid_mapping.num_scratch_operands
|
||||
)
|
||||
num_kernel_inputs = len(args)
|
||||
num_scratch = num_invars - num_inputs_outputs
|
||||
num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs
|
||||
num_kernel_outputs = grid_mapping.num_outputs
|
||||
|
||||
# Trace the jaxpr to get an initial error value so the kernel jaxpr has all of
|
||||
# the required inputs.
|
||||
@ -989,11 +983,11 @@ def pallas_call(
|
||||
out_shape: Any,
|
||||
*,
|
||||
grid_spec: GridSpec | None = None,
|
||||
debug: bool = False,
|
||||
grid: Grid = (),
|
||||
grid: TupleGrid = (),
|
||||
in_specs: BlockSpecTree = no_block_spec,
|
||||
out_specs: BlockSpecTree = no_block_spec,
|
||||
input_output_aliases: dict[int, int] = {},
|
||||
debug: bool = False,
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
@ -1008,9 +1002,8 @@ def pallas_call(
|
||||
corresponding ``in_specs`` and ``out_specs``.
|
||||
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
|
||||
and dtypes of the outputs.
|
||||
grid_spec: TO BE DOCUMENTED.
|
||||
debug: if True, Pallas prints various intermediate forms of the kernel
|
||||
as it is being processed.
|
||||
grid_spec: An alternative way to specify ``grid``, ``in_specs``, and
|
||||
``out_specs``. If given, those other parameters must not be also given.
|
||||
grid: the iteration space, as a tuple of integers. The kernel is executed
|
||||
as many times as ``prod(grid)``.
|
||||
See details at :ref:`pallas_grid`.
|
||||
@ -1027,6 +1020,8 @@ def pallas_call(
|
||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||
the index of the output that aliases them. These indices are in the
|
||||
flattened inputs and outputs.
|
||||
debug: if True, Pallas prints various intermediate forms of the kernel
|
||||
as it is being processed.
|
||||
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
|
||||
grid whose body is the kernel lowered as a JAX function. This does not
|
||||
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
|
||||
@ -1059,7 +1054,7 @@ def pallas_call(
|
||||
"If `grid_spec` is specified, then `out_specs` must "
|
||||
f"be `no_block_spec`. It is {out_specs}")
|
||||
del grid, in_specs, out_specs
|
||||
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
||||
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
|
||||
# TODO(necula): this canonicalization may be convenient for some usage
|
||||
# but it is lossy, because it prevents expressing functions that return
|
||||
# lists.
|
||||
@ -1078,7 +1073,8 @@ def pallas_call(
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
|
||||
kernel_avals, grid_mapping = grid_spec.get_grid_mapping(
|
||||
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
|
||||
grid_spec,
|
||||
flat_in_avals, in_tree, in_paths,
|
||||
flat_out_avals, out_tree, out_paths)
|
||||
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
|
||||
|
@ -250,7 +250,6 @@ def _new_ir_context() -> ir.Context:
|
||||
|
||||
def lower_jaxpr_to_triton_module(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
in_out_shapes,
|
||||
grid_mapping: GridMapping,
|
||||
name: str,
|
||||
platform: str
|
||||
@ -313,23 +312,22 @@ def lower_jaxpr_to_triton_module(
|
||||
raise NotImplementedError(
|
||||
"Scalar prefetch not supported in Triton lowering."
|
||||
)
|
||||
for bm in grid_mapping.block_mappings:
|
||||
if not isinstance(bm.indexing_mode, Blocked):
|
||||
raise NotImplementedError(
|
||||
"Only Blocked indexing mode is supported in Triton lowering."
|
||||
)
|
||||
if not all(isinstance(bm.indexing_mode, Blocked)
|
||||
for bm in grid_mapping.block_mappings):
|
||||
raise NotImplementedError(
|
||||
"Only Blocked indexing mode is supported in Triton lowering."
|
||||
)
|
||||
start_indices = map(
|
||||
functools.partial(_eval_index_map, ctx, program_ids),
|
||||
grid_mapping.block_mappings,
|
||||
)
|
||||
block_infos = [
|
||||
BlockInfo(
|
||||
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
|
||||
block_mapping.array_shape_dtype,
|
||||
start_idx,
|
||||
block_mapping.block_shape,
|
||||
)
|
||||
for shape_dtype, block_mapping, start_idx in zip(
|
||||
in_out_shapes,
|
||||
for block_mapping, start_idx in zip(
|
||||
grid_mapping.block_mappings,
|
||||
start_indices,
|
||||
)
|
||||
|
@ -50,9 +50,14 @@ def pallas_call_lowering(
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
del interpret
|
||||
# TODO(necula): cleanup
|
||||
in_shapes = grid_mapping.in_shapes
|
||||
out_shapes = grid_mapping.out_shapes
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Triton backend"
|
||||
)
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError(
|
||||
"scalar prefetch not implemented in the Triton backend"
|
||||
)
|
||||
triton_params = compiler_params.get("triton", compiler_params)
|
||||
num_warps = triton_params.pop("num_warps", 4)
|
||||
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
|
||||
@ -66,7 +71,7 @@ def pallas_call_lowering(
|
||||
print(grid_mapping)
|
||||
|
||||
lowering_result = lowering.lower_jaxpr_to_triton_module(
|
||||
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform
|
||||
jaxpr, grid_mapping, name, lowering_platform
|
||||
)
|
||||
module_op = lowering_result.module.operation
|
||||
if debug:
|
||||
@ -74,8 +79,9 @@ def pallas_call_lowering(
|
||||
|
||||
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
|
||||
out_types = [
|
||||
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
||||
for shape in out_shapes
|
||||
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
|
||||
mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype))
|
||||
for bm in grid_mapping.block_mappings_output
|
||||
]
|
||||
buf = io.BytesIO()
|
||||
module_op.write_bytecode(buf)
|
||||
|
@ -25,6 +25,7 @@ from jax._src.pallas.core import IndexingMode
|
||||
from jax._src.pallas.core import no_block_spec
|
||||
from jax._src.pallas.core import Unblocked
|
||||
from jax._src.pallas.core import unblocked
|
||||
from jax._src.pallas.core import GridSpec
|
||||
from jax._src.pallas.pallas_call import pallas_call
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas.primitives import atomic_add
|
||||
|
Loading…
x
Reference in New Issue
Block a user