Add support for named grids in pallas_call.

PiperOrigin-RevId: 655036727
This commit is contained in:
Sharad Vikram 2024-07-22 23:24:31 -07:00 committed by jax authors
parent a18872aa13
commit 499ceeeb2c
6 changed files with 209 additions and 22 deletions

View File

@ -22,7 +22,7 @@ import dataclasses
import enum
import functools
import threading
from typing import Any, Union
from typing import Any, Hashable, Union
import warnings
import jax
@ -44,9 +44,14 @@ dynamic_grid_dim = DynamicGridDim()
partial = functools.partial
Grid = tuple[Union[int, jax_core.Array], ...]
GridElement = int | jax_core.Array
GridName = Hashable
GridNames = tuple[Hashable, ...] | None
NamedGrid = tuple[tuple[GridName, int], ...]
TupleGrid = tuple[GridElement, ...]
Grid = Union[NamedGrid, TupleGrid]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[Union[int, DynamicGridDim], ...]
GridMappingGrid = tuple[int | DynamicGridDim, ...]
split_list = util.split_list
map, unsafe_map = util.safe_map, map
@ -280,6 +285,7 @@ def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
@dataclasses.dataclass(frozen=True)
class GridMapping:
grid: GridMappingGrid
grid_names: tuple[Hashable, ...] | None
block_mappings: tuple[BlockMapping | None, ...]
mapped_dims: tuple[int, ...] = ()
num_index_operands: int = 0
@ -301,16 +307,39 @@ class GridMapping:
@contextlib.contextmanager
def trace_env(self):
with tracing_grid_env(self.grid, self.mapped_dims):
if self.grid_names is None:
axis_env_ctx = contextlib.nullcontext()
else:
axis_env_ctx = jax_core.extend_axis_env_nd(
zip(self.grid_names, self.grid)
)
with tracing_grid_env(self.grid, self.mapped_dims), axis_env_ctx:
yield
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
if isinstance(dim, jax.Array):
return True
return jax_core.is_dim(dim)
def _preprocess_grid(grid: Grid | int | None) -> Grid:
def _preprocess_grid(grid: Grid | int | None) -> tuple[TupleGrid, GridNames]:
if grid is None:
return ()
return (), None
if isinstance(grid, int):
return (grid,)
return grid
return (grid,), None
# Handle empty grid
if not grid:
return grid, None # type: ignore
# Check if we have a named grid
if isinstance(grid[0], tuple):
grid_names, grid = util.unzip2(grid) # type: ignore
else:
grid_names = None
# TODO(b/353730556): allow NumPy scalars in grids
if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore
raise ValueError(
f"Grid must be a tuple of integers or jax.Array, got {grid}"
)
return grid, grid_names # type: ignore
def _convert_block_spec_to_block_mapping(
@ -409,7 +438,8 @@ def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
@dataclasses.dataclass(init=False, unsafe_hash=True)
class GridSpec:
grid: Grid
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
in_specs_tree: Any
@ -429,7 +459,7 @@ class GridSpec:
if isinstance(out_specs, list):
out_specs = tuple(out_specs)
self.grid = _preprocess_grid(grid)
self.grid, self.grid_names = _preprocess_grid(grid)
if in_specs is not no_block_spec:
flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs)
self.in_specs = tuple(flat_in_specs)
@ -504,7 +534,8 @@ class GridSpec:
out_ref_avals,
)
grid_mapping = GridMapping(
grid_mapping_grid, (*in_block_mappings, *out_block_mappings) # type: ignore
grid_mapping_grid, self.grid_names, # type: ignore
(*in_block_mappings, *out_block_mappings)
)
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)

View File

@ -19,7 +19,7 @@ from collections.abc import Sequence
import dataclasses
import enum
import functools
from typing import Any
from typing import Any, Hashable
from jax._src import core as jax_core
from jax._src import dtypes
@ -33,6 +33,7 @@ zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial
Grid = pallas_core.Grid
TupleGrid = pallas_core.TupleGrid
BlockSpec = pallas_core.BlockSpec
BlockSpecTree = pallas_core.BlockSpecTree
GridMapping = pallas_core.GridMapping
@ -149,7 +150,8 @@ def _make_aval(obj: object) -> jax_core.AbstractValue:
@dataclasses.dataclass(init=False, unsafe_hash=True)
class PrefetchScalarGridSpec(pallas_core.GridSpec):
grid: Grid
grid: TupleGrid
grid_names: tuple[Hashable, ...] | None
num_scalar_prefetch: int
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
@ -227,7 +229,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
out_ref_avals,
)
grid_mapping = GridMapping(
grid=grid_mapping_grid, # type: ignore
grid=grid_mapping_grid, grid_names=self.grid_names, # type: ignore
block_mappings=(*in_block_mappings, *out_block_mappings),
mapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,

View File

@ -19,7 +19,7 @@ from collections.abc import Callable, Sequence
import dataclasses
import functools
import string
from typing import Any
from typing import Any, Hashable
import jax
from jax import lax
@ -99,6 +99,7 @@ class MeshContext:
class LoweringContext:
ir_context: ir.Context
grid_rank: int # Includes both user and vmap axes.
grid_names: tuple[Hashable, ...] | None
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
user_grid_indices: Sequence[ir.Value] | None
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
@ -252,6 +253,7 @@ def _get_arg_type(
@dataclasses.dataclass(init=False)
class MosaicGridMapping:
grid: tuple[int, ...] | None
grid_names: tuple[Hashable, ...] | None
jaxpr: jax_core.Jaxpr
block_mappings: tuple[pl_core.BlockMapping | None, ...]
mapped_dims: tuple[int, ...]
@ -269,6 +271,7 @@ class MosaicGridMapping:
dimension_semantics: tuple[str, ...] | None,
mesh: mesh_lib.Mesh | None):
self.grid = grid_mapping.grid
self.grid_names = grid_mapping.grid_names
self.jaxpr = jaxpr
self.block_mappings = grid_mapping.block_mappings
self.mapped_dims = grid_mapping.mapped_dims
@ -341,6 +344,12 @@ class MosaicGridMapping:
"Cannot use communication in pallas_call without shard_map."
)
axis_names = mesh.axis_names
if self.grid_names is not None:
if any(a in self.grid_names for a in axis_names):
raise ValueError(
"Cannot shadow axis mesh axis names with grid names. mesh axis"
f" names: {mesh.axis_names}, grid names: {self.grid_names}"
)
# We need mesh <-> logical translation tables. Since the logical IDs are
# just linearized versions of the mesh IDs, we create those tables.
mesh_strides = pallas_utils.strides_from_shape(tuple(
@ -356,7 +365,19 @@ class MosaicGridMapping:
@functools.cached_property
def has_communication(self) -> bool:
return bool(jax_core.used_axis_names_jaxpr(self.jaxpr))
nonlocal_axis_names = set()
def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr):
return {
e.name
for e in jaxpr.effects
if isinstance(e, jax_core.NamedAxisEffect)
and (not self.grid_names or e.name not in self.grid_names)
}
nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr))
for bm in self.block_mappings:
if bm is not None:
nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr))
return bool(nonlocal_axis_names)
def get_extra_args(self) -> tuple[Any, ...]:
return ()
@ -531,6 +552,7 @@ def lower_jaxpr_to_transform_func(
lowering_context = LoweringContext(
ctx,
len(mosaic_grid_mapping.grid),
mosaic_grid_mapping.grid_names,
mosaic_grid_mapping.mapped_dims,
None,
arg_block_shapes,
@ -599,6 +621,7 @@ def lower_jaxpr_to_func(
lowering_context = LoweringContext(
ctx,
len(mosaic_grid_mapping.grid),
mosaic_grid_mapping.grid_names,
mosaic_grid_mapping.mapped_dims,
jaxpr_indices,
arg_block_shapes,
@ -2627,10 +2650,18 @@ def _device_id_lowering_rule(ctx: LoweringRuleContext):
return tpu.DeviceIdOp().result
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str):
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
grid_names = ctx.lowering_context.grid_names
if grid_names and axis_name in grid_names:
# We are querying a named axis corresponding to a grid dimension.
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
# We are querying a named axis corresponding to a mesh dimension.
device_id = tpu.DeviceIdOp().result
mesh_shape = ctx.lowering_context.mesh_context.mesh_shape
axis_names = ctx.lowering_context.mesh_context.axis_names
mesh_context = ctx.lowering_context.mesh_context
if mesh_context is None:
raise ValueError("Mesh context is not set.")
mesh_shape = mesh_context.mesh_shape
axis_names = mesh_context.axis_names
axis_index = axis_names.index(axis_name)
axis_size = ir_constant(mesh_shape[axis_index])
minor_divisor = ir_constant(

View File

@ -908,7 +908,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), jaxpr_in_tree)
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
with grid_mapping.trace_env():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
jaxpr_flat_avals, debug)
if consts:
@ -1032,6 +1032,13 @@ jax_core.custom_str_eqn_compact_rules[pallas_call_p] = (
_pallas_custom_str_eqn_compact
)
def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params):
with grid_mapping.trace_env():
return pallas_call_p.abstract_eval(
*in_avals, grid_mapping=grid_mapping, **params
)
jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule
def pallas_call(
f: Callable[..., None],

View File

@ -21,7 +21,7 @@ import dataclasses
import functools
import math
import operator
from typing import Any, TypeVar
from typing import Any, Hashable, TypeVar
import jax
from jax import lax
@ -256,7 +256,10 @@ def lower_jaxpr_to_triton_module(
name: str,
platform: str
) -> LoweringResult:
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
with grid_mapping.trace_env():
jaxpr, _ = pe.dce_jaxpr(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)
with _new_ir_context(), ir.Location.unknown():
module = ir.Module.create()
param_types = [
@ -2240,6 +2243,14 @@ def _remat_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
triton_lowering_rules[ad_util.stop_gradient_p] = lambda _, x: x
@register_lowering(lax.axis_index_p)
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
grid_names = ctx.context.grid_mapping.grid_names
if axis_name in grid_names:
# We are querying a named axis corresponding to a grid dimension.
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
raise LookupError(f"Axis name {axis_name} not found in grid.")
def _is_read_only(ref_effects) -> bool:
if len(ref_effects) == 0:
return True

View File

@ -1927,5 +1927,110 @@ class PallasCheckifyInterpreterTest(PallasBaseTest):
err.throw()
class PallasCallNamedGridTest(PallasBaseTest):
def test_named_grid(self):
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
x = jnp.arange(2 * 8 * 128, dtype=np.int32).reshape((2, 8, 128))
y = self.pallas_call(
kernel,
out_shape=x,
in_specs=[
pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
],
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
grid=(("i", 2),)
)(x)
np.testing.assert_array_equal(y, x)
def test_named_grid_reordered_names(self):
def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]
x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128))
y = self.pallas_call(
kernel,
out_shape=x,
in_specs=[
pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
],
out_specs=pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
grid=(("j", 4), ("i", 2))
)(x)
np.testing.assert_array_equal(y, x)
def test_can_query_named_grid_size_in_kernel_via_psum(self):
def kernel(x_ref, y_ref):
self.assertEqual(lax.psum(1, "i"), 2)
self.assertEqual(lax.psum(1, "j"), 4)
y_ref[...] = x_ref[...]
x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128))
y = self.pallas_call(
kernel,
out_shape=x,
in_specs=[
pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
],
out_specs=pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
grid=(("j", 4), ("i", 2))
)(x)
np.testing.assert_array_equal(y, x)
def test_can_query_named_dynamic_grid_size_in_kernel_via_psum(self):
# TODO(): Enable dynamic grid size via axis_size primitive.
self.skipTest("Not supported.")
def kernel(x_ref, y_ref):
self.assertEqual(lax.psum(1, "i"), 2)
self.assertEqual(lax.psum(1, "j"), 4)
y_ref[...] = x_ref[...]
x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128))
@jax.jit
def foo(n):
return self.pallas_call(
kernel,
out_shape=x,
in_specs=[
pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
],
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
grid=(("i", n),)
)(x)
y = foo(4)
np.testing.assert_array_equal(y, x)
def test_can_query_named_grid_program_id_in_kernel_via_axis_index(self):
if self.INTERPRET:
self.skipTest("Not supported in interpret mode.")
def kernel(x_ref, y_ref):
i_index = lax.axis_index("i")
y_ref[...] = x_ref[...] + i_index
x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128))
y = self.pallas_call(
kernel,
out_shape=x,
in_specs=[
pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
],
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
grid=(("i", 4),),
)(x)
np.testing.assert_array_equal(
y, x + jnp.arange(4, dtype=jnp.int32)[:, None, None]
)
class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest):
INTERPRET = True
if __name__ == "__main__":
absltest.main()