From 499ceeeb2c8e6384c8221fc77d3b2b23af59cff0 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 22 Jul 2024 23:24:31 -0700 Subject: [PATCH] Add support for named grids in pallas_call. PiperOrigin-RevId: 655036727 --- jax/_src/pallas/core.py | 53 ++++++++++++--- jax/_src/pallas/mosaic/core.py | 8 ++- jax/_src/pallas/mosaic/lowering.py | 41 +++++++++-- jax/_src/pallas/pallas_call.py | 9 ++- jax/_src/pallas/triton/lowering.py | 15 ++++- tests/pallas/pallas_test.py | 105 +++++++++++++++++++++++++++++ 6 files changed, 209 insertions(+), 22 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index ae5f267d2..53164e6e4 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -22,7 +22,7 @@ import dataclasses import enum import functools import threading -from typing import Any, Union +from typing import Any, Hashable, Union import warnings import jax @@ -44,9 +44,14 @@ dynamic_grid_dim = DynamicGridDim() partial = functools.partial -Grid = tuple[Union[int, jax_core.Array], ...] +GridElement = int | jax_core.Array +GridName = Hashable +GridNames = tuple[Hashable, ...] | None +NamedGrid = tuple[tuple[GridName, int], ...] +TupleGrid = tuple[GridElement, ...] +Grid = Union[NamedGrid, TupleGrid] StaticGrid = tuple[int, ...] -GridMappingGrid = tuple[Union[int, DynamicGridDim], ...] +GridMappingGrid = tuple[int | DynamicGridDim, ...] split_list = util.split_list map, unsafe_map = util.safe_map, map @@ -280,6 +285,7 @@ def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]): @dataclasses.dataclass(frozen=True) class GridMapping: grid: GridMappingGrid + grid_names: tuple[Hashable, ...] | None block_mappings: tuple[BlockMapping | None, ...] mapped_dims: tuple[int, ...] = () num_index_operands: int = 0 @@ -301,16 +307,39 @@ class GridMapping: @contextlib.contextmanager def trace_env(self): - with tracing_grid_env(self.grid, self.mapped_dims): + if self.grid_names is None: + axis_env_ctx = contextlib.nullcontext() + else: + axis_env_ctx = jax_core.extend_axis_env_nd( + zip(self.grid_names, self.grid) + ) + with tracing_grid_env(self.grid, self.mapped_dims), axis_env_ctx: yield +def _is_valid_grid_dim(dim: int | jax.Array) -> bool: + if isinstance(dim, jax.Array): + return True + return jax_core.is_dim(dim) -def _preprocess_grid(grid: Grid | int | None) -> Grid: +def _preprocess_grid(grid: Grid | int | None) -> tuple[TupleGrid, GridNames]: if grid is None: - return () + return (), None if isinstance(grid, int): - return (grid,) - return grid + return (grid,), None + # Handle empty grid + if not grid: + return grid, None # type: ignore + # Check if we have a named grid + if isinstance(grid[0], tuple): + grid_names, grid = util.unzip2(grid) # type: ignore + else: + grid_names = None + # TODO(b/353730556): allow NumPy scalars in grids + if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore + raise ValueError( + f"Grid must be a tuple of integers or jax.Array, got {grid}" + ) + return grid, grid_names # type: ignore def _convert_block_spec_to_block_mapping( @@ -409,7 +438,8 @@ def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray], @dataclasses.dataclass(init=False, unsafe_hash=True) class GridSpec: - grid: Grid + grid: TupleGrid + grid_names: tuple[Hashable, ...] | None in_specs: tuple[BlockSpec | NoBlockSpec, ...] out_specs: tuple[BlockSpec | NoBlockSpec, ...] in_specs_tree: Any @@ -429,7 +459,7 @@ class GridSpec: if isinstance(out_specs, list): out_specs = tuple(out_specs) - self.grid = _preprocess_grid(grid) + self.grid, self.grid_names = _preprocess_grid(grid) if in_specs is not no_block_spec: flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs) self.in_specs = tuple(flat_in_specs) @@ -504,7 +534,8 @@ class GridSpec: out_ref_avals, ) grid_mapping = GridMapping( - grid_mapping_grid, (*in_block_mappings, *out_block_mappings) # type: ignore + grid_mapping_grid, self.grid_names, # type: ignore + (*in_block_mappings, *out_block_mappings) ) jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 20a7f8207..68a37b777 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ from collections.abc import Sequence import dataclasses import enum import functools -from typing import Any +from typing import Any, Hashable from jax._src import core as jax_core from jax._src import dtypes @@ -33,6 +33,7 @@ zip, unsafe_zip = util.safe_zip, zip partial = functools.partial Grid = pallas_core.Grid +TupleGrid = pallas_core.TupleGrid BlockSpec = pallas_core.BlockSpec BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping @@ -149,7 +150,8 @@ def _make_aval(obj: object) -> jax_core.AbstractValue: @dataclasses.dataclass(init=False, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): - grid: Grid + grid: TupleGrid + grid_names: tuple[Hashable, ...] | None num_scalar_prefetch: int in_specs: tuple[BlockSpec | NoBlockSpec, ...] out_specs: tuple[BlockSpec | NoBlockSpec, ...] @@ -227,7 +229,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): out_ref_avals, ) grid_mapping = GridMapping( - grid=grid_mapping_grid, # type: ignore + grid=grid_mapping_grid, grid_names=self.grid_names, # type: ignore block_mappings=(*in_block_mappings, *out_block_mappings), mapped_dims=(), num_index_operands=num_flat_scalar_prefetch, diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f4aa9a70f..084534624 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -19,7 +19,7 @@ from collections.abc import Callable, Sequence import dataclasses import functools import string -from typing import Any +from typing import Any, Hashable import jax from jax import lax @@ -99,6 +99,7 @@ class MeshContext: class LoweringContext: ir_context: ir.Context grid_rank: int # Includes both user and vmap axes. + grid_names: tuple[Hashable, ...] | None mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. user_grid_indices: Sequence[ir.Value] | None block_shapes: list[tuple[int | pl_core.Mapped, ...]] @@ -252,6 +253,7 @@ def _get_arg_type( @dataclasses.dataclass(init=False) class MosaicGridMapping: grid: tuple[int, ...] | None + grid_names: tuple[Hashable, ...] | None jaxpr: jax_core.Jaxpr block_mappings: tuple[pl_core.BlockMapping | None, ...] mapped_dims: tuple[int, ...] @@ -269,6 +271,7 @@ class MosaicGridMapping: dimension_semantics: tuple[str, ...] | None, mesh: mesh_lib.Mesh | None): self.grid = grid_mapping.grid + self.grid_names = grid_mapping.grid_names self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.mapped_dims @@ -341,6 +344,12 @@ class MosaicGridMapping: "Cannot use communication in pallas_call without shard_map." ) axis_names = mesh.axis_names + if self.grid_names is not None: + if any(a in self.grid_names for a in axis_names): + raise ValueError( + "Cannot shadow axis mesh axis names with grid names. mesh axis" + f" names: {mesh.axis_names}, grid names: {self.grid_names}" + ) # We need mesh <-> logical translation tables. Since the logical IDs are # just linearized versions of the mesh IDs, we create those tables. mesh_strides = pallas_utils.strides_from_shape(tuple( @@ -356,7 +365,19 @@ class MosaicGridMapping: @functools.cached_property def has_communication(self) -> bool: - return bool(jax_core.used_axis_names_jaxpr(self.jaxpr)) + nonlocal_axis_names = set() + def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr): + return { + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + and (not self.grid_names or e.name not in self.grid_names) + } + nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr)) + for bm in self.block_mappings: + if bm is not None: + nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr)) + return bool(nonlocal_axis_names) def get_extra_args(self) -> tuple[Any, ...]: return () @@ -531,6 +552,7 @@ def lower_jaxpr_to_transform_func( lowering_context = LoweringContext( ctx, len(mosaic_grid_mapping.grid), + mosaic_grid_mapping.grid_names, mosaic_grid_mapping.mapped_dims, None, arg_block_shapes, @@ -599,6 +621,7 @@ def lower_jaxpr_to_func( lowering_context = LoweringContext( ctx, len(mosaic_grid_mapping.grid), + mosaic_grid_mapping.grid_names, mosaic_grid_mapping.mapped_dims, jaxpr_indices, arg_block_shapes, @@ -2627,10 +2650,18 @@ def _device_id_lowering_rule(ctx: LoweringRuleContext): return tpu.DeviceIdOp().result lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule -def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str): +def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + grid_names = ctx.lowering_context.grid_names + if grid_names and axis_name in grid_names: + # We are querying a named axis corresponding to a grid dimension. + return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name)) + # We are querying a named axis corresponding to a mesh dimension. device_id = tpu.DeviceIdOp().result - mesh_shape = ctx.lowering_context.mesh_context.mesh_shape - axis_names = ctx.lowering_context.mesh_context.axis_names + mesh_context = ctx.lowering_context.mesh_context + if mesh_context is None: + raise ValueError("Mesh context is not set.") + mesh_shape = mesh_context.mesh_shape + axis_names = mesh_context.axis_names axis_index = axis_names.index(axis_name) axis_size = ir_constant(mesh_shape[axis_index]) minor_divisor = ir_constant( diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d185133c1..e725229cb 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -908,7 +908,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun), jaxpr_in_tree) debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call") - with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + with grid_mapping.trace_env(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug) if consts: @@ -1032,6 +1032,13 @@ jax_core.custom_str_eqn_compact_rules[pallas_call_p] = ( _pallas_custom_str_eqn_compact ) +def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): + with grid_mapping.trace_env(): + return pallas_call_p.abstract_eval( + *in_avals, grid_mapping=grid_mapping, **params + ) +jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule + def pallas_call( f: Callable[..., None], diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 5653dcf6a..bd9cf711b 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -21,7 +21,7 @@ import dataclasses import functools import math import operator -from typing import Any, TypeVar +from typing import Any, Hashable, TypeVar import jax from jax import lax @@ -256,7 +256,10 @@ def lower_jaxpr_to_triton_module( name: str, platform: str ) -> LoweringResult: - jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) + with grid_mapping.trace_env(): + jaxpr, _ = pe.dce_jaxpr( + jaxpr, [True] * len(jaxpr.outvars), instantiate=True + ) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() param_types = [ @@ -2240,6 +2243,14 @@ def _remat_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): triton_lowering_rules[ad_util.stop_gradient_p] = lambda _, x: x +@register_lowering(lax.axis_index_p) +def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + grid_names = ctx.context.grid_mapping.grid_names + if axis_name in grid_names: + # We are querying a named axis corresponding to a grid dimension. + return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name)) + raise LookupError(f"Axis name {axis_name} not found in grid.") + def _is_read_only(ref_effects) -> bool: if len(ref_effects) == 0: return True diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 10f3c9ee7..00af80ae3 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1927,5 +1927,110 @@ class PallasCheckifyInterpreterTest(PallasBaseTest): err.throw() +class PallasCallNamedGridTest(PallasBaseTest): + + def test_named_grid(self): + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + x = jnp.arange(2 * 8 * 128, dtype=np.int32).reshape((2, 8, 128)) + y = self.pallas_call( + kernel, + out_shape=x, + in_specs=[ + pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), + ], + out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), + grid=(("i", 2),) + )(x) + np.testing.assert_array_equal(y, x) + + def test_named_grid_reordered_names(self): + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128)) + y = self.pallas_call( + kernel, + out_shape=x, + in_specs=[ + pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)), + ], + out_specs=pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)), + grid=(("j", 4), ("i", 2)) + )(x) + np.testing.assert_array_equal(y, x) + + def test_can_query_named_grid_size_in_kernel_via_psum(self): + + def kernel(x_ref, y_ref): + self.assertEqual(lax.psum(1, "i"), 2) + self.assertEqual(lax.psum(1, "j"), 4) + y_ref[...] = x_ref[...] + + x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128)) + y = self.pallas_call( + kernel, + out_shape=x, + in_specs=[ + pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)), + ], + out_specs=pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)), + grid=(("j", 4), ("i", 2)) + )(x) + np.testing.assert_array_equal(y, x) + + def test_can_query_named_dynamic_grid_size_in_kernel_via_psum(self): + # TODO(): Enable dynamic grid size via axis_size primitive. + self.skipTest("Not supported.") + + def kernel(x_ref, y_ref): + self.assertEqual(lax.psum(1, "i"), 2) + self.assertEqual(lax.psum(1, "j"), 4) + y_ref[...] = x_ref[...] + + x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128)) + @jax.jit + def foo(n): + return self.pallas_call( + kernel, + out_shape=x, + in_specs=[ + pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), + ], + out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), + grid=(("i", n),) + )(x) + y = foo(4) + np.testing.assert_array_equal(y, x) + + def test_can_query_named_grid_program_id_in_kernel_via_axis_index(self): + if self.INTERPRET: + self.skipTest("Not supported in interpret mode.") + def kernel(x_ref, y_ref): + i_index = lax.axis_index("i") + y_ref[...] = x_ref[...] + i_index + + x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128)) + y = self.pallas_call( + kernel, + out_shape=x, + in_specs=[ + pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), + ], + out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), + grid=(("i", 4),), + )(x) + np.testing.assert_array_equal( + y, x + jnp.arange(4, dtype=jnp.int32)[:, None, None] + ) + + +class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest): + INTERPRET = True + + if __name__ == "__main__": absltest.main()