From cb114f247a78d10b50504d7495a3cf0ecdd0c020 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 7 Sep 2023 17:08:18 -0700 Subject: [PATCH] [Pallas] Refactor memory space handling PiperOrigin-RevId: 563586933 --- jax/_src/pallas/core.py | 175 +++++++++++------- jax/_src/pallas/mosaic/__init__.py | 6 +- jax/_src/pallas/mosaic/core.py | 108 ++++++++--- jax/_src/pallas/mosaic/lowering.py | 40 ++-- .../pallas/mosaic/pallas_call_registration.py | 4 +- jax/_src/pallas/mosaic/primitives.py | 32 +--- jax/_src/pallas/pallas_call.py | 60 +++--- jax/experimental/pallas/__init__.py | 1 + .../pallas/ops/tpu/flash_attention.py | 7 + jax/experimental/pallas/tpu.py | 3 +- tests/pallas/pallas_test.py | 10 +- 11 files changed, 264 insertions(+), 182 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index bd2d0492c..2d82ee046 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -70,12 +70,24 @@ class Mapped: mapped = Mapped() -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(init=False, unsafe_hash=True) class BlockSpec: - index_map: Callable[..., Any] - block_shape: tuple[int | None, ...] + index_map: Callable[..., Any] | None + block_shape: tuple[int | None, ...] | None + memory_space: Any + + def __init__(self, index_map: Callable[..., Any] | None = None, + block_shape: tuple[int | None, ...] | None = None, + memory_space: Any = None): + self.index_map = index_map + if block_shape is not None and not isinstance(block_shape, tuple): + block_shape = tuple(block_shape) + self.block_shape = block_shape + self.memory_space = memory_space def compute_index(self, *args): + assert self.index_map is not None + assert self.block_shape is not None out = self.index_map(*args) if not isinstance(out, tuple): out = (out,) @@ -86,6 +98,7 @@ class BlockSpec: class BlockMapping: block_shape: tuple[Mapped | int, ...] index_map_jaxpr: jax_core.ClosedJaxpr + memory_space: Any def compute_start_indices(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( @@ -123,107 +136,131 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid: def _convert_block_spec_to_block_mapping( in_avals: list[jax_core.ShapedArray], block_spec: BlockSpec | None, + aval: jax_core.ShapedArray, ) -> BlockSpec | None: - if block_spec is _no_block_spec: + if block_spec is no_block_spec: return None + if block_spec.index_map is None: + compute_index = lambda *args: (0,) * len(aval.shape) + block_shape = aval.shape + else: + compute_index = block_spec.compute_index + block_shape = block_spec.block_shape block_shape = tuple( - mapped if s is None else s for s in block_spec.block_shape) + mapped if s is None else s for s in block_shape) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(block_spec.compute_index), in_avals) - return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts)) + lu.wrap_init(compute_index), in_avals) + return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts), + block_spec.memory_space) -def _compute_shape_from_block_spec(block_spec: BlockSpec | None, - arg_shape: tuple[int, ...] - ) -> tuple[int, ...]: - if block_spec is _no_block_spec: - return arg_shape - return tuple(s for s in block_spec.block_shape if s is not None) +def _tile_ref(ref: jax_core.AbstractRef, block_shape: tuple[int, ...] | None + ) -> jax_core.AbstractRef: + if block_shape is None: + return ref + shape = tuple(s for s in block_shape if s is not None) + return state.shaped_array_ref(shape, ref.dtype) def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs): + in_ref_avals = map(state.AbstractRef, in_avals) + out_ref_avals = map(state.AbstractRef, out_avals) if grid is None: in_specs = [None] * len(in_avals) out_specs = [None] * len(out_avals) - in_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in in_avals] - out_ref_avals = [state.shaped_array_ref(arg.shape, arg.dtype) - for arg in out_avals] - else: - in_ref_avals = [ - state.shaped_array_ref( - _compute_shape_from_block_spec( - block_spec, arg.shape), arg.dtype) - for block_spec, arg in zip(in_specs, in_avals)] - out_ref_avals = [ - state.shaped_array_ref( - _compute_shape_from_block_spec( - block_spec, arg.shape), arg.dtype) - for block_spec, arg in zip(out_specs, out_avals)] - return in_specs, in_ref_avals, out_specs, out_ref_avals + tiled_in_ref_avals = [ + aval if in_spec is no_block_spec + else _tile_ref(aval, in_spec.block_shape) + for aval, in_spec in zip(in_ref_avals, in_specs) + ] + tiled_out_ref_avals = [ + aval if out_spec is no_block_spec + else _tile_ref(aval, out_spec.block_shape) + for aval, out_spec in zip(out_ref_avals, out_specs) + ] + return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals +class NoBlockSpec: + pass +no_block_spec = NoBlockSpec() -_no_block_spec = object() - -@dataclasses.dataclass(init=False) +@dataclasses.dataclass(init=False, unsafe_hash=True) class GridSpec: grid: Grid - in_specs: Sequence[BlockSpec | None] | None - out_specs: tuple[BlockSpec | None, ...] | None + in_specs: tuple[BlockSpec | NoBlockSpec, ...] + out_specs: tuple[BlockSpec | NoBlockSpec, ...] + in_specs_tree: Any + out_specs_tree: Any def __init__( self, grid: Grid | None = None, - in_specs: Sequence[BlockSpec | None] | None = None, - out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None, + in_specs: BlockSpec + | Sequence[BlockSpec | NoBlockSpec] + | NoBlockSpec = no_block_spec, + out_specs: BlockSpec + | Sequence[BlockSpec | NoBlockSpec] + | NoBlockSpec = no_block_spec, ): - if grid is None: - if in_specs is not None: - raise ValueError("Cannot specify `in_specs` with a `None` grid.") - if out_specs is not None: - raise ValueError("Cannot specify `out_specs` with a `None` grid.") - self.grid = _preprocess_grid(grid) - self.in_specs = in_specs - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if out_specs is not None and not isinstance(out_specs, tuple): + # Be more lenient for in/out_specs + if isinstance(in_specs, list): + in_specs = tuple(in_specs) + if isinstance(out_specs, list): out_specs = tuple(out_specs) - self.out_specs = out_specs + + self.grid = _preprocess_grid(grid) + if in_specs is not no_block_spec: + flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs) + self.in_specs = tuple(flat_in_specs) + else: + self.in_specs = in_specs + self.in_specs_tree = None + if out_specs is not no_block_spec: + flat_out_specs, self.out_specs_tree = tree_util.tree_flatten(out_specs) + self.out_specs = tuple(flat_out_specs) + else: + self.out_specs = out_specs + self.out_specs_tree = None + + def _get_in_out_specs(self, in_avals, in_tree, out_avals, out_tree): + if self.in_specs is no_block_spec: + flat_in_specs = [no_block_spec] * len(in_avals) + else: + flat_in_specs = self.in_specs + if self.in_specs_tree != in_tree: + raise ValueError( + "Pytree specs for arguments and `in_specs` must match: " + f"{in_tree} vs. {self.in_specs_tree}") + if self.out_specs is no_block_spec: + flat_out_specs = [no_block_spec] * len(out_avals) + else: + flat_out_specs = self.out_specs + if self.out_specs_tree != out_tree: + raise ValueError( + "Pytree specs for `out_shape` and `out_specs` must match: " + f"{out_tree} vs. {self.out_specs_tree}") + return flat_in_specs, flat_out_specs def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: - if self.in_specs is not None: - in_specs = self.in_specs - in_spec_tree = tree_util.tree_structure(tuple(in_specs)) - if in_spec_tree != in_tree: - raise ValueError( - "Pytree specs for arguments and `in_specs` must match: " - f"{in_tree} vs. {in_spec_tree}") - else: - in_specs = [_no_block_spec] * len(in_avals) - if self.out_specs is not None: - out_specs = self.out_specs - out_spec_tree = tree_util.tree_structure(out_specs) - if out_spec_tree != out_tree: - raise ValueError( - "Pytree specs for `out_shape` and `out_specs` must match: " - f"{out_tree} vs. {out_spec_tree}") - else: - out_specs = [_no_block_spec] * len(out_avals) - flat_in_specs = tree_util.tree_leaves(in_specs) - flat_out_specs = tree_util.tree_leaves(out_specs) + flat_in_specs, flat_out_specs = self._get_in_out_specs( + in_avals, in_tree, out_avals, out_tree) in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals( self.grid, in_avals, flat_in_specs, out_avals, flat_out_specs) grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) in_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs) + partial(_convert_block_spec_to_block_mapping, grid_avals), in_specs, + in_ref_avals) out_block_mappings = map( - partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs) + partial(_convert_block_spec_to_block_mapping, grid_avals), out_specs, + out_ref_avals) grid_mapping = GridMapping( self.grid, (*in_block_mappings, *out_block_mappings), (), num_index_operands=0) jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) + if not isinstance(jaxpr_out_avals, (tuple, list)): + jaxpr_out_avals = (jaxpr_out_avals,) return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 1387fc1f1..5843a5f1f 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -23,10 +23,10 @@ from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regen from jax._src.pallas.mosaic.primitives import repeat from jax._src.pallas.mosaic.primitives import trace from jax._src.pallas.mosaic.primitives import run_scoped -from jax._src.pallas.mosaic.primitives import VMEM -SMEM = TPUMemorySpace.SMEM +ANY = TPUMemorySpace.ANY CMEM = TPUMemorySpace.CMEM - +SMEM = TPUMemorySpace.SMEM +VMEM = TPUMemorySpace.VMEM del pallas_call_registration diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index bfeda7b7a..bab61751c 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import dataclasses import enum import functools +from typing import Any from jax._src import core as jax_core from jax._src import state @@ -30,12 +31,16 @@ from jax._src.pallas import core as pallas_core # TODO(sharadmv): enable type checking # mypy: ignore-errors +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip + partial = functools.partial Grid = pallas_core.Grid BlockSpec = pallas_core.BlockSpec GridMapping = pallas_core.GridMapping +NoBlockSpec = pallas_core.NoBlockSpec +no_block_spec = pallas_core.no_block_spec _preprocess_grid = pallas_core._preprocess_grid -_compute_shape_from_block_spec = pallas_core._compute_shape_from_block_spec _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping split_list = util.split_list @@ -49,59 +54,114 @@ class TPUMemorySpace(enum.Enum): def __str__(self) -> str: return self.value + def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): + # A convenience function for constructing MemoryRef types. + return MemoryRef(shape, dtype, self) -@dataclasses.dataclass(init=False) + +class AbstractMemoryRef(state.AbstractRef): + __slots__ = ["inner_aval", "memory_space"] + + def __init__(self, inner_aval: jax_core.AbstractValue, + memory_space: TPUMemorySpace): + assert isinstance(inner_aval, jax_core.ShapedArray) + self.inner_aval = inner_aval + self.memory_space = memory_space + + def __repr__(self) -> str: + return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' + + def at_least_vspace(self): + return AbstractMemoryRef( + self.inner_aval.at_least_vspace(), self.memory_space) + + def __eq__(self, other): + return (type(self) is type(other) and self.inner_aval == other.inner_aval + and self.memory_space == other.memory_space) + + def __hash__(self): + return hash((self.__class__, self.inner_aval, self.memory_space)) + + +def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): + return AbstractMemoryRef( + jax_core.raise_to_shaped(ref_aval.inner_aval, weak_type), + ref_aval.memory_space) +jax_core.raise_to_shaped_mappings[AbstractMemoryRef] = _ref_raise_to_shaped + + +@dataclasses.dataclass(frozen=True) +class MemoryRef: + """Like jax.ShapeDtypeStruct but with memory spaces.""" + shape: tuple[int, ...] + dtype: jnp.dtype + memory_space: TPUMemorySpace = TPUMemorySpace.ANY + + def get_aval(self): + return AbstractMemoryRef(jax_core.ShapedArray(self.shape, self.dtype), + self.memory_space) + + +@dataclasses.dataclass(init=False, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): grid: Grid num_scalar_prefetch: int - in_specs: Sequence[BlockSpec | None] | None - out_specs: tuple[BlockSpec | None, ...] | None + in_specs: tuple[BlockSpec | NoBlockSpec, ...] + out_specs: tuple[BlockSpec | NoBlockSpec, ...] + in_specs_tree: Any + out_specs_tree: Any + def __init__( self, num_scalar_prefetch: int, grid: Grid | None = None, - in_specs: Sequence[BlockSpec | None] | None = None, - out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None, + in_specs: BlockSpec + | Sequence[BlockSpec | NoBlockSpec] + | NoBlockSpec = no_block_spec, + out_specs: BlockSpec + | Sequence[BlockSpec | NoBlockSpec] + | NoBlockSpec = no_block_spec, ): - if grid is None: - raise NotImplementedError("Should pass in non-`None` grid.") - self.grid = _preprocess_grid(grid) - if out_specs is not None and not isinstance(out_specs, (tuple, list)): - out_specs = (out_specs,) - if out_specs is not None and not isinstance(out_specs, tuple): - out_specs = tuple(out_specs) + super().__init__(grid, in_specs, out_specs) self.num_scalar_prefetch = num_scalar_prefetch - self.in_specs = in_specs - self.out_specs = out_specs def get_grid_mapping( self, in_avals, in_tree, out_avals, out_tree ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: - scalar_avals, in_avals = split_list(in_avals, [self.num_scalar_prefetch]) - flat_in_specs = tree_util.tree_leaves(self.in_specs) - flat_out_specs = tree_util.tree_leaves(self.out_specs) + all_avals = tree_util.tree_unflatten(in_tree, in_avals) + scalar_avals, unflat_in_avals = split_list( + all_avals, [self.num_scalar_prefetch]) + flat_scalar_avals, scalar_tree = tree_util.tree_flatten(scalar_avals) + num_flat_scalar_prefetch = len(flat_scalar_avals) + in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals)) + flat_in_specs, flat_out_specs = self._get_in_out_specs( + in_avals, in_avals_tree, out_avals, out_tree) in_specs, in_ref_avals, out_specs, out_ref_avals = ( pallas_core._get_ref_avals( self.grid, in_avals, flat_in_specs, out_avals, flat_out_specs)) scalar_ref_avals = [ state.shaped_array_ref(aval.shape, aval.dtype) - for aval in scalar_avals] + for aval in flat_scalar_avals] grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) in_block_mappings = map( partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals)), in_specs) + (*grid_avals, *scalar_ref_avals)), in_specs, in_ref_avals) out_block_mappings = map( partial(_convert_block_spec_to_block_mapping, - (*grid_avals, *scalar_ref_avals)), out_specs) + (*grid_avals, *scalar_ref_avals)), out_specs, out_ref_avals) grid_mapping = GridMapping( grid=self.grid, block_mappings=(*in_block_mappings, *out_block_mappings), mapped_dims=(), - num_index_operands=self.num_scalar_prefetch, + num_index_operands=num_flat_scalar_prefetch, ) - jaxpr_in_avals = tree_util.tree_unflatten( - in_tree, [*scalar_ref_avals, *in_ref_avals]) + jaxpr_scalar_ref_avals = tree_util.tree_unflatten( + scalar_tree, scalar_ref_avals) + jaxpr_in_ref_avals = tree_util.tree_unflatten(in_avals_tree, in_ref_avals) + jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) + if not isinstance(jaxpr_out_avals, (tuple, list)): + jaxpr_out_avals = (jaxpr_out_avals,) return (*jaxpr_in_avals, *jaxpr_out_avals), grid_mapping diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 0ea62f1c1..d077d7561 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -143,20 +143,17 @@ def lower_jaxpr_to_module( grid_mapping: core.GridMapping, jaxpr: jax_core.Jaxpr, dimension_semantics: tuple[str | None, ...] | None, - memory_spaces: tuple[TPUMemorySpace | None, ...] | None ) -> ir.Module: m = ir.Module.create() sym_tab = ir.SymbolTable(m.operation) - if all(bm is None for bm in grid_mapping.block_mappings): + if not grid_mapping.grid: # Trivial grid-map, we don't need to populate the transform functions. func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, - memory_spaces=memory_spaces, name="main") m.body.append(func_op) sym_tab.insert(func_op) return m func_op = lower_jaxpr_to_func(ctx, jaxpr, grid_mapping=grid_mapping, - memory_spaces=memory_spaces, name="main") m.body.append(func_op) sym_tab.insert(func_op) @@ -256,10 +253,11 @@ def lower_jaxpr_to_func( ctx: ir.Context, jaxpr: jax_core.Jaxpr, *, - memory_spaces: Sequence[tpu_core.TPUMemorySpace | None] | None, grid_mapping: core.GridMapping | None, name: str, ) -> func.FuncOp: + memory_spaces = [None if bm is None else bm.memory_space + for bm in grid_mapping.block_mappings] if grid_mapping: arg_types = map( aval_to_ir_type, @@ -277,22 +275,15 @@ def lower_jaxpr_to_func( ) return (aval_to_ir_type(aval, shape=shape, memory_space=memory_space), block_mapping.block_shape) - if memory_spaces is None: - memory_spaces = [None] * len(jaxpr.invars) - if len(memory_spaces) != len(jaxpr.invars): - raise ValueError("Must have as many memory spaces as inputs and outputs.") if grid_mapping is None: block_mappings = [None] * len(jaxpr.invars) else: scalar_prefetch = grid_mapping.num_index_operands block_mappings = grid_mapping.block_mappings block_mappings = [*[None] * scalar_prefetch, *block_mappings] - for memory_space in memory_spaces[:scalar_prefetch]: - if memory_space is not None and memory_space != SMEM: - raise ValueError("Cannot specify non-SMEM memory space for " - "scalar prefetch inputs.") - memory_spaces = memory_spaces[scalar_prefetch:] memory_spaces = [*[SMEM] * scalar_prefetch, *memory_spaces] + assert len(memory_spaces) == len(jaxpr.invars), ( + "Must have as many memory spaces as inputs and outputs.") invar_arg_types, block_shapes = unzip2( map(_get_arg_type, [invar.aval for invar in jaxpr.invars], block_mappings, memory_spaces) @@ -1402,24 +1393,25 @@ def _trace_stop_lowering_rule(ctx: LoweringRuleContext): lowering_rules[tpu_primitives.trace_stop_p] = _trace_stop_lowering_rule -def _alloc_type(type: tpu_primitives.Type): - if isinstance(type, tpu_primitives.VMEM): - aval = type.get_aval() - vmem = ir.Attribute.parse("#tpu.memory_space") +def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: + if isinstance(aval, tpu_core.AbstractMemoryRef): + memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>") out_type = ir.MemRefType.get( - aval.shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=vmem) + aval.shape, mlir.dtype_to_ir_type(aval.dtype), memory_space=memspace) return memref.AllocaOp(out_type, [], []).result - raise NotImplementedError(f"Cannot allocate {type}.") + raise NotImplementedError(f"Cannot allocate {type(aval)}.") -def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr, - types): +def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): region = tpu.RegionOp() + in_avals = [v.aval for v in jaxpr.invars] jaxpr = pe.convert_constvars_jaxpr(jaxpr) with ir.InsertionPoint(region.body): - args = [_alloc_type(type) for type in types] + args = map(_alloc_value, in_avals) + block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None + for a in in_avals) ctx = ctx.lowering_context.replace( - block_shapes=(*ctx.block_shapes, *(t.get_block_shape() for t in types)) + block_shapes=(*ctx.block_shapes, *block_shapes) ) jaxpr_subcomp(ctx, jaxpr, *consts, *args) tpu.YieldOp([]) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index f8b552406..3607c6f63 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -61,13 +61,11 @@ def pallas_call_tpu_lowering_rule( if mosaic_params is None: mosaic_params = {} dimension_semantics = mosaic_params.get("dimension_semantics", None) - memory_spaces = mosaic_params.get("memory_spaces", None) kernel_regeneration_metadata = mosaic_params.get( "kernel_regeneration_metadata" ) mosaic_module = lowering.lower_jaxpr_to_module( - mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics, - memory_spaces=memory_spaces) + mlir_ctx, grid_mapping, jaxpr, dimension_semantics=dimension_semantics) if debug: print(mosaic_module) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 9e73c7610..bc04e85f1 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -16,19 +16,20 @@ from __future__ import annotations import contextlib -import dataclasses from typing import Callable from jax._src import api_util from jax._src import core as jax_core from jax._src import effects from jax._src import linear_util as lu -from jax._src import state from jax._src import tree_util +from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe import jax.numpy as jnp +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip repeat_p = jax_core.Primitive('repeat') @@ -74,27 +75,6 @@ def trace(message: str, level: int = 10): trace_stop_p.bind() -class Type: - - def get_aval(self) -> jax_core.AbstractValue: - raise NotImplementedError() - - def get_block_shape(self) -> tuple[int, ...] | None: - raise NotImplementedError() - - -@dataclasses.dataclass(frozen=True) -class VMEM(Type): - shape: tuple[int, ...] - dtype: jnp.dtype - - def get_aval(self) -> jax_core.AbstractValue: - return state.AbstractRef(jax_core.ShapedArray(self.shape, self.dtype)) - - def get_block_shape(self) -> tuple[int, ...] | None: - return self.shape - - run_scoped_p = jax_core.Primitive('run_scoped') run_scoped_p.multiple_results = True @@ -102,13 +82,13 @@ run_scoped_p.multiple_results = True def run_scoped(f: Callable[..., None], *types, **kw_types) -> None: flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, _ = api_util.flatten_fun(lu.wrap_init(f), in_tree) - avals = [type.get_aval() for type in flat_types] + avals = map(lambda t: t.get_aval(), flat_types) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals) - run_scoped_p.bind(*consts, jaxpr=jaxpr, types=tuple(flat_types)) + run_scoped_p.bind(*consts, jaxpr=jaxpr) @run_scoped_p.def_effectful_abstract_eval -def _run_scoped_abstract_eval(*args, jaxpr, types): +def _run_scoped_abstract_eval(*args, jaxpr): # jaxpr will have effects for its inputs (Refs that are allocated) and for # constvars (closed over Refs). The effects for the allocated Refs are local # to the jaxpr and shouldn't propagate out. diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 7987b8cc0..3aa32fb07 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -50,6 +50,8 @@ BlockSpec = pallas_core.BlockSpec GridSpec = pallas_core.GridSpec BlockMapping = pallas_core.BlockMapping GridMapping = pallas_core.GridMapping +NoBlockSpec = pallas_core.NoBlockSpec +no_block_spec = pallas_core.no_block_spec pallas_call_p = jax_core.Primitive('pallas_call') pallas_call_p.multiple_results = True @@ -224,7 +226,8 @@ def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray, new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) if block_mapping is None: - return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr) + return BlockMapping(block_shape=new_block_shape, index_map_jaxpr=jaxpr, + memory_space=None) return block_mapping.replace(block_shape=new_block_shape, index_map_jaxpr=jaxpr) @@ -324,40 +327,44 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr: return hoisted_jaxpr @weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, - primitive_name: str | None = None): +def _trace_to_jaxpr(fun: Callable, grid_spec, flat_in_avals, + flat_out_avals, in_tree, out_tree): + avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, + flat_out_avals, out_tree) + jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals) wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(fun), in_tree) - debug = pe.debug_info(fun, in_tree, out_tree_thunk, False, - primitive_name or "") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) + lu.wrap_init(fun), jaxpr_in_tree) + debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call") + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, + debug) jaxpr = _hoist_consts_to_refs(jaxpr) - return jaxpr, consts, out_tree_thunk() + return grid_mapping, jaxpr, consts, out_tree_thunk() def _extract_function_name(f: Callable, name: str | None) -> str: if name is None: name = f.__name__ if hasattr(f, "__name__") and f.__name__ else "func" return name + def pallas_call( - f: Callable[..., None], out_shape: Any, *, + f: Callable[..., None], + out_shape: Any, + *, grid_spec: GridSpec | None = None, debug: bool = False, grid: Grid | None = None, - in_specs: Sequence[BlockSpec | None] | None = None, - out_specs: BlockSpec | Sequence[BlockSpec | None] | None = None, + in_specs: Sequence[BlockSpec | NoBlockSpec] | NoBlockSpec = no_block_spec, + out_specs: BlockSpec | NoBlockSpec + | Sequence[BlockSpec | NoBlockSpec] = no_block_spec, input_output_aliases: Dict[int, int] = {}, interpret: bool = False, name: str | None = None, - **compiler_params: Any): + **compiler_params: Any, +): + name = _extract_function_name(f, name) if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs) - name = _extract_function_name(f, name) - singleton = False - if not isinstance(out_shape, (tuple, list)): - out_shape = (out_shape,) - singleton = True - if not isinstance(out_shape, tuple): + if isinstance(out_shape, list): out_shape = tuple(out_shape) flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) @@ -365,14 +372,13 @@ def pallas_call( @jax.jit def wrapped(*args): flat_args, in_tree = tree_util.tree_flatten(args) - flat_avals = [jax_core.raise_to_shaped(jax_core.get_aval(a)) - for a in flat_args] - avals, grid_mapping = grid_spec.get_grid_mapping(flat_avals, in_tree, - flat_out_shapes, out_tree) - jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals) - jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree, - tuple(jaxpr_flat_avals), - primitive_name="pallas_call") + flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) + for a in flat_args) + flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) + for v in flat_out_shapes) + grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr( + f, grid_spec, flat_in_avals, flat_out_avals, in_tree, + out_tree) which_linear = (False,) * len(flat_args) out_flat = pallas_call_p.bind( *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear, @@ -384,7 +390,5 @@ def pallas_call( input_output_aliases=tuple(input_output_aliases.items()), **compiler_params) out = tree_util.tree_unflatten(out_tree, out_flat) - if singleton: - return out[0] return out return wrapped diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 45e1964e1..afcb62ad4 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -16,6 +16,7 @@ from jax._src import pallas from jax._src.pallas.core import BlockSpec +from jax._src.pallas.core import no_block_spec from jax._src.pallas.indexing import ds from jax._src.pallas.indexing import dslice from jax._src.pallas.indexing import broadcast_to diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index fdace6dc7..9d94da20a 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -695,6 +695,7 @@ def _flash_attention_bwd_dkv( def qo_index_map(batch_index, _, head_index, q_seq_index): return (batch_index, head_index, q_seq_index, 0) qo_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim)) + assert qo_spec.block_shape is not None assert q.ndim == len(qo_spec.block_shape) do_spec = qo_spec assert do.ndim == len(qo_spec.block_shape) @@ -703,16 +704,19 @@ def _flash_attention_bwd_dkv( del q_seq_index return (batch_index, head_index, kv_seq_index, 0) kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim)) + assert kv_spec.block_shape is not None assert k.ndim == len(kv_spec.block_shape) assert v.ndim == len(kv_spec.block_shape) def lm_index_map(batch_index, _, head_index, q_seq_index): return (batch_index, head_index, q_seq_index, 0) lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert lm_spec.block_shape is not None assert l.ndim == len(lm_spec.block_shape) assert m.ndim == len(lm_spec.block_shape) di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert di_spec.block_shape is not None assert di.ndim == len(di_spec.block_shape) in_specs = [ @@ -882,6 +886,7 @@ def _flash_attention_bwd_dq( return (batch_index, head_index, kv_seq_index, 0) kv_spec = pl.BlockSpec(kv_index_map, (1, 1, block_k_major, head_dim)) + assert kv_spec.block_shape is not None assert k.ndim == len(kv_spec.block_shape) assert v.ndim == len(kv_spec.block_shape) @@ -890,10 +895,12 @@ def _flash_attention_bwd_dq( return (batch_index, head_index, q_seq_index, 0) lm_spec = pl.BlockSpec(lm_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert lm_spec.block_shape is not None assert l.ndim == len(lm_spec.block_shape) assert m.ndim == len(lm_spec.block_shape) di_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, MIN_BLOCK_SIZE)) + assert di_spec.block_shape is not None assert di.ndim == len(di_spec.block_shape) in_specs = [ diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index dcf365f6b..d008f4d29 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -15,9 +15,10 @@ """Contains Mosaic specific Pallas functions.""" from jax._src.pallas.mosaic import PrefetchScalarGridSpec from jax._src.pallas.mosaic import TPUMemorySpace +from jax._src.pallas.mosaic import ANY from jax._src.pallas.mosaic import CMEM -from jax._src.pallas.mosaic import VMEM from jax._src.pallas.mosaic import SMEM +from jax._src.pallas.mosaic import VMEM from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata from jax._src.pallas.mosaic import repeat diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index eed2c6362..9de56d9a8 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -29,7 +29,7 @@ from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import state from jax._src.lax.control_flow.for_loop import for_loop -from jax._src.pallas.pallas_call import _initial_style_open_jaxpr +from jax._src.pallas.pallas_call import _trace_to_jaxpr from jax.config import config from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -134,7 +134,7 @@ class PallasTest(parameterized.TestCase): super().setUp() if compile_jaxpr: compile_jaxpr.cache_clear() - _initial_style_open_jaxpr.cache_clear() + _trace_to_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) @@ -639,7 +639,8 @@ class PallasCallTest(PallasTest): def f(x): return add_one(add_one(x)) - self.assertEqual(f(0.), 2.) + x = jnp.array(0., dtype=jnp.float32) + self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) def test_pallas_compilation_cache(self): @@ -657,7 +658,8 @@ class PallasCallTest(PallasTest): def f(x): return add_one(add_one(x)) - self.assertEqual(f(0.), 2.) + x = jnp.array(0., dtype=jnp.float32) + self.assertEqual(f(x), 2.) num_misses = compile_jaxpr.cache_info().misses self.assertEqual(num_misses, 1)