From 051687dc4c899df3d95c30b812ade401d8b31166 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 17 Mar 2025 16:29:02 -0700 Subject: [PATCH] [pallas] `pallas_call_p` is now parameterized by a mesh The mesh is necessary to add support for clusters to the Mosaic GPU backend. PiperOrigin-RevId: 737792129 --- jax/_src/pallas/core.py | 28 +++++-- jax/_src/pallas/hlo_interpreter.py | 3 +- jax/_src/pallas/mosaic/core.py | 8 +- jax/_src/pallas/mosaic/interpret.py | 3 +- .../pallas/mosaic/pallas_call_registration.py | 16 ++-- jax/_src/pallas/mosaic_gpu/core.py | 16 ++-- jax/_src/pallas/mosaic_gpu/lowering.py | 10 ++- .../mosaic_gpu/pallas_call_registration.py | 2 + jax/_src/pallas/pallas_call.py | 82 ++++++++++++++++--- .../pallas/triton/pallas_call_registration.py | 3 + 10 files changed, 133 insertions(+), 38 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 466f6037a..5342a6946 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -15,6 +15,7 @@ """Module for pallas-core functionality.""" from __future__ import annotations +import collections from collections.abc import Callable, Iterable, Iterator, Sequence import contextlib import copy @@ -1068,6 +1069,17 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_): return [], effs +class Mesh(Protocol): + + @property + def backend(self) -> str: + ... + + @property + def shape(self) -> collections.OrderedDict[object, int]: + ... + + _core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {} @@ -1075,9 +1087,8 @@ def default_mesh_discharge_rule( in_avals, out_avals, *args, - grid, + mesh, compiler_params, - backend, jaxpr, debug, interpret, @@ -1100,19 +1111,22 @@ def default_mesh_discharge_rule( if isinstance(eff, state_types.WriteEffect) ) any_spec = BlockSpec(memory_space=MemorySpace.ANY) + grid_spec = GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=[any_spec] * len(in_avals), + out_specs=[any_spec] * len(modified_idxs), + ) from jax._src.pallas import pallas_call # Avoid circular dependency. - outs = pallas_call.pallas_call( + outs = pallas_call._pallas_call( body, name=name, out_shape=[in_avals[idx] for idx in modified_idxs], - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid=grid, + grid_spec=grid_spec, + mesh=mesh, compiler_params=compiler_params, - backend=backend, interpret=interpret, debug=debug, cost_estimate=cost_estimate, diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 8d7543b31..6fbe5e914 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -340,11 +340,12 @@ def pallas_call_hlo_interpret( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: Any, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], ): - del compiler_params, cost_estimate, out_avals + del mesh, compiler_params, cost_estimate, out_avals debug_info = jaxpr.debug_info # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 3e60e471d..f582248ee 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -211,6 +211,10 @@ class TensorCoreMesh: devices: np.ndarray axis_names: Sequence[str] + @property + def backend(self) -> str: + return "mosaic_tpu" + @property def shape(self): return collections.OrderedDict(zip(self.axis_names, self.devices.shape)) @@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule( compiler_params = TPUCompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") - core_axis_name, num_cores = list(mesh.shape.items())[0] if compiler_params.dimension_semantics is not None: raise ValueError( "dimension_semantics must be None for TensorCoreMesh" @@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule( out_avals, *args, jaxpr=jaxpr, - grid=((core_axis_name, num_cores),), + mesh=mesh, compiler_params=compiler_params.replace( dimension_semantics=(PARALLEL,) ), debug=debug, interpret=interpret, - backend="mosaic_tpu", cost_estimate=cost_estimate, name=name, ) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index a731bfdfd..e92de91f4 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -1351,12 +1351,13 @@ def interpret_pallas_call( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: Any, cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], interpret_params: TPUInterpretParams, ): - del debug, cost_estimate, out_avals + del debug, mesh, cost_estimate, out_avals # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) dynamic_grid_args, scalars, input_args = split_list( diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 887c9629a..896af0c46 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule( *in_nodes, jaxpr: jax_core.Jaxpr, grid_mapping: core.GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, @@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule( out_avals: tuple[jax_core.AbstractValue, ...], ): """Lowers a pallas_call to a Mosaic TPU custom call.""" - del interpret + del mesh, interpret # Unused. + debug_info = jaxpr._debug_info if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") @@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule( else: mosaic_params = {} - mesh = None + jax_mesh = None axis_context = ctx.module_context.axis_context if axis_context is not None: if isinstance(axis_context, sharding_impls.SPMDAxisContext): - mesh = axis_context.mesh + jax_mesh = axis_context.mesh mlir_ctx = mlir.JaxIrContext() mlir_ctx.append_dialect_registry(mlir.upstream_dialects) mlir_ctx.load_all_available_dialects() @@ -147,7 +149,7 @@ def pallas_call_tpu_lowering_rule( grid_mapping, jaxpr, dimension_semantics=dimension_semantics, - mesh=mesh, + mesh=jax_mesh, for_verification=for_verification, dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), ) @@ -164,11 +166,11 @@ def pallas_call_tpu_lowering_rule( ) if promela_dump_path := _DUMP_PROMELA_TO.value: - num_devices = 1 if mesh is None else mesh.devices.size + num_devices = 1 if jax_mesh is None else jax_mesh.devices.size num_cores = ( jax.devices()[0].num_cores - if mesh is None - else mesh.devices[0].num_cores + if jax_mesh is None + else jax_mesh.devices[0].num_cores ) verification_module, _ = lower_module(for_verification=True) model = verification.export_promela_model( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 2e491074c..630c1b8f4 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,7 +18,7 @@ from __future__ import annotations import abc import collections -from collections.abc import Sequence +from collections.abc import Iterable, Sequence import dataclasses import enum import itertools as it @@ -519,9 +519,16 @@ class GPUMesh: ) @property - def shape(self): + def backend(self) -> str: + return "mosaic_gpu" + + @property + def shape(self) -> collections.OrderedDict[object, int]: + pairs: Iterable[tuple[object, int]] if self.num_threads is not None: - pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads)) + pairs = zip( + self.axis_names, (*self.grid, *self.cluster, self.num_threads) + ) else: pairs = tuple( zip( @@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule( out_avals, *args, jaxpr=jaxpr, - grid=tuple(mesh.shape.items()), - backend="mosaic_gpu", + mesh=mesh, compiler_params=compiler_params, debug=debug, interpret=interpret, diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5c863baf6..6b06e6b7d 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -450,6 +450,7 @@ def _block_spec_from_block_mapping( def lower_pipelined_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, jaxpr: jax_core.Jaxpr, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, @@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count + if mesh is not None: + assert isinstance(mesh, gpu_core.GPUMesh) + if mesh and mesh.num_threads is not None: + # Last dim corresponds to the warpgroup count. block = (128 * grid_mapping.grid[-1], 1, 1) grid = grid_mapping.grid[:-1] else: @@ -566,6 +570,7 @@ def lower_pipelined_jaxpr_to_module( parallel_grid, grid_mapping.grid_names, block, + mesh.cluster if mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], new_jaxpr, @@ -578,6 +583,7 @@ def lower_jaxpr_to_module( grid: Sequence[int], grid_names: Sequence[str], block: Sequence[int], + cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, @@ -640,7 +646,7 @@ def lower_jaxpr_to_module( mgpu_core._lower_as_gpu_kernel( body, grid=parallel_grid, - cluster=(), + cluster=cluster, block=block, in_shapes=in_shapes, out_shape=out_shapes, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index a9e5ead8d..d506349fe 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -38,6 +38,7 @@ def pallas_call_lowering( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -63,6 +64,7 @@ def pallas_call_lowering( lowering_result = lowering.lower_pipelined_jaxpr_to_module( grid_mapping, + mesh, jaxpr, compiler_params, cost_estimate, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b1e1da34f..d0b74b2e5 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -20,7 +20,7 @@ import dataclasses import enum from functools import partial, reduce import types -from typing import Any, Literal +from typing import Any, Literal, cast import jax from jax import lax @@ -119,6 +119,7 @@ def _pallas_call_jvp_rule( jaxpr: jax_core.Jaxpr, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, debug: bool, interpret: bool, compiler_params: Any, @@ -133,6 +134,8 @@ def _pallas_call_jvp_rule( raise NotImplementedError if input_output_aliases: raise NotImplementedError("JVP with aliasing not supported.") + if mesh is not None: + raise NotImplementedError("pallas_call with a mesh does not support JVP") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] tangents = [t for t in tangents if type(t) is not ad_util.Zero] nonzero_tangents_with_outputs = nonzero_tangents + [True] * grid_mapping.num_outputs @@ -181,6 +184,7 @@ def _pallas_call_jvp_rule( *tangents, jaxpr=jvp_jaxpr, grid_mapping=jvp_grid_mapping, + mesh=mesh, interpret=interpret, debug=debug, input_output_aliases=(), @@ -317,6 +321,7 @@ def _batch_with_explicit_loop( *, jaxpr: jax_core.Jaxpr, grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, @@ -384,6 +389,7 @@ def _batch_with_explicit_loop( *batch_args, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -413,6 +419,7 @@ def _pallas_call_batching_rule( *, jaxpr: jax_core.Jaxpr, grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, @@ -421,6 +428,11 @@ def _pallas_call_batching_rule( out_avals: tuple[jax_core.AbstractValue, ...], backend: _Backend | None, ): + if mesh is not None: + raise NotImplementedError( + "pallas_call with a mesh does not support batching" + ) + def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped ) -> jax.Array: @@ -445,6 +457,7 @@ def _pallas_call_batching_rule( *args, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -478,6 +491,7 @@ def _pallas_call_batching_rule( dims=dynamic_grid_dims + dims, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -512,6 +526,7 @@ def _pallas_call_batching_rule( dims=scalar_bdims + bdims, jaxpr=jaxpr, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -890,6 +905,7 @@ def _pallas_call_batching_rule( *args, jaxpr=jaxpr, grid_mapping=batched_grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, @@ -1339,12 +1355,13 @@ def _pallas_call_state_discharge_rule( jaxpr: jax_core.Jaxpr, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + mesh: pallas_core.Mesh | None, debug: bool, interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None = None + backend: _Backend | None = None, ): del avals_out assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) @@ -1440,6 +1457,7 @@ def _pallas_call_state_discharge_rule( jaxpr=new_jaxpr, input_output_aliases=new_input_output_aliases, grid_mapping=new_grid_mapping, + mesh=mesh, debug=debug, interpret=interpret, compiler_params=compiler_params, @@ -1526,16 +1544,6 @@ def pallas_call( invoke the Pallas kernel. """ - if compiler_params is None: - compiler_params = {} - if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: - raise ValueError( - f"Unknown platform in compiler params: {compiler_params.PLATFORM}") - compiler_params = { - compiler_params.PLATFORM: dataclasses.asdict(compiler_params) - } - if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) else: @@ -1556,6 +1564,55 @@ def pallas_call( "If `grid_spec` is specified, then `scratch_shapes` must " f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs + return _pallas_call( + kernel, + out_shape, + grid_spec=grid_spec, + input_output_aliases=input_output_aliases, + debug=debug, + interpret=interpret, + name=name, + compiler_params=compiler_params, + cost_estimate=cost_estimate, + backend=backend, + ) + + +def _pallas_call( + kernel: Callable[..., None], + out_shape: Any, + *, + grid_spec: GridSpec, + mesh: pallas_core.Mesh | None = None, + input_output_aliases: dict[int, int] = {}, + debug: bool = False, + interpret: bool = False, + name: str | None = None, + compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + cost_estimate: CostEstimate | None = None, + backend: _Backend | None = None, +): + if compiler_params is None: + compiler_params = {} + if isinstance(compiler_params, pallas_core.CompilerParams): + if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: + raise ValueError( + f"Unknown platform in compiler params: {compiler_params.PLATFORM}" + ) + compiler_params = { + compiler_params.PLATFORM: dataclasses.asdict(compiler_params) + } + + if mesh is not None: + if tuple(mesh.shape.values()) != grid_spec.grid: + raise ValueError( + f"Mesh shape {tuple(mesh.shape.values())} does not match grid " + f"shape {grid_spec.grid}." + ) + if backend is not None: + raise ValueError("If `mesh` is specified, then `backend` must be `None`.") + backend = cast(_Backend, mesh.backend) + grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage # but it is lossy, because it prevents expressing functions that return @@ -1643,6 +1700,7 @@ def pallas_call( debug=debug, interpret=interpret, grid_mapping=grid_mapping, + mesh=mesh, input_output_aliases=tuple(input_output_aliases.items()), compiler_params=compiler_params, cost_estimate=cost_estimate, diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 4e3bd0697..4e8775e51 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -50,6 +50,7 @@ def pallas_call_lowering( debug: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, + mesh: pallas_core.Mesh | None, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], @@ -64,6 +65,8 @@ def pallas_call_lowering( raise NotImplementedError( "scalar prefetch not implemented in the Triton backend" ) + if mesh is not None: + raise NotImplementedError("mesh is not supported in the Triton backend") triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.get("num_warps", 4) num_warps = 4 if num_warps is None else num_warps