[pallas] More simplification of grid mapping and calling convention

In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
This commit is contained in:
George Necula 2024-07-23 15:25:14 +03:00
parent 68972de021
commit 70a11acbb1
13 changed files with 308 additions and 303 deletions

View File

@ -10,6 +10,7 @@ Classes
:toctree: _autosummary
BlockSpec
GridSpec
Slice
Functions
@ -34,4 +35,4 @@ Functions
atomic_or
atomic_xchg
debug_print
debug_print

View File

@ -17,6 +17,13 @@ Remember to align the itemized text with the first line of an item within a list
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* {class}`jax.experimental.pallas.GridSpec` does not have anymore the `in_specs_tree`,
and the `out_specs_tree` fields, and the `in_specs` and `out_specs` tree now
store the values as pytrees of BlockSpec. Previously, `in_specs` and
`out_specs` were flattened ({jax-issue}`#22552`).
* The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has
been removed because it is private. Similarly, the `get_grid_mapping` and
`unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`).
* Fixed the interpreter mode to work with BlockSpec that involve padding
({jax-issue}`#22275`).
Padding in interpreter mode will be with NaN, to help debug out-of-bounds

View File

@ -1705,7 +1705,7 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
d: a Python value that represents a dimension.
Returns:
A canonical dimension value.

View File

@ -15,12 +15,13 @@
"""Module for pallas-core functionality."""
from __future__ import annotations
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
import contextlib
import copy
import dataclasses
import enum
import functools
import itertools
import threading
from typing import Any, Hashable, Union
import warnings
@ -55,6 +56,10 @@ TupleGrid = tuple[GridElement, ...]
Grid = Union[NamedGrid, TupleGrid]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[int | DynamicGridDim, ...]
# Pytrees of jax.ShapeDtypeStruct
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
split_list = util.split_list
map, unsafe_map = util.safe_map, map
@ -202,14 +207,13 @@ blocked = Blocked()
IndexingMode = Union[Blocked, Unblocked]
@dataclasses.dataclass(unsafe_hash=True)
@dataclasses.dataclass
class BlockSpec:
"""Specifies how an array should be sliced for each iteration of a kernel.
"""Specifies how an array should be sliced for each invocation of a kernel.
See :ref:`pallas_blockspec` for more details.
This object contains the parameters passed through the API.
An internal canonicalized version is in BlockMapping.
"""
# An internal canonicalized version is in BlockMapping.
block_shape: tuple[int | None, ...] | None = None
index_map: Callable[..., Any] | None = None
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
@ -242,22 +246,25 @@ class BlockSpec:
self.memory_space = memory_space
self.indexing_mode = indexing_mode
def compute_index(self, *args):
assert self.index_map is not None
out = self.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out
def compute_index(bs: BlockSpec, *args):
assert bs.index_map is not None
out = bs.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out
class NoBlockSpec:
pass
def __repr__(self):
return "NoBlockSpec"
no_block_spec = NoBlockSpec()
# A PyTree of BlockSpec | NoBlockSpec.
# BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
BlockSpecTree = Any
@dataclasses.dataclass(frozen=True)
class BlockMapping:
"""An internal canonicalized version of BlockSpec.
@ -399,6 +406,7 @@ class GridMapping:
if self.grid_names is not None:
assert len(self.grid) == len(self.grid_names), (self.grid, self.grid_names)
for bm in self.block_mappings:
bm.check_invariants()
assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), (
@ -439,45 +447,52 @@ class GridMapping:
@property
def slice_index_ops(self):
"""Returns a slice object to select the index operands to a kernel."""
"""Returns a slice object to select the index operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
"""
return slice(0, self.num_index_operands)
@property
def slice_block_ops(self):
"""Returns a slice to select all but the index operands to a kernel."""
"""Returns a slice to select all but the index operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
"""
return slice(self.num_index_operands, None)
@property
def slice_scratch_ops(self):
"""Returns a slice object to select the scratch operands to a kernel."""
"""Returns a slice object to select the scratch operands to a kernel.
This works on a sequence that contains *index, *consts, *ins, *outs, *scratch.
"""
if self.num_scratch_operands:
return slice(-self.num_scratch_operands, None)
else:
return slice(0, 0)
# TODO(necula): this is used to recover the old `in_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def in_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
"""The shapes of *index, *consts, *inputs."""
index_shapes = [jax.ShapeDtypeStruct(ia.inner_aval.shape,
index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape,
ia.inner_aval.dtype)
for ia in self.index_map_avals[len(self.grid):]]
consts_inputs_shapes = [
for ia in self.index_map_avals[len(self.grid):])
consts_inputs_shapes = (
bm.array_shape_dtype
for bm in self.block_mappings[
:self.num_constant_operands + self.num_inputs]]
return tuple(index_shapes + consts_inputs_shapes)
:self.num_constant_operands + self.num_inputs])
return itertools.chain(index_shapes, consts_inputs_shapes)
# TODO(necula): this is used to recover the old `out_shapes`, but it probably
# is not needed anymore, with some cleanup.
@property
def out_shapes(self) -> tuple[jax.ShapeDtypeStruct, ...]:
def block_mappings_output(self) -> Iterable[BlockMapping]:
return itertools.islice(
self.block_mappings,
self.num_constant_operands + self.num_inputs,
self.num_constant_operands + self.num_inputs + self.num_outputs)
@property
def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]:
return tuple(
bm.array_shape_dtype
for bm in self.block_mappings[
self.num_constant_operands + self.num_inputs:
self.num_constant_operands + self.num_inputs + self.num_outputs])
bm.array_shape_dtype for bm in self.block_mappings_output)
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
if isinstance(dim, jax.Array):
@ -500,9 +515,9 @@ def _convert_block_spec_to_block_mapping(
if block_spec is no_block_spec:
block_spec = BlockSpec(None, None)
if block_spec.index_map is None:
compute_index = lambda *args: (0,) * len(array_aval.shape)
index_map_func = lambda *args: (0,) * len(array_aval.shape)
else:
compute_index = block_spec.compute_index
index_map_func = functools.partial(compute_index, block_spec)
if block_spec.block_shape is None:
block_shape = array_aval.shape
else:
@ -522,7 +537,7 @@ def _convert_block_spec_to_block_mapping(
"dynamically-shaped blocks. "
f"{origin} has block_shape: {block_aval.shape}")
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index),
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func),
index_map_tree)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
@ -559,16 +574,21 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
shape = tuple(s for s in block_shape if s is not None)
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
@dataclasses.dataclass(init=False, unsafe_hash=True)
class GridSpec:
"""Encodes the parameters of the grid, as given through the API.
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
An internal sanitized version is in GridMapping.
@dataclasses.dataclass(init=False)
class GridSpec:
"""Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`.
See the documentation for :func:`jax.experimental.pallas.pallas_call`,
and also :ref:`pallas_grids_and_blockspecs` for a more detailed
description of the parameters.
"""
# A canonicalized internal version is in GridMapping.
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
in_specs: BlockSpecTree
out_specs: BlockSpecTree
def __init__(
self,
@ -601,149 +621,151 @@ class GridSpec:
self.grid = grid # type: ignore
self.grid_names = grid_names
def get_grid_mapping(
self,
in_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.AbstractValue],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
num_scalar_prefetch: int = 0,
scratch_shapes: Sequence[Any] = (),
) -> tuple[tuple[AbstractMemoryRef, ...],
GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
dynamic_grid_dim if d is None else d for d in self.grid
)
# The inputs for the index maps
index_map_avals = (
(jax_core.ShapedArray((), jnp.dtype("int32")),) * len(self.grid))
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
if num_scalar_prefetch:
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
scalar_avals, unflat_in_avals = split_list(
all_avals, [num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
num_flat_scalar_prefetch = len(flat_scalar_avals)
scalar_ref_avals = [
self._make_scalar_ref_aval(aval)
for aval in flat_scalar_avals]
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
index_map_tree = tree_util.tree_structure(((*index_map_avals,
*scalar_avals), {}))
index_map_avals = (*index_map_avals, *scalar_ref_avals)
del scalar_ref_avals, flat_scalar_avals, scalar_tree
del scalar_avals, unflat_in_avals, all_avals
else:
num_flat_scalar_prefetch = 0
jaxpr_scalar_ref_avals = ()
if scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
scratch_shapes)
flat_scratch_avals = map(self._make_scratch_aval, flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
else:
num_flat_scratch_operands = 0
jaxpr_scratch_avals = ()
if self.in_specs is not no_block_spec:
flat_in_specs, in_specs_tree = tree_util.tree_flatten(self.in_specs)
if in_specs_tree != in_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
"inputs", in_tree))
else:
flat_in_specs = [no_block_spec] * len(in_avals)
in_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="inputs",
),
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_avals,
)
if self.out_specs is not no_block_spec:
flat_out_specs, out_specs_tree = tree_util.tree_flatten(self.out_specs)
if out_specs_tree != out_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
"`out_shape`", out_tree))
else:
flat_out_specs = [no_block_spec] * len(out_avals)
out_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="outputs",
),
flat_out_specs,
out_paths,
out_avals,
)
grid_mapping = GridMapping(
grid=grid_mapping_grid, # type: ignore[arg-type]
grid_names=self.grid_names,
block_mappings=(*in_block_mappings, *out_block_mappings),
index_map_avals=index_map_avals, # type: ignore[arg-type]
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_constant_operands=0, # Fixed up later
num_inputs=len(flat_in_specs),
num_outputs=len(flat_out_specs),
num_scratch_operands=num_flat_scratch_operands,
)
grid_mapping.check_invariants()
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals,
*jaxpr_scratch_avals), grid_mapping
def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue:
assert False # Not needed in GridSpec
def _make_scalar_ref_aval(self, aval):
assert False # Not needed in GridSpec
def unzip_dynamic_grid_bounds(
self,
) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(
d if isinstance(d, int) else None for d in self.grid
)
dynamic_bounds = tuple(d for d in self.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(self)
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds
def get_grid_mapping(
grid_spec: GridSpec,
in_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.AbstractValue],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
) -> tuple[tuple[jax_core.AbstractValue, ...],
GridMapping]:
assert all(i is None or isinstance(i, int) for i in grid_spec.grid)
grid_mapping_grid = tuple(
dynamic_grid_dim if d is None else d for d in grid_spec.grid
)
# The inputs for the index maps
index_map_avals = (
(index_map_grid_aval,) * len(grid_spec.grid))
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)
if num_scalar_prefetch:
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
scalar_avals, unflat_in_avals = split_list(
all_avals, [num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
num_flat_scalar_prefetch = len(flat_scalar_avals)
scalar_ref_avals = [
grid_spec._make_scalar_ref_aval(aval)
for aval in flat_scalar_avals]
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
in_avals, in_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
index_map_tree = tree_util.tree_structure(((*index_map_avals,
*scalar_avals), {}))
index_map_avals = (*index_map_avals, *scalar_ref_avals)
del scalar_ref_avals, flat_scalar_avals, scalar_tree
del scalar_avals, unflat_in_avals, all_avals
else:
num_flat_scalar_prefetch = 0
jaxpr_scalar_ref_avals = ()
scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ())
if scratch_shapes:
flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(
scratch_shapes)
flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes)
num_flat_scratch_operands = len(flat_scratch_avals)
jaxpr_scratch_avals = tree_util.tree_unflatten(
scratch_tree, flat_scratch_avals)
if not isinstance(jaxpr_scratch_avals, (tuple, list)):
jaxpr_scratch_avals = (jaxpr_scratch_avals,)
del flat_scratch_avals, flat_scratch_shapes, scratch_tree
else:
num_flat_scratch_operands = 0
jaxpr_scratch_avals = ()
if grid_spec.in_specs is not no_block_spec:
flat_in_specs, in_specs_tree = tree_util.tree_flatten(grid_spec.in_specs)
if in_specs_tree != in_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree,
"inputs", in_tree))
else:
flat_in_specs = [no_block_spec] * len(in_avals)
in_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="inputs",
),
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_avals,
)
if grid_spec.out_specs is not no_block_spec:
flat_out_specs, out_specs_tree = tree_util.tree_flatten(grid_spec.out_specs)
if out_specs_tree != out_tree:
raise ValueError(
pytreedef_mismatch_err_msg("`out_specs`", out_specs_tree,
"`out_shape`", out_tree))
else:
flat_out_specs = [no_block_spec] * len(out_avals)
out_block_mappings = map(
partial(
_convert_block_spec_to_block_mapping,
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="outputs",
),
flat_out_specs,
out_paths,
out_avals,
)
grid_mapping = GridMapping(
grid=grid_mapping_grid, # type: ignore[arg-type]
grid_names=grid_spec.grid_names,
block_mappings=(*in_block_mappings, *out_block_mappings),
index_map_avals=index_map_avals, # type: ignore[arg-type]
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
num_constant_operands=0, # Fixed up later
num_inputs=len(flat_in_specs),
num_outputs=len(flat_out_specs),
num_scratch_operands=num_flat_scratch_operands,
)
grid_mapping.check_invariants()
in_ref_avals = [bm.block_aval for bm in in_block_mappings]
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals,
*jaxpr_in_ref_avals)
out_ref_avals = [bm.block_aval for bm in out_block_mappings]
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals,
*jaxpr_scratch_avals), grid_mapping
def unzip_dynamic_grid_bounds(
grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]:
static_grid = tuple(
d if isinstance(d, int) else None for d in grid_spec.grid
)
dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(grid_spec)
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds
def pytreedef_mismatch_err_msg(

View File

@ -67,7 +67,7 @@ class barrier_semaphore(semaphore_dtype): pass
class AbstractSemaphoreTyRules:
@staticmethod
def pallas_interpret_element_aval(_) -> jax_core.ShapedArray:
return jax_core.ShapedArray((), jnp.dtype('int32'))
return pallas_core.index_map_grid_aval
class AbstractSemaphoreTy(dtypes.ExtendedDType):
name: str
@ -145,8 +145,8 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
num_scalar_prefetch: int
in_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
out_specs: tuple[BlockSpec | NoBlockSpec, ...] | NoBlockSpec
in_specs: pallas_core.BlockSpecTree
out_specs: pallas_core.BlockSpecTree
scratch_shapes: tuple[Any, ...]
def __init__(
@ -173,14 +173,6 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
raise ValueError(f"No registered conversion for {type(obj)}. "
"Only VMEM and SemaphoreType are supported.")
def get_grid_mapping( # type: ignore[override]
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
return super().get_grid_mapping(in_avals, in_tree, in_paths,
out_avals, out_tree, out_paths,
num_scalar_prefetch=self.num_scalar_prefetch,
scratch_shapes=self.scratch_shapes)
@dataclasses.dataclass(frozen=True)
class TensorCore:

View File

@ -262,6 +262,7 @@ def _get_arg_type(
memory_space = TPUMemorySpace.VMEM
if isinstance(aval, tpu_core.AbstractSemaphore):
return aval_to_ir_type(aval), None
# TODO(necula): clean this None block_mapping
if block_mapping is None:
return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
shape = tuple(1 if b is pl_core.mapped else b for b in block_mapping.block_shape)
@ -296,6 +297,7 @@ class MosaicGridMapping:
self.jaxpr = jaxpr
self.block_mappings = grid_mapping.block_mappings
self.mapped_dims = grid_mapping.vmapped_dims
# TODO(necula): clean this using new grid_mapping helpers
num_scalar_prefetch = grid_mapping.num_index_operands
num_scratch = grid_mapping.num_scratch_operands
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
@ -348,7 +350,7 @@ class MosaicGridMapping:
for aval in scratch_avals
)
self.grid_types, _ = unzip2([
_get_arg_type(jax_core.ShapedArray((), jnp.int32), None)
_get_arg_type(pl_core.index_map_grid_aval, None)
for _ in range(len(self.grid))
])
self._prepare_mesh_info(mesh)
@ -432,9 +434,6 @@ def lower_jaxpr_to_module(
mesh: mesh_lib.Mesh | None = None,
for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
# TODO(necula): cleanup
in_shapes = grid_mapping.in_shapes
out_shapes = grid_mapping.out_shapes
for bm in grid_mapping.block_mappings:
def err_details():
return (f"Block spec for {bm.origin} has block shape "
@ -510,33 +509,17 @@ def lower_jaxpr_to_module(
window_params = []
grid = mosaic_grid_mapping.grid
if grid:
invars = jaxpr.invars
if grid_mapping.num_scratch_operands > 0:
invars = invars[
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
else:
invars = invars[grid_mapping.num_index_operands:]
# invars now = *consts, *ins, *outs
avals = tuple(v.aval for v in invars)
# TODO(necula): we should not need block_operand_shapes anymore
block_operand_shapes = (
*in_shapes[grid_mapping.num_index_operands:],
*out_shapes,
)
assert len(block_operand_shapes) == len(grid_mapping.block_mappings)
for i, (full_ty, bm, aval) in enumerate(
zip(block_operand_shapes, grid_mapping.block_mappings, avals)
):
for i, bm in enumerate(grid_mapping.block_mappings):
func_name = f"transform_{i}"
# ANY operands don't support windowing and require empty window_params.
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY:
# We checked above that the block does not require windowing.
window_params.append(ir.DictAttr.get())
continue
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
aval,
bm.block_aval,
name=func_name,
mosaic_grid_mapping=mosaic_grid_mapping,
for_verification=for_verification,
@ -547,7 +530,7 @@ def lower_jaxpr_to_module(
]
# If we have an extended dtype, we need to add the block shape for the
# remaining physical dtype.
block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval))
block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval))
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
block_params = dict(
window_bounds=window_shape,
@ -941,7 +924,7 @@ def _make_index(s):
def _maybe_cast_to_index(cast_to_index, x):
if cast_to_index:
return _make_index(x)
return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32))
return _ensure_mlir_value(x, aval=pl_core.index_map_grid_aval)
def _index_to_start_size_stride(
@ -2156,9 +2139,8 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
if unroll != 1:
raise NotImplementedError(
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
i32 = jax_core.ShapedArray((), jnp.int32)
lbd = _ensure_mlir_value(start, i32)
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, i32))
lbd = _ensure_mlir_value(start, pl_core.index_map_grid_aval)
ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pl_core.index_map_grid_aval))
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
for_op = scf.ForOp(lbd, ubd, step, args)
with ir.InsertionPoint(for_op.body):
@ -2626,8 +2608,8 @@ def _device_id_to_logical(
return sum(a * b for a, b in zip(indices, mesh_strides))
lower_ctx = LoweringRuleContext(
lowering_context=ctx.lowering_context,
avals_in=[jax_core.ShapedArray((), jnp.int32)] * len(device_ids),
avals_out=[jax_core.ShapedArray((), jnp.int32)],
avals_in=[pl_core.index_map_grid_aval] * len(device_ids),
avals_out=[pl_core.index_map_grid_aval],
block_shapes=(None,) * len(device_ids),
)
return lower_fun(_linearize_mesh_indices, multiple_results=False)(

View File

@ -74,8 +74,6 @@ def pallas_call_tpu_lowering_rule(
compiler_params: dict[str, Any]):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if debug:
print(jaxpr)
if "mosaic_params" in compiler_params:
@ -118,7 +116,9 @@ def pallas_call_tpu_lowering_rule(
(a[0] + num_dyn_bounds + num_extra_args, a[1])
for a in input_output_aliases
)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype)
for bm in grid_mapping.block_mappings_output]
if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size

View File

@ -276,7 +276,7 @@ class BufferedRef:
@property
def compute_index(self):
return self.spec.compute_index
return lambda *args: pallas_core.compute_index(self.spec, *args)
@property
def memory_space(self):

View File

@ -147,7 +147,7 @@ def lower_jaxpr_to_module(
name: str,
compiler_params: dict[str, Any],
) -> LoweringResult:
in_structs = grid_mapping.in_shapes
in_structs = tuple(grid_mapping.in_shapes)
out_structs = grid_mapping.out_shapes
assert len(jaxpr.outvars) == 0
assert not grid_mapping.vmapped_dims

View File

@ -15,7 +15,7 @@
"""Module for calling pallas functions from JAX."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterable, Sequence
from functools import partial, reduce
import itertools
from typing import Any
@ -53,6 +53,7 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Grid = pallas_core.Grid
TupleGrid = pallas_core.TupleGrid
GridSpec = pallas_core.GridSpec
BlockMapping = pallas_core.BlockMapping
GridMapping = pallas_core.GridMapping
@ -118,14 +119,16 @@ def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals)
def _initialize_output_vals(
out_shapes, input_args, input_output_aliases) -> Sequence[jax.Array]:
block_mappings_output: Iterable[BlockMapping],
input_args, input_output_aliases) -> Sequence[jax.Array]:
oi_map = {v: k for k, v in input_output_aliases}
output_vals = []
for i, out_shape in enumerate(out_shapes):
for i, bm in enumerate(block_mappings_output):
if i in oi_map:
output_vals.append(input_args[oi_map[i]])
else:
output_vals.append(uninitialized_value(out_shape.shape, out_shape.dtype))
output_vals.append(uninitialized_value(bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype))
return output_vals
def _logical_to_interpret_mode_dtype(dtype):
@ -171,8 +174,6 @@ def _pallas_call_impl_interpret(
grid_mapping: GridMapping,
compiler_params: Any):
del compiler_params, name
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
# If we're in interpreter mode, we *scan* over the grid and eval the
# discharged jaxpr.
dynamic_grid_args, args = split_list( # type: ignore
@ -189,16 +190,18 @@ def _pallas_call_impl_interpret(
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
if debug:
print(discharged_jaxpr)
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
out = _initialize_output_vals(grid_mapping.block_mappings_output,
args, input_output_aliases)
scalars = args[grid_mapping.slice_index_ops]
block_args = args[len(scalars):]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
# args now contains: *consts, *inputs, *outputs
# block_args now contains: *consts, *inputs, *outputs
scratch_invars = jaxpr.invars[grid_mapping.slice_scratch_ops]
scratch_avals = [v.aval for v in scratch_invars]
scratch_values = _initialize_scratch_vals(scratch_avals)
carry = []
for x, bm in zip(itertools.chain(args, out), grid_mapping.block_mappings):
for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings):
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
@ -224,7 +227,7 @@ def _pallas_call_impl_interpret(
carry = map(_pad_values_to_block_dimension, carry, block_shapes)
carry.extend(scratch_values)
num_inout = len(args) + len(out)
num_inout_blocks = len(block_args) + len(out)
grid_start_indices = (jnp.int32(0),) * len(grid)
if grid:
num_iterations = reduce(jnp.multiply, grid)
@ -239,19 +242,19 @@ def _pallas_call_impl_interpret(
i, *_ = carry
return i < num_iterations
def body(carry):
i, loop_idx, *carry = carry
i, loop_idx, *carry_blocks = carry
local_grid_env = tuple(
pallas_core.GridAxis(idx, b)
for dim, (idx, b) in enumerate(zip(loop_idx, grid))
if dim not in grid_mapping.vmapped_dims
)
carry, scratch = split_list(carry, [num_inout])
carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks])
with pallas_core.grid_env(local_grid_env):
start_indices = [
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
for bm in grid_mapping.block_mappings]
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
is_indexing_dim)
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes,
carry_consts_ins, is_indexing_dim)
with pallas_core.grid_env(local_grid_env):
assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len(
scratch_values
@ -263,20 +266,21 @@ def _pallas_call_impl_interpret(
)
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry, blocks, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
_, out_inout, out_scratch = split_list(
blocks, [grid_mapping.num_index_operands, num_inout_blocks])
out_carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
carry_consts_ins, out_inout, is_indexing_dim)
return (i + 1, _get_next_indices(grid, loop_idx),
*out_carry, *out_scratch)
(_, _, *carry) = lax.while_loop(
cond, body, (jnp.int32(0), grid_start_indices, *carry)
)
_, out, _ = split_list(carry, [len(args), len(out)])
assert len(grid_mapping.block_mappings) == len(args) + len(out)
out_block_mappings = grid_mapping.block_mappings[len(args):]
out_out = carry[len(block_args):len(block_args) + len(out)]
out_nopad = []
for o, expected_o_shape, bm in zip(out, out_shapes, out_block_mappings):
for o, bm in zip(out_out, grid_mapping.block_mappings_output):
if isinstance(bm.indexing_mode, pallas_core.Unblocked):
padding = bm.indexing_mode.padding
if padding is not None and any(p != (0, 0) for p in padding):
@ -285,23 +289,22 @@ def _pallas_call_impl_interpret(
pad_low, pad_high = zip(*padding)
limit_indices = [s - p for s, p in zip(o.shape, pad_high)]
o = lax.slice(o, pad_low, limit_indices)
if o.shape != expected_o_shape.shape:
o = lax.slice(o, (0,) * o.ndim, expected_o_shape.shape)
if o.shape != bm.array_shape_dtype.shape:
o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape)
out_nopad.append(o)
return out_nopad
pallas_call_p.def_impl(_pallas_call_impl)
def _pallas_call_abstract_eval(*avals, grid_mapping, **_):
out_shapes = grid_mapping.out_shapes
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype)
for bm in grid_mapping.block_mappings_output)
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping, debug, interpret, compiler_params: Any):
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
if grid_mapping.num_index_operands:
@ -310,7 +313,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
raise NotImplementedError("JVP with aliasing not supported.")
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
nonzero_tangents_with_outputs = nonzero_tangents + [True] * len(out_shapes)
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ())
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts
@ -322,7 +325,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
# compatible w/ `pallas_call` (inputs then outputs), we need to shuffle around
# the jaxpr's invars.
primal_refs, primal_out_refs, tangent_refs, tangent_out_refs = split_list(
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
jvp_jaxpr.invars, [len(primals), grid_mapping.num_outputs, len(tangents)]
)
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
effs = []
@ -335,6 +338,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
if debug:
print(jvp_jaxpr)
# TODO(necula): does this work with consts?
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
jvp_bms = (*in_bms, *in_bms, *out_bms, *out_bms)
jvp_grid_mapping = grid_mapping.replace(
@ -369,8 +373,7 @@ def _batch_block_mapping(grid_mapping: GridMapping,
if dim is not batching.not_mapped:
indices.insert(dim, new_idx)
return tuple(indices)
i32_aval = jax_core.ShapedArray((), jnp.int32)
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
@ -444,8 +447,6 @@ def _batch_with_explicit_loop(
to the current iteration index and dynamic_updates an (initially empty) output
allocation.
"""
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
if not dims:
raise NotImplementedError("vmapping pallas_call with no arguments.")
@ -465,10 +466,9 @@ def _batch_with_explicit_loop(
# The output arrays are completelly overwritten, so we can just initialize
# empty arrays.
initial_state = [
jnp.empty(
tuple_insert(out_shape.shape, 0, axis_size), dtype=out_shape.dtype
)
for out_shape in out_shapes
jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size),
dtype=bm.array_shape_dtype.dtype)
for bm in grid_mapping.block_mappings_output
]
def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
@ -528,8 +528,6 @@ def _pallas_call_batching_rule(
interpret: bool,
compiler_params: Any,
):
# TODO(necula): cleanup
out_shapes = grid_mapping.out_shapes
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
) -> jax.Array:
@ -631,7 +629,7 @@ def _pallas_call_batching_rule(
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
)
all_dims = list(dims) + [0] * len(out_shapes)
all_dims = list(dims) + [0] * grid_mapping.num_outputs
num_index_operands = grid_mapping.num_index_operands
num_scratch_operands = grid_mapping.num_scratch_operands
@ -647,10 +645,12 @@ def _pallas_call_batching_rule(
block_mappings,
)
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(grid_mapping.index_map_avals)
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(
grid_mapping.index_map_avals)
assert not index_map_tree_kwargs
batched_index_map_args = (jax_core.ShapedArray((), jnp.int32),) + index_map_tree_args
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten((batched_index_map_args, {}))
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
(batched_index_map_args, {}))
batched_grid_mapping = grid_mapping.replace(
grid=(axis_size, *grid_mapping.grid),
block_mappings=tuple(batched_block_mappings),
@ -706,18 +706,12 @@ def pallas_call_checkify_rule(error: checkify.Error,
# 4) Create block specs for the error state and call pallas_call with
# the new kernel.
dynamic_grid_bounds, scalars, args = split_list( # type: ignore
args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands]
args, [grid_mapping.num_dynamic_grid_bounds,
grid_mapping.num_index_operands]
)
num_scalars = len(scalars)
num_invars = len(jaxpr.invars)
num_inputs_outputs = (
num_invars
- grid_mapping.num_index_operands
- grid_mapping.num_scratch_operands
)
num_kernel_inputs = len(args)
num_scratch = num_invars - num_inputs_outputs
num_kernel_outputs = num_invars - num_scratch - num_kernel_inputs
num_kernel_outputs = grid_mapping.num_outputs
# Trace the jaxpr to get an initial error value so the kernel jaxpr has all of
# the required inputs.
@ -989,11 +983,11 @@ def pallas_call(
out_shape: Any,
*,
grid_spec: GridSpec | None = None,
debug: bool = False,
grid: Grid = (),
grid: TupleGrid = (),
in_specs: BlockSpecTree = no_block_spec,
out_specs: BlockSpecTree = no_block_spec,
input_output_aliases: dict[int, int] = {},
debug: bool = False,
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | None = None,
@ -1008,9 +1002,8 @@ def pallas_call(
corresponding ``in_specs`` and ``out_specs``.
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
and dtypes of the outputs.
grid_spec: TO BE DOCUMENTED.
debug: if True, Pallas prints various intermediate forms of the kernel
as it is being processed.
grid_spec: An alternative way to specify ``grid``, ``in_specs``, and
``out_specs``. If given, those other parameters must not be also given.
grid: the iteration space, as a tuple of integers. The kernel is executed
as many times as ``prod(grid)``.
See details at :ref:`pallas_grid`.
@ -1027,6 +1020,8 @@ def pallas_call(
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them. These indices are in the
flattened inputs and outputs.
debug: if True, Pallas prints various intermediate forms of the kernel
as it is being processed.
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
grid whose body is the kernel lowered as a JAX function. This does not
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
@ -1059,7 +1054,7 @@ def pallas_call(
"If `grid_spec` is specified, then `out_specs` must "
f"be `no_block_spec`. It is {out_specs}")
del grid, in_specs, out_specs
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
# TODO(necula): this canonicalization may be convenient for some usage
# but it is lossy, because it prevents expressing functions that return
# lists.
@ -1078,7 +1073,8 @@ def pallas_call(
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
kernel_avals, grid_mapping = grid_spec.get_grid_mapping(
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
grid_spec,
flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)

View File

@ -250,7 +250,6 @@ def _new_ir_context() -> ir.Context:
def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr,
in_out_shapes,
grid_mapping: GridMapping,
name: str,
platform: str
@ -313,23 +312,22 @@ def lower_jaxpr_to_triton_module(
raise NotImplementedError(
"Scalar prefetch not supported in Triton lowering."
)
for bm in grid_mapping.block_mappings:
if not isinstance(bm.indexing_mode, Blocked):
raise NotImplementedError(
"Only Blocked indexing mode is supported in Triton lowering."
)
if not all(isinstance(bm.indexing_mode, Blocked)
for bm in grid_mapping.block_mappings):
raise NotImplementedError(
"Only Blocked indexing mode is supported in Triton lowering."
)
start_indices = map(
functools.partial(_eval_index_map, ctx, program_ids),
grid_mapping.block_mappings,
)
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
block_mapping.array_shape_dtype,
start_idx,
block_mapping.block_shape,
)
for shape_dtype, block_mapping, start_idx in zip(
in_out_shapes,
for block_mapping, start_idx in zip(
grid_mapping.block_mappings,
start_indices,
)

View File

@ -50,9 +50,14 @@ def pallas_call_lowering(
compiler_params: dict[str, Any],
):
del interpret
# TODO(necula): cleanup
in_shapes = grid_mapping.in_shapes
out_shapes = grid_mapping.out_shapes
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"dynamic grid bounds not supported in the Triton backend"
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
@ -66,7 +71,7 @@ def pallas_call_lowering(
print(grid_mapping)
lowering_result = lowering.lower_jaxpr_to_triton_module(
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, lowering_platform
jaxpr, grid_mapping, name, lowering_platform
)
module_op = lowering_result.module.operation
if debug:
@ -74,8 +79,9 @@ def pallas_call_lowering(
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
out_types = [
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
for shape in out_shapes
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
mlir.dtype_to_ir_type(bm.array_shape_dtype.dtype))
for bm in grid_mapping.block_mappings_output
]
buf = io.BytesIO()
module_op.write_bytecode(buf)

View File

@ -25,6 +25,7 @@ from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import Unblocked
from jax._src.pallas.core import unblocked
from jax._src.pallas.core import GridSpec
from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add