[pallas] pallas_call_p is now parameterized by a mesh

The mesh is necessary to add support for clusters to the Mosaic GPU backend.

PiperOrigin-RevId: 737792129
This commit is contained in:
Sergei Lebedev 2025-03-17 16:29:02 -07:00 committed by jax authors
parent b4966130a3
commit 051687dc4c
10 changed files with 133 additions and 38 deletions

View File

@ -15,6 +15,7 @@
"""Module for pallas-core functionality."""
from __future__ import annotations
import collections
from collections.abc import Callable, Iterable, Iterator, Sequence
import contextlib
import copy
@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
return [], effs
class Mesh(Protocol):
@property
def backend(self) -> str:
...
@property
def shape(self) -> collections.OrderedDict[object, int]:
...
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule(
in_avals,
out_avals,
*args,
grid,
mesh,
compiler_params,
backend,
jaxpr,
debug,
interpret,
@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule(
if isinstance(eff, state_types.WriteEffect)
)
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
grid_spec = GridSpec(
grid=tuple(mesh.shape.items()),
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(modified_idxs),
)
from jax._src.pallas import pallas_call # Avoid circular dependency.
outs = pallas_call.pallas_call(
outs = pallas_call._pallas_call(
body,
name=name,
out_shape=[in_avals[idx] for idx in modified_idxs],
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(modified_idxs),
input_output_aliases={
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
},
grid=grid,
grid_spec=grid_spec,
mesh=mesh,
compiler_params=compiler_params,
backend=backend,
interpret=interpret,
debug=debug,
cost_estimate=cost_estimate,

View File

@ -340,11 +340,12 @@ def pallas_call_hlo_interpret(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: Any,
cost_estimate: CostEstimate,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del compiler_params, cost_estimate, out_avals
del mesh, compiler_params, cost_estimate, out_avals
debug_info = jaxpr.debug_info
# If we're in interpret mode, we *scan* over the grid and eval the
# discharged jaxpr.

View File

@ -211,6 +211,10 @@ class TensorCoreMesh:
devices: np.ndarray
axis_names: Sequence[str]
@property
def backend(self) -> str:
return "mosaic_tpu"
@property
def shape(self):
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule(
compiler_params = TPUCompilerParams()
if len(mesh.shape) > 1:
raise NotImplementedError("Mesh must be 1D")
core_axis_name, num_cores = list(mesh.shape.items())[0]
if compiler_params.dimension_semantics is not None:
raise ValueError(
"dimension_semantics must be None for TensorCoreMesh"
@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule(
out_avals,
*args,
jaxpr=jaxpr,
grid=((core_axis_name, num_cores),),
mesh=mesh,
compiler_params=compiler_params.replace(
dimension_semantics=(PARALLEL,)
),
debug=debug,
interpret=interpret,
backend="mosaic_tpu",
cost_estimate=cost_estimate,
name=name,
)

View File

@ -1351,12 +1351,13 @@ def interpret_pallas_call(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: Any,
cost_estimate: CostEstimate,
out_avals: tuple[jax_core.AbstractValue, ...],
interpret_params: TPUInterpretParams,
):
del debug, cost_estimate, out_avals
del debug, mesh, cost_estimate, out_avals
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
dynamic_grid_args, scalars, input_args = split_list(

View File

@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
*in_nodes,
jaxpr: jax_core.Jaxpr,
grid_mapping: core.GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
out_avals: tuple[jax_core.AbstractValue, ...],
):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
del interpret
del mesh, interpret # Unused.
debug_info = jaxpr._debug_info
if debug:
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
else:
mosaic_params = {}
mesh = None
jax_mesh = None
axis_context = ctx.module_context.axis_context
if axis_context is not None:
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
mesh = axis_context.mesh
jax_mesh = axis_context.mesh
mlir_ctx = mlir.JaxIrContext()
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
mlir_ctx.load_all_available_dialects()
@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule(
grid_mapping,
jaxpr,
dimension_semantics=dimension_semantics,
mesh=mesh,
mesh=jax_mesh,
for_verification=for_verification,
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
)
@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule(
)
if promela_dump_path := _DUMP_PROMELA_TO.value:
num_devices = 1 if mesh is None else mesh.devices.size
num_devices = 1 if jax_mesh is None else jax_mesh.devices.size
num_cores = (
jax.devices()[0].num_cores
if mesh is None
else mesh.devices[0].num_cores
if jax_mesh is None
else jax_mesh.devices[0].num_cores
)
verification_module, _ = lower_module(for_verification=True)
model = verification.export_promela_model(

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import abc
import collections
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
import dataclasses
import enum
import itertools as it
@ -519,9 +519,16 @@ class GPUMesh:
)
@property
def shape(self):
def backend(self) -> str:
return "mosaic_gpu"
@property
def shape(self) -> collections.OrderedDict[object, int]:
pairs: Iterable[tuple[object, int]]
if self.num_threads is not None:
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
pairs = zip(
self.axis_names, (*self.grid, *self.cluster, self.num_threads)
)
else:
pairs = tuple(
zip(
@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule(
out_avals,
*args,
jaxpr=jaxpr,
grid=tuple(mesh.shape.items()),
backend="mosaic_gpu",
mesh=mesh,
compiler_params=compiler_params,
debug=debug,
interpret=interpret,

View File

@ -450,6 +450,7 @@ def _block_spec_from_block_mapping(
def lower_pipelined_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
jaxpr: jax_core.Jaxpr,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module(
block_mappings, [grid_mapping.num_inputs]
)
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
if mesh is not None:
assert isinstance(mesh, gpu_core.GPUMesh)
if mesh and mesh.num_threads is not None:
# Last dim corresponds to the warpgroup count.
block = (128 * grid_mapping.grid[-1], 1, 1)
grid = grid_mapping.grid[:-1]
else:
@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module(
parallel_grid,
grid_mapping.grid_names,
block,
mesh.cluster if mesh is not None else (),
[bm.array_shape_dtype for bm in in_block_mappings],
[bm.array_shape_dtype for bm in out_block_mappings],
new_jaxpr,
@ -578,6 +583,7 @@ def lower_jaxpr_to_module(
grid: Sequence[int],
grid_names: Sequence[str],
block: Sequence[int],
cluster: Sequence[int],
in_shapes: Sequence[jax.ShapeDtypeStruct],
out_shapes: Sequence[jax.ShapeDtypeStruct],
jaxpr: jax_core.Jaxpr,
@ -640,7 +646,7 @@ def lower_jaxpr_to_module(
mgpu_core._lower_as_gpu_kernel(
body,
grid=parallel_grid,
cluster=(),
cluster=cluster,
block=block,
in_shapes=in_shapes,
out_shape=out_shapes,

View File

@ -38,6 +38,7 @@ def pallas_call_lowering(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
@ -63,6 +64,7 @@ def pallas_call_lowering(
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
grid_mapping,
mesh,
jaxpr,
compiler_params,
cost_estimate,

View File

@ -20,7 +20,7 @@ import dataclasses
import enum
from functools import partial, reduce
import types
from typing import Any, Literal
from typing import Any, Literal, cast
import jax
from jax import lax
@ -119,6 +119,7 @@ def _pallas_call_jvp_rule(
jaxpr: jax_core.Jaxpr,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
debug: bool,
interpret: bool,
compiler_params: Any,
@ -133,6 +134,8 @@ def _pallas_call_jvp_rule(
raise NotImplementedError
if input_output_aliases:
raise NotImplementedError("JVP with aliasing not supported.")
if mesh is not None:
raise NotImplementedError("pallas_call with a mesh does not support JVP")
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs
@ -181,6 +184,7 @@ def _pallas_call_jvp_rule(
*tangents,
jaxpr=jvp_jaxpr,
grid_mapping=jvp_grid_mapping,
mesh=mesh,
interpret=interpret,
debug=debug,
input_output_aliases=(),
@ -317,6 +321,7 @@ def _batch_with_explicit_loop(
*,
jaxpr: jax_core.Jaxpr,
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
@ -384,6 +389,7 @@ def _batch_with_explicit_loop(
*batch_args,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -413,6 +419,7 @@ def _pallas_call_batching_rule(
*,
jaxpr: jax_core.Jaxpr,
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
@ -421,6 +428,11 @@ def _pallas_call_batching_rule(
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None,
):
if mesh is not None:
raise NotImplementedError(
"pallas_call with a mesh does not support batching"
)
def _maybe_squeeze_out_bdim(
x: jax.Array, bdim: int | batching.NotMapped
) -> jax.Array:
@ -445,6 +457,7 @@ def _pallas_call_batching_rule(
*args,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -478,6 +491,7 @@ def _pallas_call_batching_rule(
dims=dynamic_grid_dims + dims,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -512,6 +526,7 @@ def _pallas_call_batching_rule(
dims=scalar_bdims + bdims,
jaxpr=jaxpr,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -890,6 +905,7 @@ def _pallas_call_batching_rule(
*args,
jaxpr=jaxpr,
grid_mapping=batched_grid_mapping,
mesh=mesh,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule(
jaxpr: jax_core.Jaxpr,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: GridMapping,
mesh: pallas_core.Mesh | None,
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
backend: _Backend | None = None
backend: _Backend | None = None,
):
del avals_out
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule(
jaxpr=new_jaxpr,
input_output_aliases=new_input_output_aliases,
grid_mapping=new_grid_mapping,
mesh=mesh,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
@ -1526,16 +1544,6 @@ def pallas_call(
invoke the Pallas kernel.
"""
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
compiler_params = {
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
}
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes)
else:
@ -1556,6 +1564,55 @@ def pallas_call(
"If `grid_spec` is specified, then `scratch_shapes` must "
f"be `()`. It is {scratch_shapes}")
del grid, in_specs, out_specs
return _pallas_call(
kernel,
out_shape,
grid_spec=grid_spec,
input_output_aliases=input_output_aliases,
debug=debug,
interpret=interpret,
name=name,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
backend=backend,
)
def _pallas_call(
kernel: Callable[..., None],
out_shape: Any,
*,
grid_spec: GridSpec,
mesh: pallas_core.Mesh | None = None,
input_output_aliases: dict[int, int] = {},
debug: bool = False,
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
cost_estimate: CostEstimate | None = None,
backend: _Backend | None = None,
):
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}"
)
compiler_params = {
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
}
if mesh is not None:
if tuple(mesh.shape.values()) != grid_spec.grid:
raise ValueError(
f"Mesh shape {tuple(mesh.shape.values())} does not match grid "
f"shape {grid_spec.grid}."
)
if backend is not None:
raise ValueError("If `mesh` is specified, then `backend` must be `None`.")
backend = cast(_Backend, mesh.backend)
grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec)
# TODO(necula): this canonicalization may be convenient for some usage
# but it is lossy, because it prevents expressing functions that return
@ -1643,6 +1700,7 @@ def pallas_call(
debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,
mesh=mesh,
input_output_aliases=tuple(input_output_aliases.items()),
compiler_params=compiler_params,
cost_estimate=cost_estimate,

View File

@ -50,6 +50,7 @@ def pallas_call_lowering(
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
grid_mapping: pallas_core.GridMapping,
mesh: pallas_core.Mesh | None,
compiler_params: dict[str, Any],
cost_estimate: pallas_core.CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
@ -64,6 +65,8 @@ def pallas_call_lowering(
raise NotImplementedError(
"scalar prefetch not implemented in the Triton backend"
)
if mesh is not None:
raise NotImplementedError("mesh is not supported in the Triton backend")
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.get("num_warps", 4)
num_warps = 4 if num_warps is None else num_warps