[Pallas] Refactor memory space handling

PiperOrigin-RevId: 563586933
This commit is contained in:
Sharad Vikram 2023-09-07 17:08:18 -07:00 committed by jax authors
parent d0c4c9b3fe
commit cb114f247a
11 changed files with 264 additions and 182 deletions

View File

@ -70,12 +70,24 @@ class Mapped:
mapped = Mapped()
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(init=False, unsafe_hash=True)
class BlockSpec:
index_map: Callable[..., Any]
block_shape: tuple[int | None, ...]
index_map: Callable[..., Any] | None
block_shape: tuple[int | None, ...] | None
memory_space: Any
def __init__(self, index_map: Callable[..., Any] | None = None,
block_shape: tuple[int | None, ...] | None = None,
memory_space: Any = None):
self.index_map = index_map
if block_shape is not None and not isinstance(block_shape, tuple):
block_shape = tuple(block_shape)
self.block_shape = block_shape
self.memory_space = memory_space
def compute_index(self, *args):
assert self.index_map is not None
assert self.block_shape is not None
out = self.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
@ -86,6 +98,7 @@ class BlockSpec:
class BlockMapping:
block_shape: tuple[Mapped | int, ...]
index_map_jaxpr: jax_core.ClosedJaxpr
memory_space: Any
def compute_start_indices(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
@ -123,107 +136,131 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
def _convert_block_spec_to_block_mapping(
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
aval: jax_core.ShapedArray,
) -> BlockSpec | None:
if block_spec is _no_block_spec:
if block_spec is no_block_spec:
return None
if block_spec.index_map is None:
compute_index = lambda *args: (0,) * len(aval.shape)
block_shape = aval.shape
else:
compute_index = block_spec.compute_index
block_shape = block_spec.block_shape
block_shape = tuple(
mapped if s is None else s for s in block_spec.block_shape)
mapped if s is None else s for s in block_shape)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(block_spec.compute_index), in_avals)
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))
lu.wrap_init(compute_index), in_avals)
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts),
block_spec.memory_space)
def _compute_shape_from_block_spec(block_spec: BlockSpec | None,
arg_shape: tuple[int, ...]
) -> tuple[int, ...]:
if block_spec is _no_block_spec:
return arg_shape
return tuple(s for s in block_spec.block_shape if s is not None)
def _tile_ref(ref: jax_core.AbstractRef, block_shape: tuple[int, ...] | None
) -> jax_core.AbstractRef:
if block_shape is None:
return ref
shape = tuple(s for s in block_shape if s is not None)
return state.shaped_array_ref(shape, ref.dtype)
def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
in_ref_avals = map(state.AbstractRef, in_avals)
out_ref_avals = map(state.AbstractRef, out_avals)
if grid is None:
in_specs = [None] * len(in_avals)
out_specs = [None] * len(out_avals)
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
for arg in in_avals]
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
for arg in out_avals]
else:
in_ref_avals = [
state.shaped_array_ref(
_compute_shape_from_block_spec(
block_spec, arg.shape), arg.dtype)
for block_spec, arg in zip(in_specs, in_avals)]
out_ref_avals = [
state.shaped_array_ref(
_compute_shape_from_block_spec(
block_spec, arg.shape), arg.dtype)
for block_spec, arg in zip(out_specs, out_avals)]
return in_specs, in_ref_avals, out_specs, out_ref_avals
tiled_in_ref_avals = [
aval if in_spec is no_block_spec
else _tile_ref(aval, in_spec.block_shape)
for aval, in_spec in zip(in_ref_avals, in_specs)
]
tiled_out_ref_avals = [
aval if out_spec is no_block_spec
else _tile_ref(aval, out_spec.block_shape)
for aval, out_spec in zip(out_ref_avals, out_specs)
]
return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals
class NoBlockSpec:
pass
no_block_spec = NoBlockSpec()
_no_block_spec = object()
@dataclasses.dataclass(init=False)
@dataclasses.dataclass(init=False, unsafe_hash=True)
class GridSpec:
grid: Grid
in_specs: Sequence[BlockSpec | None] | None
out_specs: tuple[BlockSpec | None, ...] | None
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
in_specs_tree: Any
out_specs_tree: Any
def __init__(
self,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | None] | None = None,
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
in_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
out_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
):
if grid is None:
if in_specs is not None:
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
if out_specs is not None:
raise ValueError("Cannot specify `out_specs` with a `None` grid.")
self.grid = _preprocess_grid(grid)
self.in_specs = in_specs
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
out_specs = (out_specs,)
if out_specs is not None and not isinstance(out_specs, tuple):
# Be more lenient for in/out_specs
if isinstance(in_specs, list):
in_specs = tuple(in_specs)
if isinstance(out_specs, list):
out_specs = tuple(out_specs)
self.out_specs = out_specs
self.grid = _preprocess_grid(grid)
if in_specs is not no_block_spec:
flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs)
self.in_specs = tuple(flat_in_specs)
else:
self.in_specs = in_specs
self.in_specs_tree = None
if out_specs is not no_block_spec:
flat_out_specs, self.out_specs_tree = tree_util.tree_flatten(out_specs)
self.out_specs = tuple(flat_out_specs)
else:
self.out_specs = out_specs
self.out_specs_tree = None
def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree):
if self.in_specs is no_block_spec:
flat_in_specs = [no_block_spec] * len(in_avals)
else:
flat_in_specs = self.in_specs
if self.in_specs_tree != in_tree:
raise ValueError(
"Pytree specs for arguments and `in_specs` must match: "
f"{in_tree} vs. {self.in_specs_tree}")
if self.out_specs is no_block_spec:
flat_out_specs = [no_block_spec] * len(out_avals)
else:
flat_out_specs = self.out_specs
if self.out_specs_tree != out_tree:
raise ValueError(
"Pytree specs for `out_shape` and `out_specs` must match: "
f"{out_tree} vs. {self.out_specs_tree}")
return flat_in_specs, flat_out_specs
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
if self.in_specs is not None:
in_specs = self.in_specs
in_spec_tree = tree_util.tree_structure(tuple(in_specs))
if in_spec_tree != in_tree:
raise ValueError(
"Pytree specs for arguments and `in_specs` must match: "
f"{in_tree} vs. {in_spec_tree}")
else:
in_specs = [_no_block_spec] * len(in_avals)
if self.out_specs is not None:
out_specs = self.out_specs
out_spec_tree = tree_util.tree_structure(out_specs)
if out_spec_tree != out_tree:
raise ValueError(
"Pytree specs for `out_shape` and `out_specs` must match: "
f"{out_tree} vs. {out_spec_tree}")
else:
out_specs = [_no_block_spec] * len(out_avals)
flat_in_specs = tree_util.tree_leaves(in_specs)
flat_out_specs = tree_util.tree_leaves(out_specs)
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
self.grid, in_avals, flat_in_specs, out_avals,
flat_out_specs)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs)
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs,
in_ref_avals)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs)
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs,
out_ref_avals)
grid_mapping = GridMapping(
self.grid, (*in_block_mappings, *out_block_mappings), (),
num_index_operands=0)
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping

View File

@ -23,10 +23,10 @@ from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regen
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import trace
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import VMEM
SMEM = TPUMemorySpace.SMEM
ANY = TPUMemorySpace.ANY
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM
del pallas_call_registration

View File

@ -19,6 +19,7 @@ from collections.abc import Sequence
import dataclasses
import enum
import functools
from typing import Any
from jax._src import core as jax_core
from jax._src import state
@ -30,12 +31,16 @@ from jax._src.pallas import core as pallas_core
# TODO(sharadmv): enable type checking
# mypy: ignore-errors
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial
Grid = pallas_core.Grid
BlockSpec = pallas_core.BlockSpec
GridMapping = pallas_core.GridMapping
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
_preprocess_grid = pallas_core._preprocess_grid
_compute_shape_from_block_spec = pallas_core._compute_shape_from_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
split_list = util.split_list
@ -49,59 +54,114 @@ class TPUMemorySpace(enum.Enum):
def __str__(self) -> str:
return self.value
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
# A convenience function for constructing MemoryRef types.
return MemoryRef(shape, dtype, self)
@dataclasses.dataclass(init=False)
class AbstractMemoryRef(state.AbstractRef):
__slots__ = ["inner_aval", "memory_space"]
def __init__(self, inner_aval: jax_core.AbstractValue,
memory_space: TPUMemorySpace):
assert isinstance(inner_aval, jax_core.ShapedArray)
self.inner_aval = inner_aval
self.memory_space = memory_space
def __repr__(self) -> str:
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
def at_least_vspace(self):
return AbstractMemoryRef(
self.inner_aval.at_least_vspace(), self.memory_space)
def __eq__(self, other):
return (type(self) is type(other) and self.inner_aval == other.inner_aval
and self.memory_space == other.memory_space)
def __hash__(self):
return hash((self.__class__, self.inner_aval, self.memory_space))
def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
return AbstractMemoryRef(
jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type),
ref_aval.memory_space)
jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped
@dataclasses.dataclass(frozen=True)
class MemoryRef:
"""Like jax.ShapeDtypeStruct but with memory spaces."""
shape: tuple[int, ...]
dtype: jnp.dtype
memory_space: TPUMemorySpace = TPUMemorySpace.ANY
def get_aval(self):
return AbstractMemoryRef(jax_core.ShapedArray(self.shape, self.dtype),
self.memory_space)
@dataclasses.dataclass(init=False, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: Grid
num_scalar_prefetch: int
in_specs: Sequence[BlockSpec | None] | None
out_specs: tuple[BlockSpec | None, ...] | None
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
in_specs_tree: Any
out_specs_tree: Any
def __init__(
self,
num_scalar_prefetch: int,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | None] | None = None,
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
in_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
out_specs: BlockSpec
| Sequence[BlockSpec | NoBlockSpec]
| NoBlockSpec = no_block_spec,
):
if grid is None:
raise NotImplementedError("Should pass in non-`None` grid.")
self.grid = _preprocess_grid(grid)
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
out_specs = (out_specs,)
if out_specs is not None and not isinstance(out_specs, tuple):
out_specs = tuple(out_specs)
super().__init__(grid, in_specs, out_specs)
self.num_scalar_prefetch = num_scalar_prefetch
self.in_specs = in_specs
self.out_specs = out_specs
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
scalar_avals, in_avals = split_list(in_avals, [self.num_scalar_prefetch])
flat_in_specs = tree_util.tree_leaves(self.in_specs)
flat_out_specs = tree_util.tree_leaves(self.out_specs)
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
scalar_avals, unflat_in_avals = split_list(
all_avals, [self.num_scalar_prefetch])
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
num_flat_scalar_prefetch = len(flat_scalar_avals)
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_avals_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = (
pallas_core._get_ref_avals(
self.grid, in_avals, flat_in_specs,
out_avals, flat_out_specs))
scalar_ref_avals = [
state.shaped_array_ref(aval.shape, aval.dtype)
for aval in scalar_avals]
for aval in flat_scalar_avals]
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
in_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals)), in_specs)
(*grid_avals, *scalar_ref_avals)), in_specs, in_ref_avals)
out_block_mappings = map(
partial(_convert_block_spec_to_block_mapping,
(*grid_avals, *scalar_ref_avals)), out_specs)
(*grid_avals, *scalar_ref_avals)), out_specs, out_ref_avals)
grid_mapping = GridMapping(
grid=self.grid,
block_mappings=(*in_block_mappings, *out_block_mappings),
mapped_dims=(),
num_index_operands=self.num_scalar_prefetch,
num_index_operands=num_flat_scalar_prefetch,
)
jaxpr_in_avals = tree_util.tree_unflatten(
in_tree, [*scalar_ref_avals, *in_ref_avals])
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
scalar_tree, scalar_ref_avals)
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals)
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
if not isinstance(jaxpr_out_avals, (tuple, list)):
jaxpr_out_avals = (jaxpr_out_avals,)
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping

View File

@ -143,20 +143,17 @@ def lower_jaxpr_to_module(
grid_mapping: core.GridMapping,
jaxpr: jax_core.Jaxpr,
dimension_semantics: tuple[str | None, ...] | None,
memory_spaces: tuple[TPUMemorySpace | None, ...] | None
) -> ir.Module:
m = ir.Module.create()
sym_tab = ir.SymbolTable(m.operation)
if all(bm is None for bm in grid_mapping.block_mappings):
if not grid_mapping.grid:
# Trivial grid-map, we don't need to populate the transform functions.
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
memory_spaces=memory_spaces,
name="main")
m.body.append(func_op)
sym_tab.insert(func_op)
return m
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
memory_spaces=memory_spaces,
name="main")
m.body.append(func_op)
sym_tab.insert(func_op)
@ -256,10 +253,11 @@ def lower_jaxpr_to_func(
ctx: ir.Context,
jaxpr: jax_core.Jaxpr,
*,
memory_spaces: Sequence[tpu_core.TPUMemorySpace | None] | None,
grid_mapping: core.GridMapping | None,
name: str,
) -> func.FuncOp:
memory_spaces = [None if bm is None else bm.memory_space
for bm in grid_mapping.block_mappings]
if grid_mapping:
arg_types = map(
aval_to_ir_type,
@ -277,22 +275,15 @@ def lower_jaxpr_to_func(
)
return (aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
block_mapping.block_shape)
if memory_spaces is None:
memory_spaces = [None] * len(jaxpr.invars)
if len(memory_spaces) != len(jaxpr.invars):
raise ValueError("Must have as many memory spaces as inputs and outputs.")
if grid_mapping is None:
block_mappings = [None] * len(jaxpr.invars)
else:
scalar_prefetch = grid_mapping.num_index_operands
block_mappings = grid_mapping.block_mappings
block_mappings = [*[None] * scalar_prefetch, *block_mappings]
for memory_space in memory_spaces[:scalar_prefetch]:
if memory_space is not None and memory_space != SMEM:
raise ValueError("Cannot specify non-SMEM memory space for "
"scalar prefetch inputs.")
memory_spaces = memory_spaces[scalar_prefetch:]
memory_spaces = [*[SMEM] * scalar_prefetch, *memory_spaces]
assert len(memory_spaces) == len(jaxpr.invars), (
"Must have as many memory spaces as inputs and outputs.")
invar_arg_types, block_shapes = unzip2(
map(_get_arg_type, [invar.aval for invar in jaxpr.invars], block_mappings,
memory_spaces)
@ -1402,24 +1393,25 @@ def _trace_stop_lowering_rule(ctx: LoweringRuleContext):
lowering_rules[tpu_primitives.trace_stop_p] = _trace_stop_lowering_rule
def _alloc_type(type: tpu_primitives.Type):
if isinstance(type, tpu_primitives.VMEM):
aval = type.get_aval()
vmem = ir.Attribute.parse("#tpu.memory_space<vmem>")
def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
if isinstance(aval, tpu_core.AbstractMemoryRef):
memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>")
out_type = ir.MemRefType.get(
aval.shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=vmem)
aval.shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=memspace)
return memref.AllocaOp(out_type, [], []).result
raise NotImplementedError(f"Cannot allocate {type}.")
raise NotImplementedError(f"Cannot allocate {type(aval)}.")
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr,
types):
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
region = tpu.RegionOp()
in_avals = [v.aval for v in jaxpr.invars]
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
with ir.InsertionPoint(region.body):
args = [_alloc_type(type) for type in types]
args = map(_alloc_value, in_avals)
block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None
for a in in_avals)
ctx = ctx.lowering_context.replace(
block_shapes=(*ctx.block_shapes, *(t.get_block_shape() for t in types))
block_shapes=(*ctx.block_shapes, *block_shapes)
)
jaxpr_subcomp(ctx, jaxpr, *consts, *args)
tpu.YieldOp([])

View File

@ -61,13 +61,11 @@ def pallas_call_tpu_lowering_rule(
if mosaic_params is None:
mosaic_params = {}
dimension_semantics = mosaic_params.get("dimension_semantics", None)
memory_spaces = mosaic_params.get("memory_spaces", None)
kernel_regeneration_metadata = mosaic_params.get(
"kernel_regeneration_metadata"
)
mosaic_module = lowering.lower_jaxpr_to_module(
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics,
memory_spaces=memory_spaces)
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics)
if debug:
print(mosaic_module)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]

View File

@ -16,19 +16,20 @@
from __future__ import annotations
import contextlib
import dataclasses
from typing import Callable
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
import jax.numpy as jnp
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
repeat_p = jax_core.Primitive('repeat')
@ -74,27 +75,6 @@ def trace(message: str, level: int = 10):
trace_stop_p.bind()
class Type:
def get_aval(self) -> jax_core.AbstractValue:
raise NotImplementedError()
def get_block_shape(self) -> tuple[int, ...] | None:
raise NotImplementedError()
@dataclasses.dataclass(frozen=True)
class VMEM(Type):
shape: tuple[int, ...]
dtype: jnp.dtype
def get_aval(self) -> jax_core.AbstractValue:
return state.AbstractRef(jax_core.ShapedArray(self.shape, self.dtype))
def get_block_shape(self) -> tuple[int, ...] | None:
return self.shape
run_scoped_p = jax_core.Primitive('run_scoped')
run_scoped_p.multiple_results = True
@ -102,13 +82,13 @@ run_scoped_p.multiple_results = True
def run_scoped(f: Callable[..., None], *types, **kw_types) -> None:
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = [type.get_aval() for type in flat_types]
avals = map(lambda t: t.get_aval(), flat_types)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
run_scoped_p.bind(*consts, jaxpr=jaxpr, types=tuple(flat_types))
run_scoped_p.bind(*consts, jaxpr=jaxpr)
@run_scoped_p.def_effectful_abstract_eval
def _run_scoped_abstract_eval(*args, jaxpr, types):
def _run_scoped_abstract_eval(*args, jaxpr):
# jaxpr will have effects for its inputs (Refs that are allocated) and for
# constvars (closed over Refs). The effects for the allocated Refs are local
# to the jaxpr and shouldn't propagate out.

View File

@ -50,6 +50,8 @@ BlockSpec = pallas_core.BlockSpec
GridSpec = pallas_core.GridSpec
BlockMapping = pallas_core.BlockMapping
GridMapping = pallas_core.GridMapping
NoBlockSpec = pallas_core.NoBlockSpec
no_block_spec = pallas_core.no_block_spec
pallas_call_p = jax_core.Primitive('pallas_call')
pallas_call_p.multiple_results = True
@ -224,7 +226,8 @@ def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
if block_mapping is None:
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr)
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr,
memory_space=None)
return block_mapping.replace(block_shape=new_block_shape,
index_map_jaxpr=jaxpr)
@ -324,40 +327,44 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
return hoisted_jaxpr
@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: str | None = None):
def _trace_to_jaxpr(fun: Callable, grid_spec, flat_in_avals,
flat_out_avals, in_tree, out_tree):
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree,
flat_out_avals, out_tree)
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree_thunk, False,
primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
lu.wrap_init(fun), jaxpr_in_tree)
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
debug)
jaxpr = _hoist_consts_to_refs(jaxpr)
return jaxpr, consts, out_tree_thunk()
return grid_mapping, jaxpr, consts, out_tree_thunk()
def _extract_function_name(f: Callable, name: str | None) -> str:
if name is None:
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
return name
def pallas_call(
f: Callable[..., None], out_shape: Any, *,
f: Callable[..., None],
out_shape: Any,
*,
grid_spec: GridSpec | None = None,
debug: bool = False,
grid: Grid | None = None,
in_specs: Sequence[BlockSpec | None] | None = None,
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
out_specs: BlockSpec | NoBlockSpec
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
input_output_aliases: Dict[int, int] = {},
interpret: bool = False,
name: str | None = None,
**compiler_params: Any):
**compiler_params: Any,
):
name = _extract_function_name(f, name)
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)
name = _extract_function_name(f, name)
singleton = False
if not isinstance(out_shape, (tuple, list)):
out_shape = (out_shape,)
singleton = True
if not isinstance(out_shape, tuple):
if isinstance(out_shape, list):
out_shape = tuple(out_shape)
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
@ -365,14 +372,13 @@ def pallas_call(
@jax.jit
def wrapped(*args):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_avals = [jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args]
avals, grid_mapping = grid_spec.get_grid_mapping(flat_avals, in_tree,
flat_out_shapes, out_tree)
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
tuple(jaxpr_flat_avals),
primitive_name="pallas_call")
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
out_tree)
which_linear = (False,) * len(flat_args)
out_flat = pallas_call_p.bind(
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
@ -384,7 +390,5 @@ def pallas_call(
input_output_aliases=tuple(input_output_aliases.items()),
**compiler_params)
out = tree_util.tree_unflatten(out_tree, out_flat)
if singleton:
return out[0]
return out
return wrapped

View File

@ -16,6 +16,7 @@
from jax._src import pallas
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.indexing import ds
from jax._src.pallas.indexing import dslice
from jax._src.pallas.indexing import broadcast_to

View File

@ -695,6 +695,7 @@ def _flash_attention_bwd_dkv(
def qo_index_map(batch_index, _, head_index, q_seq_index):
return (batch_index, head_index, q_seq_index, 0)
qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
assert qo_spec.block_shape is not None
assert q.ndim == len(qo_spec.block_shape)
do_spec = qo_spec
assert do.ndim == len(qo_spec.block_shape)
@ -703,16 +704,19 @@ def _flash_attention_bwd_dkv(
del q_seq_index
return (batch_index, head_index, kv_seq_index, 0)
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
assert kv_spec.block_shape is not None
assert k.ndim == len(kv_spec.block_shape)
assert v.ndim == len(kv_spec.block_shape)
def lm_index_map(batch_index, _, head_index, q_seq_index):
return (batch_index, head_index, q_seq_index, 0)
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
assert lm_spec.block_shape is not None
assert l.ndim == len(lm_spec.block_shape)
assert m.ndim == len(lm_spec.block_shape)
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
assert di_spec.block_shape is not None
assert di.ndim == len(di_spec.block_shape)
in_specs = [
@ -882,6 +886,7 @@ def _flash_attention_bwd_dq(
return (batch_index, head_index, kv_seq_index, 0)
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
assert kv_spec.block_shape is not None
assert k.ndim == len(kv_spec.block_shape)
assert v.ndim == len(kv_spec.block_shape)
@ -890,10 +895,12 @@ def _flash_attention_bwd_dq(
return (batch_index, head_index, q_seq_index, 0)
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
assert lm_spec.block_shape is not None
assert l.ndim == len(lm_spec.block_shape)
assert m.ndim == len(lm_spec.block_shape)
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
assert di_spec.block_shape is not None
assert di.ndim == len(di_spec.block_shape)
in_specs = [

View File

@ -15,9 +15,10 @@
"""Contains Mosaic specific Pallas functions."""
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
from jax._src.pallas.mosaic import TPUMemorySpace
from jax._src.pallas.mosaic import ANY
from jax._src.pallas.mosaic import CMEM
from jax._src.pallas.mosaic import VMEM
from jax._src.pallas.mosaic import SMEM
from jax._src.pallas.mosaic import VMEM
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic import repeat

View File

@ -29,7 +29,7 @@ from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import state
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.pallas.pallas_call import _initial_style_open_jaxpr
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax.config import config
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
@ -134,7 +134,7 @@ class PallasTest(parameterized.TestCase):
super().setUp()
if compile_jaxpr:
compile_jaxpr.cache_clear()
_initial_style_open_jaxpr.cache_clear()
_trace_to_jaxpr.cache_clear()
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
@ -639,7 +639,8 @@ class PallasCallTest(PallasTest):
def f(x):
return add_one(add_one(x))
self.assertEqual(f(0.), 2.)
x = jnp.array(0., dtype=jnp.float32)
self.assertEqual(f(x), 2.)
self.assertEqual(trace_count, 1)
def test_pallas_compilation_cache(self):
@ -657,7 +658,8 @@ class PallasCallTest(PallasTest):
def f(x):
return add_one(add_one(x))
self.assertEqual(f(0.), 2.)
x = jnp.array(0., dtype=jnp.float32)
self.assertEqual(f(x), 2.)
num_misses = compile_jaxpr.cache_info().misses
self.assertEqual(num_misses, 1)