mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Pallas] Refactor memory space handling
PiperOrigin-RevId: 563586933
This commit is contained in:
parent
d0c4c9b3fe
commit
cb114f247a
@ -70,12 +70,24 @@ class Mapped:
|
||||
mapped = Mapped()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class BlockSpec:
|
||||
index_map: Callable[..., Any]
|
||||
block_shape: tuple[int | None, ...]
|
||||
index_map: Callable[..., Any] | None
|
||||
block_shape: tuple[int | None, ...] | None
|
||||
memory_space: Any
|
||||
|
||||
def __init__(self, index_map: Callable[..., Any] | None = None,
|
||||
block_shape: tuple[int | None, ...] | None = None,
|
||||
memory_space: Any = None):
|
||||
self.index_map = index_map
|
||||
if block_shape is not None and not isinstance(block_shape, tuple):
|
||||
block_shape = tuple(block_shape)
|
||||
self.block_shape = block_shape
|
||||
self.memory_space = memory_space
|
||||
|
||||
def compute_index(self, *args):
|
||||
assert self.index_map is not None
|
||||
assert self.block_shape is not None
|
||||
out = self.index_map(*args)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out,)
|
||||
@ -86,6 +98,7 @@ class BlockSpec:
|
||||
class BlockMapping:
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
memory_space: Any
|
||||
|
||||
def compute_start_indices(self, loop_idx, *args):
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
|
||||
@ -123,107 +136,131 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None,
|
||||
aval: jax_core.ShapedArray,
|
||||
) -> BlockSpec | None:
|
||||
if block_spec is _no_block_spec:
|
||||
if block_spec is no_block_spec:
|
||||
return None
|
||||
if block_spec.index_map is None:
|
||||
compute_index = lambda *args: (0,) * len(aval.shape)
|
||||
block_shape = aval.shape
|
||||
else:
|
||||
compute_index = block_spec.compute_index
|
||||
block_shape = block_spec.block_shape
|
||||
block_shape = tuple(
|
||||
mapped if s is None else s for s in block_spec.block_shape)
|
||||
mapped if s is None else s for s in block_shape)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(block_spec.compute_index), in_avals)
|
||||
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))
|
||||
lu.wrap_init(compute_index), in_avals)
|
||||
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts),
|
||||
block_spec.memory_space)
|
||||
|
||||
|
||||
def _compute_shape_from_block_spec(block_spec: BlockSpec | None,
|
||||
arg_shape: tuple[int, ...]
|
||||
) -> tuple[int, ...]:
|
||||
if block_spec is _no_block_spec:
|
||||
return arg_shape
|
||||
return tuple(s for s in block_spec.block_shape if s is not None)
|
||||
def _tile_ref(ref: jax_core.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
) -> jax_core.AbstractRef:
|
||||
if block_shape is None:
|
||||
return ref
|
||||
shape = tuple(s for s in block_shape if s is not None)
|
||||
return state.shaped_array_ref(shape, ref.dtype)
|
||||
|
||||
|
||||
def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
|
||||
in_ref_avals = map(state.AbstractRef, in_avals)
|
||||
out_ref_avals = map(state.AbstractRef, out_avals)
|
||||
if grid is None:
|
||||
in_specs = [None] * len(in_avals)
|
||||
out_specs = [None] * len(out_avals)
|
||||
in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
|
||||
for arg in in_avals]
|
||||
out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype)
|
||||
for arg in out_avals]
|
||||
else:
|
||||
in_ref_avals = [
|
||||
state.shaped_array_ref(
|
||||
_compute_shape_from_block_spec(
|
||||
block_spec, arg.shape), arg.dtype)
|
||||
for block_spec, arg in zip(in_specs, in_avals)]
|
||||
out_ref_avals = [
|
||||
state.shaped_array_ref(
|
||||
_compute_shape_from_block_spec(
|
||||
block_spec, arg.shape), arg.dtype)
|
||||
for block_spec, arg in zip(out_specs, out_avals)]
|
||||
return in_specs, in_ref_avals, out_specs, out_ref_avals
|
||||
tiled_in_ref_avals = [
|
||||
aval if in_spec is no_block_spec
|
||||
else _tile_ref(aval, in_spec.block_shape)
|
||||
for aval, in_spec in zip(in_ref_avals, in_specs)
|
||||
]
|
||||
tiled_out_ref_avals = [
|
||||
aval if out_spec is no_block_spec
|
||||
else _tile_ref(aval, out_spec.block_shape)
|
||||
for aval, out_spec in zip(out_ref_avals, out_specs)
|
||||
]
|
||||
return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals
|
||||
|
||||
class NoBlockSpec:
|
||||
pass
|
||||
no_block_spec = NoBlockSpec()
|
||||
|
||||
_no_block_spec = object()
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class GridSpec:
|
||||
grid: Grid
|
||||
in_specs: Sequence[BlockSpec | None] | None
|
||||
out_specs: tuple[BlockSpec | None, ...] | None
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
in_specs_tree: Any
|
||||
out_specs_tree: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | None] | None = None,
|
||||
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
||||
in_specs: BlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec]
|
||||
| NoBlockSpec = no_block_spec,
|
||||
out_specs: BlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec]
|
||||
| NoBlockSpec = no_block_spec,
|
||||
):
|
||||
if grid is None:
|
||||
if in_specs is not None:
|
||||
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
|
||||
if out_specs is not None:
|
||||
raise ValueError("Cannot specify `out_specs` with a `None` grid.")
|
||||
self.grid = _preprocess_grid(grid)
|
||||
self.in_specs = in_specs
|
||||
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
|
||||
out_specs = (out_specs,)
|
||||
if out_specs is not None and not isinstance(out_specs, tuple):
|
||||
# Be more lenient for in/out_specs
|
||||
if isinstance(in_specs, list):
|
||||
in_specs = tuple(in_specs)
|
||||
if isinstance(out_specs, list):
|
||||
out_specs = tuple(out_specs)
|
||||
self.out_specs = out_specs
|
||||
|
||||
self.grid = _preprocess_grid(grid)
|
||||
if in_specs is not no_block_spec:
|
||||
flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs)
|
||||
self.in_specs = tuple(flat_in_specs)
|
||||
else:
|
||||
self.in_specs = in_specs
|
||||
self.in_specs_tree = None
|
||||
if out_specs is not no_block_spec:
|
||||
flat_out_specs, self.out_specs_tree = tree_util.tree_flatten(out_specs)
|
||||
self.out_specs = tuple(flat_out_specs)
|
||||
else:
|
||||
self.out_specs = out_specs
|
||||
self.out_specs_tree = None
|
||||
|
||||
def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree):
|
||||
if self.in_specs is no_block_spec:
|
||||
flat_in_specs = [no_block_spec] * len(in_avals)
|
||||
else:
|
||||
flat_in_specs = self.in_specs
|
||||
if self.in_specs_tree != in_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for arguments and `in_specs` must match: "
|
||||
f"{in_tree} vs. {self.in_specs_tree}")
|
||||
if self.out_specs is no_block_spec:
|
||||
flat_out_specs = [no_block_spec] * len(out_avals)
|
||||
else:
|
||||
flat_out_specs = self.out_specs
|
||||
if self.out_specs_tree != out_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for `out_shape` and `out_specs` must match: "
|
||||
f"{out_tree} vs. {self.out_specs_tree}")
|
||||
return flat_in_specs, flat_out_specs
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, out_avals, out_tree
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
if self.in_specs is not None:
|
||||
in_specs = self.in_specs
|
||||
in_spec_tree = tree_util.tree_structure(tuple(in_specs))
|
||||
if in_spec_tree != in_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for arguments and `in_specs` must match: "
|
||||
f"{in_tree} vs. {in_spec_tree}")
|
||||
else:
|
||||
in_specs = [_no_block_spec] * len(in_avals)
|
||||
if self.out_specs is not None:
|
||||
out_specs = self.out_specs
|
||||
out_spec_tree = tree_util.tree_structure(out_specs)
|
||||
if out_spec_tree != out_tree:
|
||||
raise ValueError(
|
||||
"Pytree specs for `out_shape` and `out_specs` must match: "
|
||||
f"{out_tree} vs. {out_spec_tree}")
|
||||
else:
|
||||
out_specs = [_no_block_spec] * len(out_avals)
|
||||
flat_in_specs = tree_util.tree_leaves(in_specs)
|
||||
flat_out_specs = tree_util.tree_leaves(out_specs)
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_tree, out_avals, out_tree)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs, out_avals,
|
||||
flat_out_specs)
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
in_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs)
|
||||
partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs,
|
||||
in_ref_avals)
|
||||
out_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs)
|
||||
partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs,
|
||||
out_ref_avals)
|
||||
grid_mapping = GridMapping(
|
||||
self.grid, (*in_block_mappings, *out_block_mappings), (),
|
||||
num_index_operands=0)
|
||||
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
|
||||
|
@ -23,10 +23,10 @@ from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regen
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import trace
|
||||
from jax._src.pallas.mosaic.primitives import run_scoped
|
||||
from jax._src.pallas.mosaic.primitives import VMEM
|
||||
|
||||
SMEM = TPUMemorySpace.SMEM
|
||||
ANY = TPUMemorySpace.ANY
|
||||
CMEM = TPUMemorySpace.CMEM
|
||||
|
||||
SMEM = TPUMemorySpace.SMEM
|
||||
VMEM = TPUMemorySpace.VMEM
|
||||
|
||||
del pallas_call_registration
|
||||
|
@ -19,6 +19,7 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import state
|
||||
@ -30,12 +31,16 @@ from jax._src.pallas import core as pallas_core
|
||||
# TODO(sharadmv): enable type checking
|
||||
# mypy: ignore-errors
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
partial = functools.partial
|
||||
Grid = pallas_core.Grid
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
GridMapping = pallas_core.GridMapping
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
_preprocess_grid = pallas_core._preprocess_grid
|
||||
_compute_shape_from_block_spec = pallas_core._compute_shape_from_block_spec
|
||||
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
|
||||
split_list = util.split_list
|
||||
|
||||
@ -49,59 +54,114 @@ class TPUMemorySpace(enum.Enum):
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
|
||||
# A convenience function for constructing MemoryRef types.
|
||||
return MemoryRef(shape, dtype, self)
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
|
||||
class AbstractMemoryRef(state.AbstractRef):
|
||||
__slots__ = ["inner_aval", "memory_space"]
|
||||
|
||||
def __init__(self, inner_aval: jax_core.AbstractValue,
|
||||
memory_space: TPUMemorySpace):
|
||||
assert isinstance(inner_aval, jax_core.ShapedArray)
|
||||
self.inner_aval = inner_aval
|
||||
self.memory_space = memory_space
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}'
|
||||
|
||||
def at_least_vspace(self):
|
||||
return AbstractMemoryRef(
|
||||
self.inner_aval.at_least_vspace(), self.memory_space)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and self.inner_aval == other.inner_aval
|
||||
and self.memory_space == other.memory_space)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.__class__, self.inner_aval, self.memory_space))
|
||||
|
||||
|
||||
def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
|
||||
return AbstractMemoryRef(
|
||||
jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type),
|
||||
ref_aval.memory_space)
|
||||
jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MemoryRef:
|
||||
"""Like jax.ShapeDtypeStruct but with memory spaces."""
|
||||
shape: tuple[int, ...]
|
||||
dtype: jnp.dtype
|
||||
memory_space: TPUMemorySpace = TPUMemorySpace.ANY
|
||||
|
||||
def get_aval(self):
|
||||
return AbstractMemoryRef(jax_core.ShapedArray(self.shape, self.dtype),
|
||||
self.memory_space)
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: Grid
|
||||
num_scalar_prefetch: int
|
||||
in_specs: Sequence[BlockSpec | None] | None
|
||||
out_specs: tuple[BlockSpec | None, ...] | None
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
in_specs_tree: Any
|
||||
out_specs_tree: Any
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_scalar_prefetch: int,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | None] | None = None,
|
||||
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
||||
in_specs: BlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec]
|
||||
| NoBlockSpec = no_block_spec,
|
||||
out_specs: BlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec]
|
||||
| NoBlockSpec = no_block_spec,
|
||||
):
|
||||
if grid is None:
|
||||
raise NotImplementedError("Should pass in non-`None` grid.")
|
||||
self.grid = _preprocess_grid(grid)
|
||||
if out_specs is not None and not isinstance(out_specs, (tuple, list)):
|
||||
out_specs = (out_specs,)
|
||||
if out_specs is not None and not isinstance(out_specs, tuple):
|
||||
out_specs = tuple(out_specs)
|
||||
super().__init__(grid, in_specs, out_specs)
|
||||
self.num_scalar_prefetch = num_scalar_prefetch
|
||||
self.in_specs = in_specs
|
||||
self.out_specs = out_specs
|
||||
|
||||
def get_grid_mapping(
|
||||
self, in_avals, in_tree, out_avals, out_tree
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
|
||||
scalar_avals, in_avals = split_list(in_avals, [self.num_scalar_prefetch])
|
||||
flat_in_specs = tree_util.tree_leaves(self.in_specs)
|
||||
flat_out_specs = tree_util.tree_leaves(self.out_specs)
|
||||
all_avals = tree_util.tree_unflatten(in_tree, in_avals)
|
||||
scalar_avals, unflat_in_avals = split_list(
|
||||
all_avals, [self.num_scalar_prefetch])
|
||||
flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals)
|
||||
num_flat_scalar_prefetch = len(flat_scalar_avals)
|
||||
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||
in_avals, in_avals_tree, out_avals, out_tree)
|
||||
in_specs, in_ref_avals, out_specs, out_ref_avals = (
|
||||
pallas_core._get_ref_avals(
|
||||
self.grid, in_avals, flat_in_specs,
|
||||
out_avals, flat_out_specs))
|
||||
scalar_ref_avals = [
|
||||
state.shaped_array_ref(aval.shape, aval.dtype)
|
||||
for aval in scalar_avals]
|
||||
for aval in flat_scalar_avals]
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||
in_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals)), in_specs)
|
||||
(*grid_avals, *scalar_ref_avals)), in_specs, in_ref_avals)
|
||||
out_block_mappings = map(
|
||||
partial(_convert_block_spec_to_block_mapping,
|
||||
(*grid_avals, *scalar_ref_avals)), out_specs)
|
||||
(*grid_avals, *scalar_ref_avals)), out_specs, out_ref_avals)
|
||||
grid_mapping = GridMapping(
|
||||
grid=self.grid,
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
mapped_dims=(),
|
||||
num_index_operands=self.num_scalar_prefetch,
|
||||
num_index_operands=num_flat_scalar_prefetch,
|
||||
)
|
||||
jaxpr_in_avals = tree_util.tree_unflatten(
|
||||
in_tree, [*scalar_ref_avals, *in_ref_avals])
|
||||
jaxpr_scalar_ref_avals = tree_util.tree_unflatten(
|
||||
scalar_tree, scalar_ref_avals)
|
||||
jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals)
|
||||
jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals)
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
if not isinstance(jaxpr_out_avals, (tuple, list)):
|
||||
jaxpr_out_avals = (jaxpr_out_avals,)
|
||||
return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping
|
||||
|
@ -143,20 +143,17 @@ def lower_jaxpr_to_module(
|
||||
grid_mapping: core.GridMapping,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
dimension_semantics: tuple[str | None, ...] | None,
|
||||
memory_spaces: tuple[TPUMemorySpace | None, ...] | None
|
||||
) -> ir.Module:
|
||||
m = ir.Module.create()
|
||||
sym_tab = ir.SymbolTable(m.operation)
|
||||
if all(bm is None for bm in grid_mapping.block_mappings):
|
||||
if not grid_mapping.grid:
|
||||
# Trivial grid-map, we don't need to populate the transform functions.
|
||||
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
|
||||
memory_spaces=memory_spaces,
|
||||
name="main")
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
return m
|
||||
func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping,
|
||||
memory_spaces=memory_spaces,
|
||||
name="main")
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
@ -256,10 +253,11 @@ def lower_jaxpr_to_func(
|
||||
ctx: ir.Context,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
*,
|
||||
memory_spaces: Sequence[tpu_core.TPUMemorySpace | None] | None,
|
||||
grid_mapping: core.GridMapping | None,
|
||||
name: str,
|
||||
) -> func.FuncOp:
|
||||
memory_spaces = [None if bm is None else bm.memory_space
|
||||
for bm in grid_mapping.block_mappings]
|
||||
if grid_mapping:
|
||||
arg_types = map(
|
||||
aval_to_ir_type,
|
||||
@ -277,22 +275,15 @@ def lower_jaxpr_to_func(
|
||||
)
|
||||
return (aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
|
||||
block_mapping.block_shape)
|
||||
if memory_spaces is None:
|
||||
memory_spaces = [None] * len(jaxpr.invars)
|
||||
if len(memory_spaces) != len(jaxpr.invars):
|
||||
raise ValueError("Must have as many memory spaces as inputs and outputs.")
|
||||
if grid_mapping is None:
|
||||
block_mappings = [None] * len(jaxpr.invars)
|
||||
else:
|
||||
scalar_prefetch = grid_mapping.num_index_operands
|
||||
block_mappings = grid_mapping.block_mappings
|
||||
block_mappings = [*[None] * scalar_prefetch, *block_mappings]
|
||||
for memory_space in memory_spaces[:scalar_prefetch]:
|
||||
if memory_space is not None and memory_space != SMEM:
|
||||
raise ValueError("Cannot specify non-SMEM memory space for "
|
||||
"scalar prefetch inputs.")
|
||||
memory_spaces = memory_spaces[scalar_prefetch:]
|
||||
memory_spaces = [*[SMEM] * scalar_prefetch, *memory_spaces]
|
||||
assert len(memory_spaces) == len(jaxpr.invars), (
|
||||
"Must have as many memory spaces as inputs and outputs.")
|
||||
invar_arg_types, block_shapes = unzip2(
|
||||
map(_get_arg_type, [invar.aval for invar in jaxpr.invars], block_mappings,
|
||||
memory_spaces)
|
||||
@ -1402,24 +1393,25 @@ def _trace_stop_lowering_rule(ctx: LoweringRuleContext):
|
||||
lowering_rules[tpu_primitives.trace_stop_p] = _trace_stop_lowering_rule
|
||||
|
||||
|
||||
def _alloc_type(type: tpu_primitives.Type):
|
||||
if isinstance(type, tpu_primitives.VMEM):
|
||||
aval = type.get_aval()
|
||||
vmem = ir.Attribute.parse("#tpu.memory_space<vmem>")
|
||||
def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
|
||||
if isinstance(aval, tpu_core.AbstractMemoryRef):
|
||||
memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>")
|
||||
out_type = ir.MemRefType.get(
|
||||
aval.shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=vmem)
|
||||
aval.shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=memspace)
|
||||
return memref.AllocaOp(out_type, [], []).result
|
||||
raise NotImplementedError(f"Cannot allocate {type}.")
|
||||
raise NotImplementedError(f"Cannot allocate {type(aval)}.")
|
||||
|
||||
|
||||
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr,
|
||||
types):
|
||||
def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
|
||||
region = tpu.RegionOp()
|
||||
in_avals = [v.aval for v in jaxpr.invars]
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
with ir.InsertionPoint(region.body):
|
||||
args = [_alloc_type(type) for type in types]
|
||||
args = map(_alloc_value, in_avals)
|
||||
block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None
|
||||
for a in in_avals)
|
||||
ctx = ctx.lowering_context.replace(
|
||||
block_shapes=(*ctx.block_shapes, *(t.get_block_shape() for t in types))
|
||||
block_shapes=(*ctx.block_shapes, *block_shapes)
|
||||
)
|
||||
jaxpr_subcomp(ctx, jaxpr, *consts, *args)
|
||||
tpu.YieldOp([])
|
||||
|
@ -61,13 +61,11 @@ def pallas_call_tpu_lowering_rule(
|
||||
if mosaic_params is None:
|
||||
mosaic_params = {}
|
||||
dimension_semantics = mosaic_params.get("dimension_semantics", None)
|
||||
memory_spaces = mosaic_params.get("memory_spaces", None)
|
||||
kernel_regeneration_metadata = mosaic_params.get(
|
||||
"kernel_regeneration_metadata"
|
||||
)
|
||||
mosaic_module = lowering.lower_jaxpr_to_module(
|
||||
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics,
|
||||
memory_spaces=memory_spaces)
|
||||
mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics)
|
||||
if debug:
|
||||
print(mosaic_module)
|
||||
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
|
||||
|
@ -16,19 +16,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from typing import Callable
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
repeat_p = jax_core.Primitive('repeat')
|
||||
|
||||
@ -74,27 +75,6 @@ def trace(message: str, level: int = 10):
|
||||
trace_stop_p.bind()
|
||||
|
||||
|
||||
class Type:
|
||||
|
||||
def get_aval(self) -> jax_core.AbstractValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_block_shape(self) -> tuple[int, ...] | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class VMEM(Type):
|
||||
shape: tuple[int, ...]
|
||||
dtype: jnp.dtype
|
||||
|
||||
def get_aval(self) -> jax_core.AbstractValue:
|
||||
return state.AbstractRef(jax_core.ShapedArray(self.shape, self.dtype))
|
||||
|
||||
def get_block_shape(self) -> tuple[int, ...] | None:
|
||||
return self.shape
|
||||
|
||||
|
||||
run_scoped_p = jax_core.Primitive('run_scoped')
|
||||
run_scoped_p.multiple_results = True
|
||||
|
||||
@ -102,13 +82,13 @@ run_scoped_p.multiple_results = True
|
||||
def run_scoped(f: Callable[..., None], *types, **kw_types) -> None:
|
||||
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
|
||||
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
||||
avals = [type.get_aval() for type in flat_types]
|
||||
avals = map(lambda t: t.get_aval(), flat_types)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
|
||||
run_scoped_p.bind(*consts, jaxpr=jaxpr, types=tuple(flat_types))
|
||||
run_scoped_p.bind(*consts, jaxpr=jaxpr)
|
||||
|
||||
|
||||
@run_scoped_p.def_effectful_abstract_eval
|
||||
def _run_scoped_abstract_eval(*args, jaxpr, types):
|
||||
def _run_scoped_abstract_eval(*args, jaxpr):
|
||||
# jaxpr will have effects for its inputs (Refs that are allocated) and for
|
||||
# constvars (closed over Refs). The effects for the allocated Refs are local
|
||||
# to the jaxpr and shouldn't propagate out.
|
||||
|
@ -50,6 +50,8 @@ BlockSpec = pallas_core.BlockSpec
|
||||
GridSpec = pallas_core.GridSpec
|
||||
BlockMapping = pallas_core.BlockMapping
|
||||
GridMapping = pallas_core.GridMapping
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
|
||||
pallas_call_p = jax_core.Primitive('pallas_call')
|
||||
pallas_call_p.multiple_results = True
|
||||
@ -224,7 +226,8 @@ def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
new_block_shape = tuple_insert(shape, dim, pallas_core.mapped)
|
||||
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
|
||||
if block_mapping is None:
|
||||
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr)
|
||||
return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr,
|
||||
memory_space=None)
|
||||
return block_mapping.replace(block_shape=new_block_shape,
|
||||
index_map_jaxpr=jaxpr)
|
||||
|
||||
@ -324,40 +327,44 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
||||
return hoisted_jaxpr
|
||||
|
||||
@weakref_lru_cache
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
primitive_name: str | None = None):
|
||||
def _trace_to_jaxpr(fun: Callable, grid_spec, flat_in_avals,
|
||||
flat_out_avals, in_tree, out_tree):
|
||||
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree,
|
||||
flat_out_avals, out_tree)
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, out_tree_thunk, False,
|
||||
primitive_name or "<unknown>")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
lu.wrap_init(fun), jaxpr_in_tree)
|
||||
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
|
||||
debug)
|
||||
jaxpr = _hoist_consts_to_refs(jaxpr)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
return grid_mapping, jaxpr, consts, out_tree_thunk()
|
||||
|
||||
def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
if name is None:
|
||||
name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func"
|
||||
return name
|
||||
|
||||
|
||||
def pallas_call(
|
||||
f: Callable[..., None], out_shape: Any, *,
|
||||
f: Callable[..., None],
|
||||
out_shape: Any,
|
||||
*,
|
||||
grid_spec: GridSpec | None = None,
|
||||
debug: bool = False,
|
||||
grid: Grid | None = None,
|
||||
in_specs: Sequence[BlockSpec | None] | None = None,
|
||||
out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None,
|
||||
in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec,
|
||||
out_specs: BlockSpec | NoBlockSpec
|
||||
| Sequence[BlockSpec | NoBlockSpec] = no_block_spec,
|
||||
input_output_aliases: Dict[int, int] = {},
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
**compiler_params: Any):
|
||||
**compiler_params: Any,
|
||||
):
|
||||
name = _extract_function_name(f, name)
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
name = _extract_function_name(f, name)
|
||||
singleton = False
|
||||
if not isinstance(out_shape, (tuple, list)):
|
||||
out_shape = (out_shape,)
|
||||
singleton = True
|
||||
if not isinstance(out_shape, tuple):
|
||||
if isinstance(out_shape, list):
|
||||
out_shape = tuple(out_shape)
|
||||
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
|
||||
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
@ -365,14 +372,13 @@ def pallas_call(
|
||||
@jax.jit
|
||||
def wrapped(*args):
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
flat_avals = [jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args]
|
||||
avals, grid_mapping = grid_spec.get_grid_mapping(flat_avals, in_tree,
|
||||
flat_out_shapes, out_tree)
|
||||
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
|
||||
jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
|
||||
tuple(jaxpr_flat_avals),
|
||||
primitive_name="pallas_call")
|
||||
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args)
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
|
||||
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
|
||||
out_tree)
|
||||
which_linear = (False,) * len(flat_args)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
|
||||
@ -384,7 +390,5 @@ def pallas_call(
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
**compiler_params)
|
||||
out = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
if singleton:
|
||||
return out[0]
|
||||
return out
|
||||
return wrapped
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from jax._src import pallas
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import no_block_spec
|
||||
from jax._src.pallas.indexing import ds
|
||||
from jax._src.pallas.indexing import dslice
|
||||
from jax._src.pallas.indexing import broadcast_to
|
||||
|
@ -695,6 +695,7 @@ def _flash_attention_bwd_dkv(
|
||||
def qo_index_map(batch_index, _, head_index, q_seq_index):
|
||||
return (batch_index, head_index, q_seq_index, 0)
|
||||
qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
|
||||
assert qo_spec.block_shape is not None
|
||||
assert q.ndim == len(qo_spec.block_shape)
|
||||
do_spec = qo_spec
|
||||
assert do.ndim == len(qo_spec.block_shape)
|
||||
@ -703,16 +704,19 @@ def _flash_attention_bwd_dkv(
|
||||
del q_seq_index
|
||||
return (batch_index, head_index, kv_seq_index, 0)
|
||||
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
|
||||
assert kv_spec.block_shape is not None
|
||||
assert k.ndim == len(kv_spec.block_shape)
|
||||
assert v.ndim == len(kv_spec.block_shape)
|
||||
|
||||
def lm_index_map(batch_index, _, head_index, q_seq_index):
|
||||
return (batch_index, head_index, q_seq_index, 0)
|
||||
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
|
||||
assert lm_spec.block_shape is not None
|
||||
assert l.ndim == len(lm_spec.block_shape)
|
||||
assert m.ndim == len(lm_spec.block_shape)
|
||||
|
||||
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
|
||||
assert di_spec.block_shape is not None
|
||||
assert di.ndim == len(di_spec.block_shape)
|
||||
|
||||
in_specs = [
|
||||
@ -882,6 +886,7 @@ def _flash_attention_bwd_dq(
|
||||
return (batch_index, head_index, kv_seq_index, 0)
|
||||
|
||||
kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim))
|
||||
assert kv_spec.block_shape is not None
|
||||
assert k.ndim == len(kv_spec.block_shape)
|
||||
assert v.ndim == len(kv_spec.block_shape)
|
||||
|
||||
@ -890,10 +895,12 @@ def _flash_attention_bwd_dq(
|
||||
return (batch_index, head_index, q_seq_index, 0)
|
||||
|
||||
lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
|
||||
assert lm_spec.block_shape is not None
|
||||
assert l.ndim == len(lm_spec.block_shape)
|
||||
assert m.ndim == len(lm_spec.block_shape)
|
||||
|
||||
di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE))
|
||||
assert di_spec.block_shape is not None
|
||||
assert di.ndim == len(di_spec.block_shape)
|
||||
|
||||
in_specs = [
|
||||
|
@ -15,9 +15,10 @@
|
||||
"""Contains Mosaic specific Pallas functions."""
|
||||
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic import ANY
|
||||
from jax._src.pallas.mosaic import CMEM
|
||||
from jax._src.pallas.mosaic import VMEM
|
||||
from jax._src.pallas.mosaic import SMEM
|
||||
from jax._src.pallas.mosaic import VMEM
|
||||
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic import repeat
|
||||
|
@ -29,7 +29,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import state
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.pallas.pallas_call import _initial_style_open_jaxpr
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
@ -134,7 +134,7 @@ class PallasTest(parameterized.TestCase):
|
||||
super().setUp()
|
||||
if compile_jaxpr:
|
||||
compile_jaxpr.cache_clear()
|
||||
_initial_style_open_jaxpr.cache_clear()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
@ -639,7 +639,8 @@ class PallasCallTest(PallasTest):
|
||||
def f(x):
|
||||
return add_one(add_one(x))
|
||||
|
||||
self.assertEqual(f(0.), 2.)
|
||||
x = jnp.array(0., dtype=jnp.float32)
|
||||
self.assertEqual(f(x), 2.)
|
||||
self.assertEqual(trace_count, 1)
|
||||
|
||||
def test_pallas_compilation_cache(self):
|
||||
@ -657,7 +658,8 @@ class PallasCallTest(PallasTest):
|
||||
def f(x):
|
||||
return add_one(add_one(x))
|
||||
|
||||
self.assertEqual(f(0.), 2.)
|
||||
x = jnp.array(0., dtype=jnp.float32)
|
||||
self.assertEqual(f(x), 2.)
|
||||
num_misses = compile_jaxpr.cache_info().misses
|
||||
self.assertEqual(num_misses, 1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user