mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add support for named grids in pallas_call.
PiperOrigin-RevId: 655036727
This commit is contained in:
parent
a18872aa13
commit
499ceeeb2c
@ -22,7 +22,7 @@ import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import threading
|
||||
from typing import Any, Union
|
||||
from typing import Any, Hashable, Union
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
@ -44,9 +44,14 @@ dynamic_grid_dim = DynamicGridDim()
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
Grid = tuple[Union[int, jax_core.Array], ...]
|
||||
GridElement = int | jax_core.Array
|
||||
GridName = Hashable
|
||||
GridNames = tuple[Hashable, ...] | None
|
||||
NamedGrid = tuple[tuple[GridName, int], ...]
|
||||
TupleGrid = tuple[GridElement, ...]
|
||||
Grid = Union[NamedGrid, TupleGrid]
|
||||
StaticGrid = tuple[int, ...]
|
||||
GridMappingGrid = tuple[Union[int, DynamicGridDim], ...]
|
||||
GridMappingGrid = tuple[int | DynamicGridDim, ...]
|
||||
split_list = util.split_list
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -280,6 +285,7 @@ def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GridMapping:
|
||||
grid: GridMappingGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
block_mappings: tuple[BlockMapping | None, ...]
|
||||
mapped_dims: tuple[int, ...] = ()
|
||||
num_index_operands: int = 0
|
||||
@ -301,16 +307,39 @@ class GridMapping:
|
||||
|
||||
@contextlib.contextmanager
|
||||
def trace_env(self):
|
||||
with tracing_grid_env(self.grid, self.mapped_dims):
|
||||
if self.grid_names is None:
|
||||
axis_env_ctx = contextlib.nullcontext()
|
||||
else:
|
||||
axis_env_ctx = jax_core.extend_axis_env_nd(
|
||||
zip(self.grid_names, self.grid)
|
||||
)
|
||||
with tracing_grid_env(self.grid, self.mapped_dims), axis_env_ctx:
|
||||
yield
|
||||
|
||||
def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
|
||||
if isinstance(dim, jax.Array):
|
||||
return True
|
||||
return jax_core.is_dim(dim)
|
||||
|
||||
def _preprocess_grid(grid: Grid | int | None) -> Grid:
|
||||
def _preprocess_grid(grid: Grid | int | None) -> tuple[TupleGrid, GridNames]:
|
||||
if grid is None:
|
||||
return ()
|
||||
return (), None
|
||||
if isinstance(grid, int):
|
||||
return (grid,)
|
||||
return grid
|
||||
return (grid,), None
|
||||
# Handle empty grid
|
||||
if not grid:
|
||||
return grid, None # type: ignore
|
||||
# Check if we have a named grid
|
||||
if isinstance(grid[0], tuple):
|
||||
grid_names, grid = util.unzip2(grid) # type: ignore
|
||||
else:
|
||||
grid_names = None
|
||||
# TODO(b/353730556): allow NumPy scalars in grids
|
||||
if not all(_is_valid_grid_dim(g) for g in grid): # type: ignore
|
||||
raise ValueError(
|
||||
f"Grid must be a tuple of integers or jax.Array, got {grid}"
|
||||
)
|
||||
return grid, grid_names # type: ignore
|
||||
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
@ -409,7 +438,8 @@ def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class GridSpec:
|
||||
grid: Grid
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
in_specs_tree: Any
|
||||
@ -429,7 +459,7 @@ class GridSpec:
|
||||
if isinstance(out_specs, list):
|
||||
out_specs = tuple(out_specs)
|
||||
|
||||
self.grid = _preprocess_grid(grid)
|
||||
self.grid, self.grid_names = _preprocess_grid(grid)
|
||||
if in_specs is not no_block_spec:
|
||||
flat_in_specs, self.in_specs_tree = tree_util.tree_flatten(in_specs)
|
||||
self.in_specs = tuple(flat_in_specs)
|
||||
@ -504,7 +534,8 @@ class GridSpec:
|
||||
out_ref_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
grid_mapping_grid, (*in_block_mappings, *out_block_mappings) # type: ignore
|
||||
grid_mapping_grid, self.grid_names, # type: ignore
|
||||
(*in_block_mappings, *out_block_mappings)
|
||||
)
|
||||
jaxpr_in_avals = tree_util.tree_unflatten(in_tree, in_ref_avals)
|
||||
jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals)
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any
|
||||
from typing import Any, Hashable
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
@ -33,6 +33,7 @@ zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
partial = functools.partial
|
||||
Grid = pallas_core.Grid
|
||||
TupleGrid = pallas_core.TupleGrid
|
||||
BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
GridMapping = pallas_core.GridMapping
|
||||
@ -149,7 +150,8 @@ def _make_aval(obj: object) -> jax_core.AbstractValue:
|
||||
|
||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||
class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
grid: Grid
|
||||
grid: TupleGrid
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
num_scalar_prefetch: int
|
||||
in_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
out_specs: tuple[BlockSpec | NoBlockSpec, ...]
|
||||
@ -227,7 +229,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
||||
out_ref_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
grid=grid_mapping_grid, # type: ignore
|
||||
grid=grid_mapping_grid, grid_names=self.grid_names, # type: ignore
|
||||
block_mappings=(*in_block_mappings, *out_block_mappings),
|
||||
mapped_dims=(),
|
||||
num_index_operands=num_flat_scalar_prefetch,
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import string
|
||||
from typing import Any
|
||||
from typing import Any, Hashable
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -99,6 +99,7 @@ class MeshContext:
|
||||
class LoweringContext:
|
||||
ir_context: ir.Context
|
||||
grid_rank: int # Includes both user and vmap axes.
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
|
||||
user_grid_indices: Sequence[ir.Value] | None
|
||||
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
|
||||
@ -252,6 +253,7 @@ def _get_arg_type(
|
||||
@dataclasses.dataclass(init=False)
|
||||
class MosaicGridMapping:
|
||||
grid: tuple[int, ...] | None
|
||||
grid_names: tuple[Hashable, ...] | None
|
||||
jaxpr: jax_core.Jaxpr
|
||||
block_mappings: tuple[pl_core.BlockMapping | None, ...]
|
||||
mapped_dims: tuple[int, ...]
|
||||
@ -269,6 +271,7 @@ class MosaicGridMapping:
|
||||
dimension_semantics: tuple[str, ...] | None,
|
||||
mesh: mesh_lib.Mesh | None):
|
||||
self.grid = grid_mapping.grid
|
||||
self.grid_names = grid_mapping.grid_names
|
||||
self.jaxpr = jaxpr
|
||||
self.block_mappings = grid_mapping.block_mappings
|
||||
self.mapped_dims = grid_mapping.mapped_dims
|
||||
@ -341,6 +344,12 @@ class MosaicGridMapping:
|
||||
"Cannot use communication in pallas_call without shard_map."
|
||||
)
|
||||
axis_names = mesh.axis_names
|
||||
if self.grid_names is not None:
|
||||
if any(a in self.grid_names for a in axis_names):
|
||||
raise ValueError(
|
||||
"Cannot shadow axis mesh axis names with grid names. mesh axis"
|
||||
f" names: {mesh.axis_names}, grid names: {self.grid_names}"
|
||||
)
|
||||
# We need mesh <-> logical translation tables. Since the logical IDs are
|
||||
# just linearized versions of the mesh IDs, we create those tables.
|
||||
mesh_strides = pallas_utils.strides_from_shape(tuple(
|
||||
@ -356,7 +365,19 @@ class MosaicGridMapping:
|
||||
|
||||
@functools.cached_property
|
||||
def has_communication(self) -> bool:
|
||||
return bool(jax_core.used_axis_names_jaxpr(self.jaxpr))
|
||||
nonlocal_axis_names = set()
|
||||
def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr):
|
||||
return {
|
||||
e.name
|
||||
for e in jaxpr.effects
|
||||
if isinstance(e, jax_core.NamedAxisEffect)
|
||||
and (not self.grid_names or e.name not in self.grid_names)
|
||||
}
|
||||
nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr))
|
||||
for bm in self.block_mappings:
|
||||
if bm is not None:
|
||||
nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr))
|
||||
return bool(nonlocal_axis_names)
|
||||
|
||||
def get_extra_args(self) -> tuple[Any, ...]:
|
||||
return ()
|
||||
@ -531,6 +552,7 @@ def lower_jaxpr_to_transform_func(
|
||||
lowering_context = LoweringContext(
|
||||
ctx,
|
||||
len(mosaic_grid_mapping.grid),
|
||||
mosaic_grid_mapping.grid_names,
|
||||
mosaic_grid_mapping.mapped_dims,
|
||||
None,
|
||||
arg_block_shapes,
|
||||
@ -599,6 +621,7 @@ def lower_jaxpr_to_func(
|
||||
lowering_context = LoweringContext(
|
||||
ctx,
|
||||
len(mosaic_grid_mapping.grid),
|
||||
mosaic_grid_mapping.grid_names,
|
||||
mosaic_grid_mapping.mapped_dims,
|
||||
jaxpr_indices,
|
||||
arg_block_shapes,
|
||||
@ -2627,10 +2650,18 @@ def _device_id_lowering_rule(ctx: LoweringRuleContext):
|
||||
return tpu.DeviceIdOp().result
|
||||
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule
|
||||
|
||||
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str):
|
||||
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
grid_names = ctx.lowering_context.grid_names
|
||||
if grid_names and axis_name in grid_names:
|
||||
# We are querying a named axis corresponding to a grid dimension.
|
||||
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
|
||||
# We are querying a named axis corresponding to a mesh dimension.
|
||||
device_id = tpu.DeviceIdOp().result
|
||||
mesh_shape = ctx.lowering_context.mesh_context.mesh_shape
|
||||
axis_names = ctx.lowering_context.mesh_context.axis_names
|
||||
mesh_context = ctx.lowering_context.mesh_context
|
||||
if mesh_context is None:
|
||||
raise ValueError("Mesh context is not set.")
|
||||
mesh_shape = mesh_context.mesh_shape
|
||||
axis_names = mesh_context.axis_names
|
||||
axis_index = axis_names.index(axis_name)
|
||||
axis_size = ir_constant(mesh_shape[axis_index])
|
||||
minor_divisor = ir_constant(
|
||||
|
@ -908,7 +908,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), jaxpr_in_tree)
|
||||
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
|
||||
jaxpr_flat_avals, debug)
|
||||
if consts:
|
||||
@ -1032,6 +1032,13 @@ jax_core.custom_str_eqn_compact_rules[pallas_call_p] = (
|
||||
_pallas_custom_str_eqn_compact
|
||||
)
|
||||
|
||||
def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params):
|
||||
with grid_mapping.trace_env():
|
||||
return pallas_call_p.abstract_eval(
|
||||
*in_avals, grid_mapping=grid_mapping, **params
|
||||
)
|
||||
jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule
|
||||
|
||||
|
||||
def pallas_call(
|
||||
f: Callable[..., None],
|
||||
|
@ -21,7 +21,7 @@ import dataclasses
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, Hashable, TypeVar
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -256,7 +256,10 @@ def lower_jaxpr_to_triton_module(
|
||||
name: str,
|
||||
platform: str
|
||||
) -> LoweringResult:
|
||||
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
|
||||
with grid_mapping.trace_env():
|
||||
jaxpr, _ = pe.dce_jaxpr(
|
||||
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
|
||||
)
|
||||
with _new_ir_context(), ir.Location.unknown():
|
||||
module = ir.Module.create()
|
||||
param_types = [
|
||||
@ -2240,6 +2243,14 @@ def _remat_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
|
||||
triton_lowering_rules[ad_util.stop_gradient_p] = lambda _, x: x
|
||||
|
||||
|
||||
@register_lowering(lax.axis_index_p)
|
||||
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
|
||||
grid_names = ctx.context.grid_mapping.grid_names
|
||||
if axis_name in grid_names:
|
||||
# We are querying a named axis corresponding to a grid dimension.
|
||||
return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
|
||||
raise LookupError(f"Axis name {axis_name} not found in grid.")
|
||||
|
||||
def _is_read_only(ref_effects) -> bool:
|
||||
if len(ref_effects) == 0:
|
||||
return True
|
||||
|
@ -1927,5 +1927,110 @@ class PallasCheckifyInterpreterTest(PallasBaseTest):
|
||||
err.throw()
|
||||
|
||||
|
||||
class PallasCallNamedGridTest(PallasBaseTest):
|
||||
|
||||
def test_named_grid(self):
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.arange(2 * 8 * 128, dtype=np.int32).reshape((2, 8, 128))
|
||||
y = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
|
||||
grid=(("i", 2),)
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_named_grid_reordered_names(self):
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
y_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128))
|
||||
y = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
|
||||
grid=(("j", 4), ("i", 2))
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_can_query_named_grid_size_in_kernel_via_psum(self):
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
self.assertEqual(lax.psum(1, "i"), 2)
|
||||
self.assertEqual(lax.psum(1, "j"), 4)
|
||||
y_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128))
|
||||
y = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((None, 8, 128), lambda i, j: (i, j, 0)),
|
||||
grid=(("j", 4), ("i", 2))
|
||||
)(x)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_can_query_named_dynamic_grid_size_in_kernel_via_psum(self):
|
||||
# TODO(): Enable dynamic grid size via axis_size primitive.
|
||||
self.skipTest("Not supported.")
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
self.assertEqual(lax.psum(1, "i"), 2)
|
||||
self.assertEqual(lax.psum(1, "j"), 4)
|
||||
y_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128))
|
||||
@jax.jit
|
||||
def foo(n):
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
|
||||
grid=(("i", n),)
|
||||
)(x)
|
||||
y = foo(4)
|
||||
np.testing.assert_array_equal(y, x)
|
||||
|
||||
def test_can_query_named_grid_program_id_in_kernel_via_axis_index(self):
|
||||
if self.INTERPRET:
|
||||
self.skipTest("Not supported in interpret mode.")
|
||||
def kernel(x_ref, y_ref):
|
||||
i_index = lax.axis_index("i")
|
||||
y_ref[...] = x_ref[...] + i_index
|
||||
|
||||
x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128))
|
||||
y = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)),
|
||||
grid=(("i", 4),),
|
||||
)(x)
|
||||
np.testing.assert_array_equal(
|
||||
y, x + jnp.arange(4, dtype=jnp.int32)[:, None, None]
|
||||
)
|
||||
|
||||
|
||||
class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user